From d642255e70c5906646b40b6f8df730b23e440110 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 13 Feb 2026 13:11:45 +0100 Subject: [PATCH] Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 --- __init__.py | 15 + gimm_vfi_arch/__init__.py | 15 + gimm_vfi_arch/configs/__init__.py | 0 gimm_vfi_arch/configs/gimmvfi_f_arb.yaml | 57 + gimm_vfi_arch/configs/gimmvfi_r_arb.yaml | 57 + gimm_vfi_arch/generalizable_INR/__init__.py | 0 gimm_vfi_arch/generalizable_INR/configs.py | 57 + .../generalizable_INR/flowformer/__init__.py | 0 .../flowformer/configs/__init__.py | 0 .../flowformer/configs/submission.py | 77 + .../FlowFormer/LatentCostFormer/__init__.py | 0 .../FlowFormer/LatentCostFormer/attention.py | 197 +++ .../core/FlowFormer/LatentCostFormer/cnn.py | 649 ++++++++ .../FlowFormer/LatentCostFormer/convnext.py | 98 ++ .../FlowFormer/LatentCostFormer/decoder.py | 316 ++++ .../FlowFormer/LatentCostFormer/encoder.py | 534 +++++++ .../core/FlowFormer/LatentCostFormer/gma.py | 123 ++ .../core/FlowFormer/LatentCostFormer/gru.py | 160 ++ .../FlowFormer/LatentCostFormer/mlpmixer.py | 55 + .../LatentCostFormer/transformer.py | 57 + .../core/FlowFormer/LatentCostFormer/twins.py | 1360 +++++++++++++++++ .../flowformer/core/FlowFormer/__init__.py | 7 + .../flowformer/core/FlowFormer/common.py | 562 +++++++ .../flowformer/core/FlowFormer/encoders.py | 115 ++ .../flowformer/core/__init__.py | 0 .../generalizable_INR/flowformer/core/corr.py | 90 ++ .../flowformer/core/extractor.py | 267 ++++ .../flowformer/core/position_encoding.py | 100 ++ .../flowformer/core/update.py | 154 ++ .../flowformer/core/utils/__init__.py | 0 .../flowformer/core/utils/utils.py | 113 ++ gimm_vfi_arch/generalizable_INR/gimm.py | 253 +++ gimm_vfi_arch/generalizable_INR/gimmvfi_f.py | 471 ++++++ gimm_vfi_arch/generalizable_INR/gimmvfi_r.py | 508 ++++++ .../generalizable_INR/modules/__init__.py | 0 .../modules/coord_sampler.py | 91 ++ .../modules/fi_components.py | 340 +++++ .../generalizable_INR/modules/fi_utils.py | 81 + .../generalizable_INR/modules/hyponet.py | 198 +++ .../generalizable_INR/modules/layers.py | 42 + .../modules/module_config.py | 52 + .../generalizable_INR/modules/softsplat.py | 672 ++++++++ .../generalizable_INR/modules/utils.py | 76 + .../generalizable_INR/raft/__init__.py | 1 + gimm_vfi_arch/generalizable_INR/raft/corr.py | 175 +++ .../generalizable_INR/raft/extractor.py | 293 ++++ .../generalizable_INR/raft/other_raft.py | 238 +++ gimm_vfi_arch/generalizable_INR/raft/raft.py | 169 ++ .../generalizable_INR/raft/update.py | 154 ++ .../generalizable_INR/raft/utils/__init__.py | 0 .../generalizable_INR/raft/utils/utils.py | 93 ++ gimm_vfi_arch/utils/__init__.py | 0 gimm_vfi_arch/utils/utils.py | 52 + inference.py | 180 +++ nodes.py | 396 ++++- requirements.txt | 5 + 56 files changed, 9774 insertions(+), 1 deletion(-) create mode 100644 gimm_vfi_arch/__init__.py create mode 100644 gimm_vfi_arch/configs/__init__.py create mode 100644 gimm_vfi_arch/configs/gimmvfi_f_arb.yaml create mode 100644 gimm_vfi_arch/configs/gimmvfi_r_arb.yaml create mode 100644 gimm_vfi_arch/generalizable_INR/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/configs.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/configs/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/configs/submission.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/common.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/encoders.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/position_encoding.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/update.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/utils/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py create mode 100644 gimm_vfi_arch/generalizable_INR/gimm.py create mode 100644 gimm_vfi_arch/generalizable_INR/gimmvfi_f.py create mode 100644 gimm_vfi_arch/generalizable_INR/gimmvfi_r.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/fi_components.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/fi_utils.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/hyponet.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/layers.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/module_config.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/softsplat.py create mode 100644 gimm_vfi_arch/generalizable_INR/modules/utils.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/corr.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/extractor.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/other_raft.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/raft.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/update.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/utils/__init__.py create mode 100644 gimm_vfi_arch/generalizable_INR/raft/utils/utils.py create mode 100644 gimm_vfi_arch/utils/__init__.py create mode 100644 gimm_vfi_arch/utils/utils.py diff --git a/__init__.py b/__init__.py index faf41d6..8aedf02 100644 --- a/__init__.py +++ b/__init__.py @@ -34,6 +34,14 @@ def _auto_install_deps(): except Exception as e: logger.warning(f"[Tween] Could not auto-install cupy: {e}") + # GIMM-VFI dependencies + for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"): + try: + __import__(pkg) + except ImportError: + logger.info(f"[Tween] Installing {pkg}...") + subprocess.check_call([sys.executable, "-m", "pip", "install", pkg]) + _auto_install_deps() @@ -41,6 +49,7 @@ from .nodes import ( LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos, LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, + LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, ) WEB_DIRECTORY = "./web" @@ -56,6 +65,9 @@ NODE_CLASS_MAPPINGS = { "LoadSGMVFIModel": LoadSGMVFIModel, "SGMVFIInterpolate": SGMVFIInterpolate, "SGMVFISegmentInterpolate": SGMVFISegmentInterpolate, + "LoadGIMMVFIModel": LoadGIMMVFIModel, + "GIMMVFIInterpolate": GIMMVFIInterpolate, + "GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -69,4 +81,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadSGMVFIModel": "Load SGM-VFI Model", "SGMVFIInterpolate": "SGM-VFI Interpolate", "SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate", + "LoadGIMMVFIModel": "Load GIMM-VFI Model", + "GIMMVFIInterpolate": "GIMM-VFI Interpolate", + "GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate", } diff --git a/gimm_vfi_arch/__init__.py b/gimm_vfi_arch/__init__.py new file mode 100644 index 0000000..8cfcdec --- /dev/null +++ b/gimm_vfi_arch/__init__.py @@ -0,0 +1,15 @@ +from .generalizable_INR.gimmvfi_r import GIMMVFI_R +from .generalizable_INR.gimmvfi_f import GIMMVFI_F +from .generalizable_INR.configs import GIMMVFIConfig +from .generalizable_INR.raft.raft import RAFT as GIMM_RAFT +from .generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer as GIMM_FlowFormer +from .generalizable_INR.flowformer.configs.submission import get_cfg as gimm_get_flowformer_cfg +from .utils.utils import InputPadder as GIMMInputPadder, RaftArgs as GIMMRaftArgs, easydict_to_dict +from .generalizable_INR.modules.softsplat import objCudacache as gimm_softsplat_cache + + +def clear_gimm_caches(): + """Clear cached CUDA kernels and warp grids for GIMM-VFI.""" + from .generalizable_INR.modules.fi_utils import backwarp_tenGrid + backwarp_tenGrid.clear() + gimm_softsplat_cache.clear() diff --git a/gimm_vfi_arch/configs/__init__.py b/gimm_vfi_arch/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/configs/gimmvfi_f_arb.yaml b/gimm_vfi_arch/configs/gimmvfi_f_arb.yaml new file mode 100644 index 0000000..8139455 --- /dev/null +++ b/gimm_vfi_arch/configs/gimmvfi_f_arb.yaml @@ -0,0 +1,57 @@ +trainer: stage_inr +dataset: + type: vimeo_arb + path: ./data/vimeo90k/vimeo_septuplet + aug: true + +arch: + type: gimmvfi_f + ema: true + modulated_layer_idxs: [1] + + coord_range: [-1., 1.] + + hyponet: + type: mlp + n_layer: 5 # including the output layer + hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1] + use_bias: true + input_dim: 3 + output_dim: 2 + output_bias: 0.5 + activation: + type: siren + siren_w0: 1.0 + initialization: + weight_init_type: siren + bias_init_type: siren + +loss: + subsample: + type: random + ratio: 0.1 + +optimizer: + type: adamw + init_lr: 0.00008 + weight_decay: 0.00004 + betas: [0.9, 0.999] + ft: true + warmup: + epoch: 1 + multiplier: 1 + buffer_epoch: 0 + min_lr: 0.000008 + mode: fix + start_from_zero: True + max_gn: null + +experiment: + amp: True + batch_size: 4 + total_batch_size: 32 + epochs: 60 + save_ckpt_freq: 10 + test_freq: 10 + test_imlog_freq: 10 + diff --git a/gimm_vfi_arch/configs/gimmvfi_r_arb.yaml b/gimm_vfi_arch/configs/gimmvfi_r_arb.yaml new file mode 100644 index 0000000..a5005fe --- /dev/null +++ b/gimm_vfi_arch/configs/gimmvfi_r_arb.yaml @@ -0,0 +1,57 @@ +trainer: stage_inr +dataset: + type: vimeo_arb + path: ./data/vimeo90k/vimeo_septuplet + aug: true + +arch: + type: gimmvfi_r + ema: true + modulated_layer_idxs: [1] + + coord_range: [-1., 1.] + + hyponet: + type: mlp + n_layer: 5 # including the output layer + hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1] + use_bias: true + input_dim: 3 + output_dim: 2 + output_bias: 0.5 + activation: + type: siren + siren_w0: 1.0 + initialization: + weight_init_type: siren + bias_init_type: siren + +loss: + subsample: + type: random + ratio: 0.1 + +optimizer: + type: adamw + init_lr: 0.00008 + weight_decay: 0.00004 + betas: [0.9, 0.999] + ft: true + warmup: + epoch: 1 + multiplier: 1 + buffer_epoch: 0 + min_lr: 0.000008 + mode: fix + start_from_zero: True + max_gn: null + +experiment: + amp: True + batch_size: 4 + total_batch_size: 32 + epochs: 60 + save_ckpt_freq: 10 + test_freq: 10 + test_imlog_freq: 10 + diff --git a/gimm_vfi_arch/generalizable_INR/__init__.py b/gimm_vfi_arch/generalizable_INR/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/configs.py b/gimm_vfi_arch/generalizable_INR/configs.py new file mode 100644 index 0000000..3e80460 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/configs.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +from typing import List, Optional +from dataclasses import dataclass, field + +from omegaconf import OmegaConf, MISSING +from .modules.module_config import HypoNetConfig + + +@dataclass +class GIMMConfig: + type: str = "gimm" + ema: Optional[bool] = None + ema_value: Optional[float] = None + fwarp_type: str = "linear" + hyponet: HypoNetConfig = field(default_factory=HypoNetConfig) + coord_range: List[float] = MISSING + modulated_layer_idxs: Optional[List[int]] = None + + @classmethod + def create(cls, config): + # We need to specify the type of the default DataEncoderConfig. + # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) + # hence merging with the config with other type would cause config error. + defaults = OmegaConf.structured(cls(ema=False)) + config = OmegaConf.merge(defaults, config) + return config + + +@dataclass +class GIMMVFIConfig: + type: str = "gimmvfi" + ema: Optional[bool] = None + ema_value: Optional[float] = None + fwarp_type: str = "linear" + rec_weight: float = 0.1 + hyponet: HypoNetConfig = field(default_factory=HypoNetConfig) + raft_iter: int = 20 + coord_range: List[float] = MISSING + modulated_layer_idxs: Optional[List[int]] = None + + @classmethod + def create(cls, config): + # We need to specify the type of the default DataEncoderConfig. + # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) + # hence merging with the config with other type would cause config error. + defaults = OmegaConf.structured(cls(ema=False)) + config = OmegaConf.merge(defaults, config) + return config diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/configs/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/configs/submission.py b/gimm_vfi_arch/generalizable_INR/flowformer/configs/submission.py new file mode 100644 index 0000000..4d225d3 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/configs/submission.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +_CN.name = "" +_CN.suffix = "" +_CN.gamma = 0.8 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = False +_CN.critical_params = [] + +_CN.transformer = "latentcostformer" +_CN.model = "pretrained_ckpt/flowformer_sintel.pth" + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = "linear" +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.arc_type = "transformer" +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = "single" +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = "twins" +_CN.latentcostformer.fnet = "twins" +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 32 +_CN.latentcostformer.critical_params = [ + "cost_heads_num", + "vert_c_dim", + "cnet", + "pretrain", + "add_flow_token", + "encoder_depth", + "gma", + "cost_encoder_res", +] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = "OneCycleLR" +_CN.trainer.optimizer = "adamw" +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-4 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = "linear" + + +def get_cfg(): + return _CN.clone() diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py new file mode 100644 index 0000000..0095917 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/attention.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + + +class BroadMultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(BroadMultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, _, _ = K.shape + _, N, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N) + + return out + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = ( + einsum("bhid, bhjd -> bhij", Q, K) * self.scale + ) # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, HW, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + + return out + + +# class MultiHeadAttentionRelative_encoder(nn.Module): +# def __init__(self, dim, heads): +# super(MultiHeadAttentionRelative, self).__init__() +# self.dim = dim +# self.heads = heads +# self.scale = (dim/heads) ** -0.5 +# self.attend = nn.Softmax(dim=-1) + +# def attend_with_rpe(self, Q, K, Q_r, K_r): +# """ +# Q: [BH1W1, H3W3, dim] +# K: [BH1W1, H3W3, dim] +# Q_r: [BH1W1, H3W3, H3W3, dim] +# K_r: [BH1W1, H3W3, H3W3, dim] +# """ + +# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + +# # context-context similarity +# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3] +# # context-position similarity +# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] +# # position-context similarity +# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) +# p_c = torch.squeeze(p_c, dim=4) +# p_c = p_c.permute(0, 1, 3, 2) +# dots = c_c + c_p + p_c +# return self.attend(dots) + +# def forward(self, Q, K, V, Q_r, K_r): +# attn = self.attend_with_rpe(Q, K, Q_r, K_r) +# B, HW, _ = Q.shape + +# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + +# out = einsum('bhij, bhjd -> bhid', attn, V) +# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + +# return out + + +class MultiHeadAttentionRelative(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttentionRelative, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim / heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K, Q_r, K_r): + """ + Q: [BH1W1, 1, dim] + K: [BH1W1, H3W3, dim] + Q_r: [BH1W1, H3W3, dim] + K_r: [BH1W1, H3W3, dim] + """ + + Q = rearrange( + Q, "b i (heads d) -> b heads i d", heads=self.heads + ) # [BH1W1, heads, 1, dim] + K = rearrange( + K, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + K_r = rearrange( + K_r, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + Q_r = rearrange( + Q_r, "b j (heads d) -> b heads j d", heads=self.heads + ) # [BH1W1, heads, H3W3, dim] + + # context-context similarity + c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3] + # context-position similarity + c_p = ( + einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale + ) # [(B H1W1) heads 1 H3W3] + # position-context similarity + p_c = ( + einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :]) + * self.scale + ) + p_c = torch.squeeze(p_c, dim=4) + p_c = p_c.permute(0, 1, 3, 2) + dots = c_c + c_p + p_c + return self.attend(dots) + + def forward(self, Q, K, V, Q_r, K_r): + attn = self.attend_with_rpe(Q, K, Q_r, K_r) + B, HW, _ = Q.shape + + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + + return out + + +def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device) + return torch.cat( + [ + torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + ], + dim=-1, + ) + + +def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device) + return torch.cat( + [ + torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)), + torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)), + ], + dim=-1, + ) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py new file mode 100644 index 0000000..da1887c --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/cnn.py @@ -0,0 +1,649 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +import math +import numpy as np + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + mul = input_dim // 3 + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64 * mul) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64 * mul) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 * mul + self.layer1 = self._make_layer(64 * mul, stride=1) + self.layer2 = self._make_layer(96 * mul, stride=2) + self.layer3 = self._make_layer(128 * mul, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class ConvNets(nn.Module): + def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1): + super(ConvNets, self).__init__() + + self.conv_first = nn.Conv2d( + in_dim, inter_dim, kernel_size=3, padding=1, stride=stride + ) + self.conv_last = nn.Conv2d( + inter_dim, out_dim, kernel_size=3, padding=1, stride=stride + ) + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList( + [ + ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1) + for i in range(depth) + ] + ) + + def forward(self, x): + x = self.relu(self.conv_first(x)) + for inter_conv in self.inter_convs: + x = inter_conv(x) + x = self.conv_last(x) + return x + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.motion_feature_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicFuseMotion(nn.Module): + def __init__(self, args): + super(BasicFuseMotion, self).__init__() + cor_planes = args.motion_feature_dim + out_planes = args.query_latent_dim + + self.normf1 = nn.InstanceNorm2d(128) + self.normf2 = nn.InstanceNorm2d(128) + + self.convf1 = nn.Conv2d(2, 128, 3, padding=1) + self.convf2 = nn.Conv2d(128, 128, 3, padding=1) + self.convf3 = nn.Conv2d(128, 64, 3, padding=1) + + s = 1 + self.normc1 = nn.InstanceNorm2d(256 * s) + self.normc2 = nn.InstanceNorm2d(256 * s) + self.normc3 = nn.InstanceNorm2d(256 * s) + + self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0) + self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1) + self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0) + + def forward(self, flow, feat, context1=None): + flo = F.relu(self.normf1(self.convf1(flow))) + flo = F.relu(self.normf2(self.convf2(flo))) + flo = self.convf3(flo) + + feat = torch.cat([feat, context1], dim=1) + feat = F.relu(self.normc1(self.convc1(feat))) + feat = F.relu(self.normc2(self.convc2(feat))) + feat = F.relu(self.normc3(self.convc3(feat))) + feat = self.convc4(feat) + + feat = torch.cat([flo, feat], dim=1) + feat = F.relu(self.conv(feat)) + + return feat + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +class DirectMeanMaskPredictor(nn.Module): + def __init__(self, args): + super(DirectMeanMaskPredictor, self).__init__() + self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256) + self.mask = nn.Sequential( + nn.Conv2d(args.predictor_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, motion_features): + delta_flow = self.flow_head(motion_features) + mask = 0.25 * self.mask(motion_features) + + return mask, delta_flow + + +class BaiscMeanPredictor(nn.Module): + def __init__(self, args, hidden_dim=128): + super(BaiscMeanPredictor, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, latent, flow): + motion_features = self.encoder(flow, latent) + delta_flow = self.flow_head(motion_features) + mask = 0.25 * self.mask(motion_features) + + return mask, delta_flow + + +class BasicRPEEncoder(nn.Module): + def __init__(self, args): + super(BasicRPEEncoder, self).__init__() + self.args = args + dim = args.query_latent_dim + self.encoder = nn.Sequential( + nn.Linear(2, dim // 2), + nn.ReLU(inplace=True), + nn.Linear(dim // 2, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim), + ) + + def forward(self, rpe_tokens): + return self.encoder(rpe_tokens) + + +from .twins import Block, CrossBlock + + +class TwinsSelfAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsSelfAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = 0.0 + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + ) + self.global_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + x = self.global_block(x, size) + + tgt = self.local_block(tgt, size) + tgt = self.global_block(tgt, size) + return x, tgt + + +class TwinsCrossAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsCrossAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = 0.0 + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + ) + self.global_block = CrossBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + tgt = self.local_block(tgt, size) + x, tgt = self.global_block(x, tgt, size) + + return x, tgt diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py new file mode 100644 index 0000000..667534c --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py @@ -0,0 +1,98 @@ +#from turtle import forward +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + + +class ConvNextLayer(nn.Module): + def __init__(self, dim, depth=4): + super().__init__() + self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)]) + + def forward(self, x): + return self.net(x) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class ConvNextBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # print(f"conv next layer") + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py new file mode 100644 index 0000000..bc64e71 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/decoder.py @@ -0,0 +1,316 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from ...utils.utils import coords_grid, bilinear_sampler +from .attention import ( + MultiHeadAttention, + LinearPositionEmbeddingSine, + ExpPositionEmbeddingSine, +) + +from timm.models.layers import DropPath + +from .gru import BasicUpdateBlock, GMAUpdateBlock +from .gma import Attention + + +def initialize_flow(img): + """Flow is represented as difference between two means flow = mean1 - mean0""" + N, C, H, W = img.shape + mean = coords_grid(N, H, W).to(img.device) + mean_init = coords_grid(N, H, W).to(img.device) + + # optical flow computed as difference: flow = mean1 - mean0 + return mean, mean_init + + +class CrossAttentionLayer(nn.Module): + # def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + def __init__( + self, + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + add_flow_token=True, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + pe="linear", + ): + super(CrossAttentionLayer, self).__init__() + + head_dim = qk_dim // num_heads + self.scale = head_dim**-0.5 + self.query_token_dim = query_token_dim + self.pe = pe + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(query_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, v_dim, bias=True), + ) + + self.proj = nn.Linear(v_dim * 2, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout), + ) + self.add_flow_token = add_flow_token + self.dim = qk_dim + + def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3): + """ + query_coord [B, 2, H1, W1] + """ + B, _, H1, W1 = query_coord.shape + + if key is None and value is None: + key = self.k(memory) + value = self.v(memory) + + # [B, 2, H1, W1] -> [BH1W1, 1, 2] + query_coord = query_coord.contiguous() + query_coord = ( + query_coord.view(B, 2, -1) + .permute(0, 2, 1)[:, :, None, :] + .contiguous() + .view(B * H1 * W1, 1, 2) + ) + if self.pe == "linear": + query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim) + elif self.pe == "exp": + query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim) + + short_cut = query + query = self.norm1(query) + + if self.add_flow_token: + q = self.q(query + query_coord_enc) + else: + q = self.q(query_coord_enc) + k, v = key, value + + x = self.multi_head_attn(q, k, v) + + x = self.proj(torch.cat([x, short_cut], dim=2)) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x, k, v + + +class MemoryDecoderLayer(nn.Module): + def __init__(self, dim, cfg): + super(MemoryDecoderLayer, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size # for converting coords into H2', W2' space + + query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim + qk_dim, v_dim = query_token_dim, query_token_dim + self.cross_attend = CrossAttentionLayer( + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + add_flow_token=cfg.add_flow_token, + dropout=cfg.dropout, + ) + + def forward(self, query, key, value, memory, coords1, size, size_h3w3): + """ + x: [B*H1*W1, 1, C] + memory: [B*H1*W1, H2'*W2', C] + coords1 [B, 2, H2, W2] + size: B, C, H1, W1 + 1. Note that here coords0 and coords1 are in H2, W2 space. + Should first convert it into H2', W2' space. + 2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0] + """ + x_global, k, v = self.cross_attend( + query, key, value, memory, coords1, self.patch_size, size_h3w3 + ) + B, C, H1, W1 = size + C = self.cfg.query_latent_dim + x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2) + return x_global, k, v + + +class ReverseCostExtractor(nn.Module): + def __init__(self, cfg): + super(ReverseCostExtractor, self).__init__() + self.cfg = cfg + + def forward(self, cost_maps, coords0, coords1): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + BH1W1, heads, H2, W2 = cost_maps.shape + B, _, H1, W1 = coords1.shape + + assert (H1 == H2) and (W1 == W2) + assert BH1W1 == B * H1 * W1 + + cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2) + coords = coords1.permute(0, 2, 3, 1) + corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2] + corr = rearrange( + corr, + "b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1", + b=B, + heads=heads, + h1=H1, + w1=W1, + h2=H2, + w2=W2, + ) + + r = 4 + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device) + centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2) + delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords = centroid + delta + corr = bilinear_sampler(corr, coords) + corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2) + return corr + + +class MemoryDecoder(nn.Module): + def __init__(self, cfg): + super(MemoryDecoder, self).__init__() + dim = self.dim = cfg.query_latent_dim + self.cfg = cfg + + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1), + nn.GELU(), + nn.Conv2d(dim, dim, 1, 1), + ) + self.proj = nn.Conv2d(256, 256, 1) + self.depth = cfg.decoder_depth + self.decoder_layer = MemoryDecoderLayer(dim, cfg) + + if self.cfg.gma: + self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) + self.att = Attention( + args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128 + ) + else: + self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128) + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch * h1 * w1, 1, 1, 2) + delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords = centroid + delta + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, cost_memory, context, data={}, flow_init=None, iters=None): + """ + memory: [B*H1*W1, H2'*W2', C] + context: [B, D, H1, W1] + """ + cost_maps = data["cost_maps"] + coords0, coords1 = initialize_flow(context) + + if flow_init is not None: + # print("[Using warm start]") + coords1 = coords1 + flow_init + + # flow = coords1 + + flow_predictions = [] + + context = self.proj(context) + net, inp = torch.split(context, [128, 128], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + if self.cfg.gma: + attention = self.att(inp) + + size = net.shape + key, value = None, None + if iters is None: + iters = self.depth + for idx in range(iters): + coords1 = coords1.detach() + + cost_forward = self.encode_flow_token(cost_maps, coords1) + # cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1) + + query = self.flow_token_encoder(cost_forward) + query = ( + query.permute(0, 2, 3, 1) + .contiguous() + .view(size[0] * size[2] * size[3], 1, self.dim) + ) + cost_global, key, value = self.decoder_layer( + query, key, value, cost_memory, coords1, size, data["H3W3"] + ) + if self.cfg.only_global: + corr = cost_global + else: + corr = torch.cat([cost_global, cost_forward], dim=1) + + flow = coords1 - coords0 + + if self.cfg.gma: + net, up_mask, delta_flow = self.update_block( + net, inp, corr, flow, attention + ) + else: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # flow = delta_flow + coords1 = coords1 + delta_flow + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + flow_predictions.append(flow_up) + + # if self.training: + # return flow_predictions + # else: + return flow_predictions[-1], coords1 - coords0 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py new file mode 100644 index 0000000..cea497d --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/encoder.py @@ -0,0 +1,534 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +import numpy as np + +from einops import rearrange +from ...utils.utils import coords_grid +from .attention import ( + BroadMultiHeadAttention, + MultiHeadAttention, + LinearPositionEmbeddingSine, + ExpPositionEmbeddingSine, +) +from ..encoders import twins_svt_large +from typing import Tuple +from .twins import Size_ +from .cnn import BasicEncoder +from .mlpmixer import MLPMixerLayer +from .convnext import ConvNextLayer + +from timm.models.layers import DropPath + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"): + super().__init__() + self.patch_size = patch_size + self.dim = embed_dim + self.pe = pe + + # assert patch_size == 8 + if patch_size == 8: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d( + embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2 + ), + nn.ReLU(), + nn.Conv2d( + embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2 + ), + ) + elif patch_size == 4: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d( + embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2 + ), + ) + else: + print(f"patch size = {patch_size} is unacceptable.") + + self.ffn_with_coord = nn.Sequential( + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1), + nn.ReLU(), + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1), + ) + self.norm = nn.LayerNorm(embed_dim * 2) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape # C == 1 + + pad_l = pad_t = 0 + pad_r = (self.patch_size - W % self.patch_size) % self.patch_size + pad_b = (self.patch_size - H % self.patch_size) % self.patch_size + x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) + + x = self.proj(x) + out_size = x.shape[2:] + + patch_coord = ( + coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size + + self.patch_size / 2 + ) # in feature coordinate space + patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1) + if self.pe == "linear": + patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim) + elif self.pe == "exp": + patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim) + patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view( + B, -1, out_size[0], out_size[1] + ) + + x_pe = torch.cat([x, patch_coord_enc], dim=1) + x = self.ffn_with_coord(x_pe) + x = self.norm(x.flatten(2).transpose(1, 2)) + + return x, out_size + + +from .twins import Block, CrossBlock + + +class GroupVerticalSelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(GroupVerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = dropout + attn_drop_rate = 0.0 + + self.block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + groupattention=True, + cfg=self.cfg, + ) + + def forward(self, x, size, context=None): + x = self.block(x, size, context) + + return x + + +class VerticalSelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(VerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0.0 + drop_rate = dropout + attn_drop_rate = 0.0 + + self.local_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=ws, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + ) + self.global_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + sr_ratio=sr_ratio, + ws=1, + with_rpe=True, + vert_c_dim=cfg.vert_c_dim, + ) + + def forward(self, x, size, context=None): + x = self.local_block(x, size, context) + x = self.global_block(x, size, context) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class SelfAttentionLayer(nn.Module): + def __init__( + self, + dim, + cfg, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(SelfAttentionLayer, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.multi_head_attn = MultiHeadAttention(dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(dim, dim, bias=True), + nn.Linear(dim, dim, bias=True), + nn.Linear(dim, dim, bias=True), + ) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = x + x = self.norm1(x) + + q, k, v = self.q(x), self.k(x), self.v(x) + + x = self.multi_head_attn(q, k, v) + + x = self.proj(x) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class CrossAttentionLayer(nn.Module): + def __init__( + self, + qk_dim, + v_dim, + query_token_dim, + tgt_token_dim, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + dropout=0.0, + ): + super(CrossAttentionLayer, self).__init__() + assert ( + qk_dim % num_heads == 0 + ), f"dim {qk_dim} should be divided by num_heads {num_heads}." + assert ( + v_dim % num_heads == 0 + ), f"dim {v_dim} should be divided by num_heads {num_heads}." + """ + Query Token: [N, C] -> [N, qk_dim] (Q) + Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V) + """ + self.num_heads = num_heads + head_dim = qk_dim // num_heads + self.scale = head_dim**-0.5 + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = ( + nn.Linear(query_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, qk_dim, bias=True), + nn.Linear(tgt_token_dim, v_dim, bias=True), + ) + + self.proj = nn.Linear(v_dim, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout), + ) + + def forward(self, query, tgt_token): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = query + query = self.norm1(query) + + q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token) + + x = self.multi_head_attn(q, k, v) + + x = short_cut + self.proj_drop(self.proj(x)) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + +class CostPerceiverEncoder(nn.Module): + def __init__(self, cfg): + super(CostPerceiverEncoder, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size + self.patch_embed = PatchEmbed( + in_chans=self.cfg.cost_heads_num, + patch_size=self.patch_size, + embed_dim=cfg.cost_latent_input_dim, + pe=cfg.pe, + ) + + self.depth = cfg.encoder_depth + + self.latent_tokens = nn.Parameter( + torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim) + ) + + query_token_dim, tgt_token_dim = ( + cfg.cost_latent_dim, + cfg.cost_latent_input_dim * 2, + ) + qk_dim, v_dim = query_token_dim, query_token_dim + self.input_layer = CrossAttentionLayer( + qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout + ) + + if cfg.use_mlp: + self.encoder_layers = nn.ModuleList( + [ + MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) + for idx in range(self.depth) + ] + ) + else: + self.encoder_layers = nn.ModuleList( + [ + SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) + for idx in range(self.depth) + ] + ) + + if self.cfg.vertical_conv: + self.vertical_encoder_layers = nn.ModuleList( + [ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)] + ) + else: + self.vertical_encoder_layers = nn.ModuleList( + [ + VerticalSelfAttentionLayer( + cfg.cost_latent_dim, cfg, dropout=cfg.dropout + ) + for idx in range(self.depth) + ] + ) + self.cost_scale_aug = None + if "cost_scale_aug" in cfg.keys(): + self.cost_scale_aug = cfg.cost_scale_aug + print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug)) + + def forward(self, cost_volume, data, context=None): + B, heads, H1, W1, H2, W2 = cost_volume.shape + cost_maps = ( + cost_volume.permute(0, 2, 3, 1, 4, 5) + .contiguous() + .view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2) + ) + data["cost_maps"] = cost_maps + + if self.cost_scale_aug is not None: + scale_factor = ( + torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2) + .uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1]) + .to(cost_maps.device) + ) + cost_maps = cost_maps * scale_factor + + x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C + data["H3W3"] = size + H3, W3 = size + + x = self.input_layer(self.latent_tokens, x) + + short_cut = x + + for idx, layer in enumerate(self.encoder_layers): + x = layer(x) + if self.cfg.vertical_conv: + # B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1 + x = ( + x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1) + .permute(0, 3, 1, 2) + .reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1) + ) + x = self.vertical_encoder_layers[idx](x) + # B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D + x = ( + x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1) + .permute(0, 2, 3, 1) + .reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1) + ) + else: + x = ( + x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1) + .permute(0, 2, 1, 3) + .reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1) + ) + x = self.vertical_encoder_layers[idx](x, (H1, W1), context) + x = ( + x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1) + .permute(0, 2, 1, 3) + .reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1) + ) + + if self.cfg.cost_encoder_res is True: + x = x + short_cut + # print("~~~~") + return x + + +class MemoryEncoder(nn.Module): + def __init__(self, cfg): + super(MemoryEncoder, self).__init__() + self.cfg = cfg + + if cfg.fnet == "twins": + self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain) + elif cfg.fnet == "basicencoder": + self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance") + else: + exit() + self.channel_convertor = nn.Conv2d( + cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False + ) + self.cost_perceiver_encoder = CostPerceiverEncoder(cfg) + + def corr(self, fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = rearrange( + fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num + ) + fmap2 = rearrange( + fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num + ) + corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2) + corr = corr.permute(0, 2, 1, 3).view( + batch * ht * wd, self.cfg.cost_heads_num, ht, wd + ) + # corr = self.norm(self.relu(corr)) + corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute( + 0, 2, 1, 3 + ) + corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd) + + return corr + + def forward(self, img1, img2, data, context=None, return_feat=False): + # The original implementation + # feat_s = self.feat_encoder(img1) + # feat_t = self.feat_encoder(img2) + # feat_s = self.channel_convertor(feat_s) + # feat_t = self.channel_convertor(feat_t) + + imgs = torch.cat([img1, img2], dim=0) + feats = self.feat_encoder(imgs) + feats = self.channel_convertor(feats) + B = feats.shape[0] // 2 + feat_s = feats[:B] + if return_feat: + ffeat = feats[:B] + feat_t = feats[B:] + + B, C, H, W = feat_s.shape + size = (H, W) + + if self.cfg.feat_cross_attn: + feat_s = feat_s.flatten(2).transpose(1, 2) + feat_t = feat_t.flatten(2).transpose(1, 2) + + for layer in self.layers: + feat_s, feat_t = layer(feat_s, feat_t, size) + + feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + cost_volume = self.corr(feat_s, feat_t) + x = self.cost_perceiver_encoder(cost_volume, data, context) + + if return_feat: + return x, ffeat + return x diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py new file mode 100644 index 0000000..0394543 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__(self, max_pos_size, dim_head): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange( + max_pos_size + ).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer("rel_ind", rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h) + width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w) + + height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb) + width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size=100, + heads=4, + dim_head=128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + for param in self.pos_emb.parameters(): + param.requires_grad = False + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k)) + q = self.scale * q + + # if self.args.position_only: + # sim = self.pos_emb(q) + + # elif self.args.position_and_content: + # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + # sim_pos = self.pos_emb(q) + # sim = sim_content + sim_pos + + # else: + sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k) + + sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)") + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads=4, + dim_head=128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py new file mode 100644 index 0000000..ec8e9f0 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gru.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if args.only_global: + print("[Decoding with only global cost]") + cor_planes = args.query_latent_dim + else: + cor_planes = 81 + args.query_latent_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +from .gma import Aggregate + + +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU( + hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim + ) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py new file mode 100644 index 0000000..f3646aa --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py @@ -0,0 +1,55 @@ +from torch import nn +from einops.layers.torch import Rearrange, Reduce +from functools import partial +import numpy as np + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear): + return nn.Sequential( + dense(dim, dim * expansion_factor), + nn.GELU(), + nn.Dropout(dropout), + dense(dim * expansion_factor, dim), + nn.Dropout(dropout), + ) + + +class MLPMixerLayer(nn.Module): + def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0): + super(MLPMixerLayer, self).__init__() + + # print(f"use mlp mixer layer") + K = cfg.cost_latent_token_num + expansion_factor = cfg.mlp_expansion_factor + chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear + + self.mlpmixer = nn.Sequential( + PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)), + PreNormResidual( + dim, FeedForward(dim, expansion_factor, dropout, chan_last) + ), + ) + + def compute_params(self): + num = 0 + for param in self.mlpmixer.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + """ + x: [BH1W1, K, D] + """ + + return self.mlpmixer(x) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py new file mode 100644 index 0000000..e3d0c72 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn + +from ...utils.utils import coords_grid +from ..encoders import twins_svt_large +from .encoder import MemoryEncoder +from .decoder import MemoryDecoder +from .cnn import BasicEncoder + + +class FlowFormer(nn.Module): + def __init__(self, cfg): + super(FlowFormer, self).__init__() + self.cfg = cfg + + self.memory_encoder = MemoryEncoder(cfg) + self.memory_decoder = MemoryDecoder(cfg) + if cfg.cnet == "twins": + self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain) + elif cfg.cnet == "basicencoder": + self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance") + + def build_coord(self, img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8) + return coords + + def forward( + self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None + ): + # Following https://github.com/princeton-vl/RAFT/ + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + data = {} + + if self.cfg.context_concat: + context = self.context_encoder(torch.cat([image1, image2], dim=1)) + else: + if return_feat: + context, cfeat = self.context_encoder(image1, return_feat=return_feat) + else: + context = self.context_encoder(image1) + if return_feat: + cost_memory, ffeat = self.memory_encoder( + image1, image2, data, context, return_feat=return_feat + ) + else: + cost_memory = self.memory_encoder(image1, image2, data, context) + + flow_predictions = self.memory_decoder( + cost_memory, context, data, flow_init=flow_init, iters=iters + ) + + if return_feat: + return flow_predictions, cfeat, ffeat + return flow_predictions diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py new file mode 100644 index 0000000..78531d1 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/twins.py @@ -0,0 +1,1360 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import Attention +from timm.models.helpers import build_model_with_cfg#, overlay_external_default_cfg +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine +from ...utils.utils import coords_grid, bilinear_sampler, upflow8 + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": "patch_embeds.0.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "twins_pcpvt_small": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth", + ), + "twins_pcpvt_base": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth", + ), + "twins_pcpvt_large": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth", + ), + "twins_svt_small": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth", + ), + "twins_svt_base": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth", + ), + "twins_svt_large": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth", + ), +} + +Size_ = Tuple[int, int] + + +class GroupAttnRPEContext(nn.Module): + """Latent cost tokens attend to different group""" + + def __init__( + self, + dim, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + ws=1, + cfg=None, + vert_c_dim=0, + ): + super(GroupAttnRPEContext, self).__init__() + assert ws != 1 + assert cfg is not None + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + assert ( + cfg.cost_latent_token_num % 5 == 0 + ), "cost_latent_token_num should be divided by 5." + assert vert_c_dim > 0, "vert_c_dim should not be 0" + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.vert_c_dim = vert_c_dim + + self.cfg = cfg + + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + batch_num = B // 5 + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + coords_enc = coords_enc.reshape(B, Hp, Wp, C_qk) + + q = ( + self.q(x_qk + coords_enc) + .reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads) + .transpose(2, 3) + ) + q = q.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x_qk + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat( + [ + kv[:batch_num, self.ws : Hp, :, :], + kv[:batch_num, Hp - self.ws : Hp, :, :], + ], + dim=1, + ) + kv_down = torch.cat( + [ + kv[batch_num : batch_num * 2, : self.ws, :, :], + kv[batch_num : batch_num * 2, : Hp - self.ws, :, :], + ], + dim=1, + ) + kv_left = torch.cat( + [ + kv[batch_num * 2 : batch_num * 3, :, self.ws : Wp, :], + kv[batch_num * 2 : batch_num * 3, :, Wp - self.ws : Wp, :], + ], + dim=2, + ) + kv_right = torch.cat( + [ + kv[batch_num * 3 : batch_num * 4, :, : self.ws, :], + kv[batch_num * 3 : batch_num * 4, :, : Wp - self.ws, :], + ], + dim=2, + ) + kv_center = kv[batch_num * 4 : batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + k = k.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = v.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + v = v.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GroupAttnRPE(nn.Module): + """Latent cost tokens attend to different group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1, cfg=None): + super(GroupAttnRPE, self).__init__() + assert ws != 1 + assert cfg is not None + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + assert ( + cfg.cost_latent_token_num % 5 == 0 + ), "cost_latent_token_num should be divided by 5." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.cfg = cfg + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + batch_num = B // 5 + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + coords_enc = coords_enc.reshape(B, Hp, Wp, C) + + q = ( + self.q(x + coords_enc) + .reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads) + .transpose(2, 3) + ) + q = q.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat( + [ + kv[:batch_num, self.ws : Hp, :, :], + kv[:batch_num, Hp - self.ws : Hp, :, :], + ], + dim=1, + ) + kv_down = torch.cat( + [ + kv[batch_num : batch_num * 2, : self.ws, :, :], + kv[batch_num : batch_num * 2, : Hp - self.ws, :, :], + ], + dim=1, + ) + kv_left = torch.cat( + [ + kv[batch_num * 2 : batch_num * 3, :, self.ws : Wp, :], + kv[batch_num * 2 : batch_num * 3, :, Wp - self.ws : Wp, :], + ], + dim=2, + ) + kv_right = torch.cat( + [ + kv[batch_num * 3 : batch_num * 4, :, : self.ws, :], + kv[batch_num * 3 : batch_num * 4, :, : Wp - self.ws, :], + ], + dim=2, + ) + kv_center = kv[batch_num * 4 : batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + k = k.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + v = v.reshape( + B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads + ).transpose(2, 3) + v = v.reshape( + B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads + ).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LocallyGroupedAttnRPEContext(nn.Module): + """LSA: self attention within a group""" + + def __init__( + self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1, vert_c_dim=0 + ): + assert ws != 1 + super(LocallyGroupedAttnRPEContext, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + # context are not added to value + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + C_qk = C + self.vert_c_dim + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + x_qk = x_qk.reshape(B, _h, self.ws, _w, self.ws, C_qk).transpose(2, 3) + + v = ( + self.v(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk).view( + B, self.ws, self.ws, C_qk + ) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x_qk = x_qk + coords_enc[:, None, None, :, :, :] + + q = ( + self.q(x_qk) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + k = ( + self.k(x_qk) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPEContext(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__( + self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1, vert_c_dim=0 + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.vert_c_dim = vert_c_dim + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr_key = nn.Conv2d( + dim + vert_c_dim, dim, kernel_size=sr_ratio, stride=sr_ratio + ) + self.sr_value = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + x_qk = x_qk.view(B, -1, C_qk) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = ( + self.q(x_qk + coords_enc) + .reshape(B, padded_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_key is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x_qk = x_qk.permute(0, 2, 1).reshape(B, C_qk, *padded_size) + x = self.sr_value(x).reshape(B, C, -1).permute(0, 2, 1) + x_qk = self.sr_key(x_qk).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x_qk = self.norm(x_qk) + + coords = coords_grid( + B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio + ).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(x_qk + coords_enc) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(x) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttnRPE(nn.Module): + """LSA: self attention within a group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1): + assert ws != 1 + super(LocallyGroupedAttnRPE, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + v = ( + self.v(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C).view( + B, self.ws, self.ws, C + ) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x = x + coords_enc[:, None, None, :, :, :] + + q = ( + self.q(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + k = ( + self.k(x) + .reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5)[0] + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPE(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = ( + self.q(x + coords_enc) + .reshape(B, padded_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + + coords = coords_grid( + B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio + ).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(x + coords_enc) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(x) + .reshape( + B, + (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttnRPE(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + coords = coords_grid(B, *size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, H*W, C + # x: B, H*W, C + q = ( + self.q(x + coords_enc) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + coords = coords_grid(B, size[0] // self.sr_ratio, size[1] // self.sr_ratio).to( + x.device + ) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = ( + self.k(tgt + coords_enc) + .reshape( + B, + (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(tgt) + .reshape( + B, + (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), + self.num_heads, + C // self.num_heads, + ) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttn(nn.Module): + """LSA: self attention within a group""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = ( + self.qkv(x) + .reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads + ) + .permute(3, 0, 1, 4, 2, 5) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttn(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = ( + self.kv(x) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttn(nn.Module): + """GSA: using a key to summarize the information for a group to be efficient.""" + + def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + kv = ( + self.kv(tgt) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + with_rpe=True, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossGlobalSubSampleAttnRPE( + dim, num_heads, attn_drop, drop, sr_ratio + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, src, tgt, size: Size_): + src_shortcut, tgt_shortcut = src, tgt + + src, tgt = self.norm1(src), self.norm1(tgt) + src = src_shortcut + self.drop_path(self.attn(src, tgt, size)) + tgt = tgt_shortcut + self.drop_path(self.attn(tgt, src, size)) + + src = src + self.drop_path(self.mlp(self.norm2(src))) + tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) + return src, tgt + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + with_rpe=False, + vert_c_dim=0, + groupattention=False, + cfg=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + if groupattention: + assert with_rpe, "Not implementing groupattention without rpe" + if vert_c_dim > 0: + self.attn = GroupAttnRPEContext( + dim, num_heads, attn_drop, drop, ws, cfg, vert_c_dim + ) + else: + self.attn = GroupAttnRPE(dim, num_heads, attn_drop, drop, ws, cfg) + elif ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + if with_rpe: + if vert_c_dim > 0: + self.attn = GlobalSubSampleAttnRPEContext( + dim, num_heads, attn_drop, drop, sr_ratio, vert_c_dim + ) + else: + self.attn = GlobalSubSampleAttnRPE( + dim, num_heads, attn_drop, drop, sr_ratio + ) + else: + self.attn = GlobalSubSampleAttn( + dim, num_heads, attn_drop, drop, sr_ratio + ) + else: + if with_rpe: + if vert_c_dim > 0: + self.attn = LocallyGroupedAttnRPEContext( + dim, num_heads, attn_drop, drop, ws, vert_c_dim + ) + else: + self.attn = LocallyGroupedAttnRPE( + dim, num_heads, attn_drop, drop, ws + ) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, size: Size_, context=None): + x = x + self.drop_path(self.attn(self.norm1(x), size, context)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ["proj.%d.weight" % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert ( + img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0 + ), f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """Twins Vision Transfomer (Revisiting Spatial Attention) + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + mlp_ratios=(4, 4, 4, 4), + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=(3, 4, 6, 3), + sr_ratios=(8, 4, 2, 1), + wss=None, + block_cls=Block, + init_weight=True, + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]) + ) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList( + [ + block_cls( + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k], + ) + for i in range(depths[k]) + ] + ) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList( + [PosConv(embed_dim, embed_dim) for embed_dim in embed_dims] + ) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + # init weights + if init_weight: + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(["pos_block." + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +# def _create_twins(variant, pretrained=False, **kwargs): +# if kwargs.get('features_only', None): +# raise RuntimeError('features_only not implemented for Vision Transformer models.') + +# model = build_model_with_cfg( +# Twins, variant, pretrained, +# default_cfg=default_cfgs[variant], +# **kwargs) +# return model + + +# @register_model +# def twins_pcpvt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + +# @register_model +# def twins_svt_large_context(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], in_chans=6, init_weight=False, **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) +# # def twins_svt_large_context(pretrained=False, **kwargs): +# # model_kwargs = dict( +# # patch_size=4, embed_dims=[128, 256], num_heads=[4, 8], mlp_ratios=[4, 4], +# # depths=[2, 2], wss=[7, 7], sr_ratios=[8, 4], in_chans=6, init_weight=False, **kwargs) +# # return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/__init__.py new file mode 100644 index 0000000..b8884e8 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/__init__.py @@ -0,0 +1,7 @@ +def build_flowformer(cfg): + name = cfg.transformer + if name == "latentcostformer": + from .LatentCostFormer.transformer import FlowFormer + else: + raise ValueError(f"FlowFormer = {name} is not a valid architecture!") + return FlowFormer(cfg[name]) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/common.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/common.py new file mode 100644 index 0000000..f696d79 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/common.py @@ -0,0 +1,562 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops import rearrange + +from ..utils.utils import bilinear_sampler, indexing + + +def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300): + """ + x is of shape [*, 2]. The last dimension are two coordinates (x and y). + """ + freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device) + return torch.cat( + [ + x * NORMALIZE_FACOR, + torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR), + torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR), + ], + dim=-1, + ) + + +def sampler_gaussian(latent, mean, std, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1)) + * STD_MAX + * delta + * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + return sampled_latents, sampled_weights + + +def sampler_gaussian_zy( + latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + + +def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1)) + * STD_MAX + * delta + * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix(latent, mean, image_size, point_num=49): + # latent [B, H*W, D] + # mean [B, 2, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + latent, coords + ) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix_pyramid( + latent, feat_pyramid, scale_weight, mean, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = (centroid + delta) / 2**i + coords = rearrange( + coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W + ) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze( + torch.unsqueeze(scale_weight, dim=2), dim=2 + ) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_weight, dim=-1 + ) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + + +def sampler_gaussian_pyramid( + latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange( + latent, "b (h w) c -> b c h w", h=H, w=W + ) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = ( + std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3 + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B * H * W, 1, 1, 2) + coords = (centroid + delta_3sigma) / 2**i + coords = rearrange( + coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W + ) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze( + torch.unsqueeze(scale_weight, dim=2), dim=2 + ) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_weight, dim=-1 + ) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + + +def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25): + """different heads have different mean""" + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + STD_MAX = 20 + latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = ( + torch.stack(torch.meshgrid(dy, dx), axis=-1) + .to(mean.device) + .repeat(HEADS, 1, 1, 1) + ) # [HEADS, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2) + coords = centroid + delta + coords = rearrange( + coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS + ) + sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum] + sampled_latents = sampled_latents.permute( + 0, 2, 3, 1 + ) # [B, H*W*HEADS, pointnum, dim] + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + return sampled_latents, sampled_weights + + +def sampler_gaussian_fix_pyramid_MH( + latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25 +): + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + # scale_head weight [B, H*W, layer_num*heads] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + + latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5) - 1) / 2) + + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + mean.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2) + for i in range(len(feat_pyramid)): + coords = (centroid) / 2**i + delta + coords = rearrange( + coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS + ) + sampled_latents.append( + bilinear_sampler(feat_pyramid[i], coords) + ) # [B, dim, H*W*HEADS, point_num] + + sampled_latents = torch.stack( + sampled_latents, dim=1 + ) # [B, layer_num, dim, H*W*HEADS, point_num] + sampled_latents = sampled_latents.permute( + 0, 3, 4, 2, 1 + ) # [B, H*W*HEADS, point_num, dim, layer_num] + + scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1) + scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num] + scale_head_weight = torch.unsqueeze( + torch.unsqueeze(scale_head_weight, dim=2), dim=2 + ) # [B, H*W*HEADS, 1, 1, layer_num] + + weighted_latent = torch.sum( + sampled_latents * scale_head_weight, dim=-1 + ) # [B, H*W*HEADS, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights + + +def sampler(feat, center, window_size): + # feat [B, C, H, W] + # center [B, 2, H, W] + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + B, H, W, C = center.shape + + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + center.device + ) # [B*H*W, window_size, point_num**0.5, 2] + + center = center.reshape(B * H * W, 1, 1, 2) + coords = center + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + sampled_latents = bilinear_sampler( + feat, coords + ) # [B*H*W, dim, window_size, window_size] + # sampled_latents = sampled_latents.permute(0, 2, 3, 1) + + return sampled_latents + + +def retrieve_tokens(feat, center, window_size, sampler): + # feat [B, C, H, W] + # center [B, 2, H, W] + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2 * radius + 1) + dy = torch.linspace(-radius, radius, 2 * radius + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to( + center.device + ) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + B, H, W, C = center.shape + centroid = center.reshape(B * H * W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W) + if sampler == "nn": + sampled_latents = indexing(feat, coords) + elif sampler == "bilinear": + sampled_latents = bilinear_sampler(feat, coords) + else: + raise ValueError("invalid sampler") + # [B, dim, H*W, point_num] + + return sampled_latents + + +def pyramid_retrieve_tokens( + feat_pyramid, center, image_size, window_sizes, sampler="bilinear" +): + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + sampled_latents_pyramid = [] + for idx in range(len(window_sizes)): + sampled_latents_pyramid.append( + retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler) + ) + center = center / 2 + + return torch.cat(sampled_latents_pyramid, dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + x = self.net(x) + return x + + +class MLP(nn.Module): + def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5): + super().__init__() + self.FC1 = nn.Linear(in_dim, innter_dim) + self.FC_out = nn.Linear(innter_dim, out_dim) + self.relu = torch.nn.LeakyReLU(0.2) + self.FC_inter = nn.ModuleList( + [nn.Linear(innter_dim, innter_dim) for i in range(depth)] + ) + + def forward(self, x): + x = self.FC1(x) + x = self.relu(x) + for inter_fc in self.FC_inter: + x = inter_fc(x) + x = self.relu(x) + x = self.FC_out(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.num_kv_tokens = num_kv_tokens + self.scale = (dim / heads) ** -0.5 + self.rpe = cfg.rpe + self.attend = nn.Softmax(dim=-1) + self.use_rpe = use_rpe + + if use_rpe: + if rpe_bias is None: + if self.rpe == "element-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(heads, self.num_kv_tokens, dim // heads) + ) + elif self.rpe == "head-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(1, heads, 1, self.num_kv_tokens) + ) + elif self.rpe == "token-wise": + self.rpe_bias = nn.Parameter( + torch.zeros(1, 1, 1, self.num_kv_tokens) + ) # 81 is point_num + elif self.rpe == "implicit": + pass + # self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2) + # raise ValueError('Implicit Encoding Not Implemented') + elif self.rpe == "element-wise-value": + self.rpe_bias = nn.Parameter( + torch.zeros(heads, self.num_kv_tokens, dim // heads) + ) + self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim)) + else: + raise ValueError("Not Implemented") + else: + self.rpe_bias = rpe_bias + + def attend_with_rpe(self, Q, K, rpe_bias): + Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads) + K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads) + + dots = ( + einsum("bhid, bhjd -> bhij", Q, K) * self.scale + ) # (b hw) heads 1 pointnum + if self.use_rpe: + if self.rpe == "element-wise": + rpe_bias_weight = ( + einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale + ) # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == "implicit": + pass + rpe_bias_weight = ( + einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale + ) # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == "head-wise" or self.rpe == "token-wise": + dots = dots + rpe_bias + + return self.attend(dots), dots + + def forward(self, Q, K, V, rpe_bias=None): + if self.use_rpe: + if rpe_bias is None or self.rpe == "element-wise": + rpe_bias = self.rpe_bias + else: + rpe_bias = rearrange( + rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads + ) + attn, dots = self.attend_with_rpe(Q, K, rpe_bias) + else: + attn, dots = self.attend_with_rpe(Q, K, None) + B, HW, _ = Q.shape + + if V is not None: + V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads) + + out = einsum("bhij, bhjd -> bhid", attn, V) + out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW) + else: + out = None + + # dots = torch.squeeze(dots, 2) + # dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW) + + return out, dots diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/encoders.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/encoders.py new file mode 100644 index 0000000..92d2d63 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/FlowFormer/encoders.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import timm +import numpy as np + + +class twins_svt_large(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model("twins_svt_large", pretrained=pretrained) + + del self.svt.head + del self.svt.patch_embeds[2] + del self.svt.patch_embeds[2] + del self.svt.blocks[2] + del self.svt.blocks[2] + del self.svt.pos_block[2] + del self.svt.pos_block[2] + self.svt.norm.weight.requires_grad = False + self.svt.norm.bias.requires_grad = False + + def forward(self, x, data=None, layer=2, return_feat=False): + B = x.shape[0] + if return_feat: + feat = [] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + if return_feat: + feat.append(x) + if i == layer - 1: + break + if return_feat: + return x, feat + return x + + def compute_params(self, layer=2): + num = 0 + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + for param in embed.parameters(): + num += np.prod(param.size()) + + for param in drop.parameters(): + num += np.prod(param.size()) + + for param in blocks.parameters(): + num += np.prod(param.size()) + + for param in pos_blk.parameters(): + num += np.prod(param.size()) + + if i == layer - 1: + break + + for param in self.svt.head.parameters(): + num += np.prod(param.size()) + + return num + + +class twins_svt_large_context(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained) + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip( + self.svt.patch_embeds, + self.svt.pos_drops, + self.svt.blocks, + self.svt.pos_block, + ) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer - 1: + break + + return x + + +if __name__ == "__main__": + m = twins_svt_large() + input = torch.randn(2, 3, 400, 800) + out = m.extract_feature(input) + print(out.shape) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py new file mode 100644 index 0000000..3c5819f --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py new file mode 100644 index 0000000..1e7f2db --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/position_encoding.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/position_encoding.py new file mode 100644 index 0000000..0dd56b1 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/position_encoding.py @@ -0,0 +1,100 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + div_term = torch.exp( + torch.arange(0, d_model // 2, 2).float() + * (-math.log(10000.0) / d_model // 2) + ) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, : x.size(2), : x.size(3)] + + +class LinearPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = ( + torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1 + ) / max_shape[0] + x_position = ( + torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1 + ) / max_shape[1] + div_term = torch.arange(0, d_model // 2, 2).float() + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi) + pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi) + pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi) + pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi) + + self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe[:, :, : x.size(2), : x.size(3)] + + +class LearnedPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(80, 80)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model)) + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/update.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/update.py new file mode 100644 index 0000000..ced6df0 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/update.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/utils/__init__.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py b/gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py new file mode 100644 index 0000000..8718cbe --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py @@ -0,0 +1,113 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + elif mode == "kitti400": + self._pad = [0, 0, 0, 400 - self.ht] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 + ) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 + ) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def indexing(img, coords, mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + """ + TODO: directly indexing features instead of sampling + """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True, mode="nearest") + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/gimm_vfi_arch/generalizable_INR/gimm.py b/gimm_vfi_arch/generalizable_INR/gimm.py new file mode 100644 index 0000000..4d18f8c --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/gimm.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .configs import GIMMConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.fi_components import LateralBlock +from .modules.hyponet import HypoNet +from .modules.fi_utils import warp + +from .modules.softsplat import softsplat + + +class GIMM(nn.Module): + Config = GIMMConfig + + def __init__(self, config: GIMMConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.coord_sampler = CoordSampler3D(config.coord_range) + self.fwarp_type = config.fwarp_type + + # Motion Encoder + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + # Latent Refiner + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def forward( + self, xs, coord=None, keep_xs_shape=True, ori_flow=None, timesteps=None + ): + coord = self.sample_coord_input(xs) if coord is None else coord + raft_flow01 = ori_flow[:, :, 0] + raft_flow10 = ori_flow[:, :, 1] + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(xs[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(xs[:, :, 1]) + pixel_latent = [] + + modulation_params_dict = None + strtype = self.fwarp_type + if isinstance(timesteps, list): + assert isinstance(coord, list) + assert len(timesteps) == len(coord) + for i, cur_t in enumerate(timesteps): + cur_t = cur_t.reshape(-1, 1, 1, 1) + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + for idx, c in enumerate(coord): + outputs = self.hyponet( + c, + modulation_params_dict=modulation_params_dict, + pixel_latent=pixel_latent[idx], + ) + if keep_xs_shape: + permute_idx_range = [i for i in range(1, xs.ndim - 1)] + outputs = outputs.permute(0, -1, *permute_idx_range) + all_outputs.append(outputs) + return all_outputs + + else: + cur_t = timesteps.reshape(-1, 1, 1, 1) + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype + "-zeroeps", + ) + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent = tmp_pixel_latent.permute(0, 2, 3, 1) + + # predict all pixels of coord after applying the modulation_parms into hyponet + outputs = self.hyponet( + coord, + modulation_params_dict=modulation_params_dict, + pixel_latent=pixel_latent, + ) + if keep_xs_shape: + permute_idx_range = [i for i in range(1, xs.ndim - 1)] + outputs = outputs.permute(0, -1, *permute_idx_range) + return outputs + + def compute_loss(self, preds, targets, reduction="mean", single=False): + assert reduction in ["mean", "sum", "none"] + batch_size = preds.shape[0] + sample_mses = 0 + assert preds.shape[2] == 1 + assert targets.shape[2] == 1 + for i in range(preds.shape[2]): + sample_mses += torch.reshape( + (preds[:, :, i] - targets[:, :, i]) ** 2, (batch_size, -1) + ).mean(dim=-1) + sample_mses = sample_mses / preds.shape[2] + if reduction == "mean": + total_loss = sample_mses.mean() + psnr = (-10 * torch.log10(sample_mses)).mean() + elif reduction == "sum": + total_loss = sample_mses.sum() + psnr = (-10 * torch.log10(sample_mses)).sum() + else: + total_loss = sample_mses + psnr = -10 * torch.log10(sample_mses) + + return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr} + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs diff --git a/gimm_vfi_arch/generalizable_INR/gimmvfi_f.py b/gimm_vfi_arch/generalizable_INR/gimmvfi_f.py new file mode 100644 index 0000000..f0001e5 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/gimmvfi_f.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .configs import GIMMVFIConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.hyponet import HypoNet +from .modules.fi_components import * +#from .flowformer import initialize_Flowformer +from .modules.fi_utils import normalize_flow, unnormalize_flow, warp, resize +from .raft.corr import BidirCorrBlock +from .modules.softsplat import softsplat + + +class GIMMVFI_F(nn.Module): + Config = GIMMVFIConfig + + def __init__(self, dtype, config: GIMMVFIConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.raft_iter = config.raft_iter + + self.dtype = dtype + + ######### Encoder and Decoder Settings ######### + #self.flow_estimator = initialize_Flowformer() + f_dims = [256, 128] + + skip_channels = f_dims[-1] // 2 + self.num_flows = 3 + self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels) + self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels) + + self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0) + self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None) + + self.amt_comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + ################ GIMM settings ################# + self.coord_sampler = CoordSampler3D(config.coord_range) + + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.fwarp_type = config.fwarp_type + + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=192, + flow_dim=64, + corr_dim=256, + corr_dim2=192, + fc_dim=188, + scale_factor=scale_factor, + corr_levels=4, + radius=4, + ) + + def cal_bidirection_flow(self, im0, im1): + f01, features0, fnet0 = self.flow_estimator( + im0, im1, return_feat=True, iters=None + ) + f10, features1, fnet1 = self.flow_estimator( + im1, im0, return_feat=True, iters=None + ) + f01 = f01[0] + f10 = f10[0] + corr_fn = BidirCorrBlock(fnet0, fnet1, radius=4) + flow01 = f01.unsqueeze(2) + flow10 = f10.unsqueeze(2) + noraml_flows = torch.cat([flow01, -flow10], dim=2) + noraml_flows, flow_scalers = normalize_flow(noraml_flows) + + ori_flows = torch.cat([flow01, flow10], dim=2) + return ( + noraml_flows, + ori_flows, + flow_scalers, + features0, + features1, + corr_fn, + torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2), + ) + + def predict_flow(self, f, coord, t, flows): + raft_flow01 = flows[:, :, 0].detach() + raft_flow10 = flows[:, :, 1].detach() + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + strtype = self.fwarp_type + "-zeroeps" + + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(f[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(f[:, :, 1]) + pixel_latent = [] + + for i, cur_t in enumerate(t): + cur_t = cur_t.reshape(-1, 1, 1, 1) + + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype, + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype, + ) + + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + permute_idx_range = [i for i in range(1, f.ndim - 1)] + for idx, c in enumerate(coord): + assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze() + assert isinstance(c, tuple) + + if c[1] is None: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ).permute(0, -1, *permute_idx_range) + else: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ) + all_outputs.append(outputs) + + return all_outputs + + def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1): + ft0 = scale * resize(ft0, scale_factor=scale) + ft1 = scale * resize(ft1, scale_factor=scale) + mask = resize(mask, scale_factor=scale).sigmoid() + img0_warp = warp(img0, ft0) + img1_warp = warp(img1, ft1) + img_warp = mask * img0_warp + (1 - mask) * img1_warp + return img_warp + + @torch.compiler.disable() + def frame_synthesize( + self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None + ): + """ + flow_t: b,2,h,w + cur_t: b,1,1,1 + """ + batch_size = img_xs.shape[0] + img0 = 2 * img_xs[:, :, 0] - 1.0 + img1 = 2 * img_xs[:, :, 1] - 1.0 + + ##################### update the predicted flow ##################### + ## initialize coordinates for looking up + lookup_coord = self.flow_estimator.build_coord(img_xs[:, :, 0]).to( + img_xs[:, :, 0].device + ) + + flow_t0_fullsize = flow_t * (-cur_t) + flow_t1_fullsize = flow_t * (1.0 - cur_t) + + inv = 1 / 4 + flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv) + flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv) + + ############################# scale 1/4 ############################# + # i. Initialize feature t at scale 1/4 + flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder( + features0[-1], + features1[-1], + flow_t0_inr4, + flow_t1_inr4, + img0=img0, + img1=img1, + ) + mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:] + img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4) + img_warp_4 = (img_warp_4 + 1.0) / 2 + img_warp_4 = torch.clamp(img_warp_4, 0, 1) + + corr_4, flow_4_lr = self._amt_corr_scale_lookup( + corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2 + ) + + delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + flowt0_4 = flowt0_4 + delta_flow0_4 + flowt1_4 = flowt1_4 + delta_flow1_4 + ft_4_ = ft_4_ + delta_ft_4_ + + # iii. residue update with lookup corr + corr_4 = resize(corr_4, scale_factor=2.0) + + flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1) + delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4) + flowt0_4 = flowt0_4 + delta_flow_4[:, :2] + flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4] + ft_4_ = ft_4_ + delta_ft_4_ + + ############################# scale 1/1 ############################# + flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder( + ft_4_, + features0[0], + features1[0], + flowt0_4, + flowt1_4, + mask=mask_4_, + img0=img0, + img1=img1, + ) + + if full_img is not None: + img0 = 2 * full_img[:, :, 0] - 1.0 + img1 = 2 * full_img[:, :, 1] - 1.0 + inv = img1.shape[2] / flowt0_1.shape[2] + flowt0_1 = inv * resize(flowt0_1, scale_factor=inv) + flowt1_1 = inv * resize(flowt1_1, scale_factor=inv) + flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv) + flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv) + mask = resize(mask, scale_factor=inv) + img_res = resize(img_res, scale_factor=inv) + + imgt_pred = multi_flow_combine( + self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None + ) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + ###################################################################### + + flowt0_1 = flowt0_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + flowt1_1 = flowt1_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + + flowt0_pred = [flowt0_1, flowt0_4] + flowt1_pred = [flowt1_1, flowt1_4] + other_pred = [img_warp_4] + return imgt_pred, flowt0_pred, flowt1_pred, other_pred + + def forward(self, img_xs, coord=None, t=None, ds_factor=None): + assert isinstance(t, list) + assert isinstance(coord, list) + assert len(t) == len(coord) + full_size_img = None + if ds_factor is not None: + full_size_img = img_xs.clone() + img_xs = torch.cat( + [ + resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2), + resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2), + ], + dim=2, + ) + + ( + normal_flows, + flows, + flow_scalers, + features0, + features1, + corr_fn, + preserved_raft_flows, + ) = self.cal_bidirection_flow(255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1]) + assert coord is not None + + # List of flows + normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows) + + ############ Unnormalize the predicted/reconstructed flow ############ + start_idx = 0 + if coord[0][1] is not None: + # Subsmapled flows for reconstruction supervision in the GIMM module + # In such case, first two coords in the list are subsampled for supervision up-mentioned + # Normalized flow_t towards positive t-axis + assert len(coord) > 2 + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(2, len(coord)) + ] + start_idx = 2 + else: + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(len(coord)) + ] + + imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], [] + + for idx in range(start_idx, len(coord)): + cur_flow_t = flow_t[idx - start_idx] + cur_t = t[idx].reshape(-1, 1, 1, 1) + if cur_flow_t.ndim != 4: + cur_flow_t = cur_flow_t.unsqueeze(0) + assert cur_flow_t.ndim == 4 + + imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize( + img_xs, + cur_flow_t, + features0, + features1, + corr_fn, + cur_t, + full_img=full_size_img, + ) + + imgt_preds.append(imgt_pred) + flowt0_preds.append(flowt0_pred) + flowt1_preds.append(flowt1_pred) + all_others.append(others) + + return { + "imgt_pred": imgt_preds, + "other_pred": all_others, + "flowt0_pred": flowt0_preds, + "flowt1_pred": flowt1_preds, + "raft_flow": preserved_raft_flows, + "ninrflow": normal_inr_flows, + "nflow": normal_flows, + "flowt": flow_t, + } + + def warp_frame(self, frame, flow): + return warp(frame, flow) + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t0_scale = 1.0 / embt + t1_scale = 1.0 / (1.0 - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow diff --git a/gimm_vfi_arch/generalizable_INR/gimmvfi_r.py b/gimm_vfi_arch/generalizable_INR/gimmvfi_r.py new file mode 100644 index 0000000..7eb27f9 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/gimmvfi_r.py @@ -0,0 +1,508 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn + +from .configs import GIMMVFIConfig +from .modules.coord_sampler import CoordSampler3D +from .modules.hyponet import HypoNet +from .modules.fi_components import * +from .modules.fi_utils import ( + normalize_flow, + unnormalize_flow, + warp, + resize, + build_coord, +) +import torch.nn.functional as F + +from .raft.corr import BidirCorrBlock +from .modules.softsplat import softsplat + + +class GIMMVFI_R(nn.Module): + Config = GIMMVFIConfig + + def __init__(self, dtype, config: GIMMVFIConfig): + super().__init__() + self.config = config = config.copy() + self.hyponet_config = config.hyponet + self.raft_iter = 20 + + ######### Encoder and Decoder Settings ######### + #self.flow_estimator = initialize_RAFT() + cur_f_dims = [128, 96] + f_dims = [256, 128] + self.dtype = dtype + + skip_channels = f_dims[-1] // 2 + self.num_flows = 3 + + self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1) + self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1) + self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1) + self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels) + self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels) + + self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0) + self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None) + + self.amt_comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + ################ GIMM settings ################# + self.coord_sampler = CoordSampler3D(config.coord_range) + + self.g_filter = torch.nn.Parameter( + torch.FloatTensor( + [ + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], + [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], + ] + ).reshape(1, 1, 1, 3, 3), + requires_grad=False, + ) + self.fwarp_type = config.fwarp_type + + self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True) + + channel = 32 + in_dim = 2 + self.cnn_encoder = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + LateralBlock(channel), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + channel = 64 + in_dim = 64 + self.res_conv = nn.Sequential( + nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), + nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + LateralBlock(channel), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d( + channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True + ), + ) + + self.hyponet = HypoNet(config.hyponet, add_coord_dim=32) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=192, + flow_dim=64, + corr_dim=256, + corr_dim2=192, + fc_dim=188, + scale_factor=scale_factor, + corr_levels=4, + radius=4, + ) + + def cal_bidirection_flow(self, im0, im1, iters=20): + f01, features0, fnet0 = self.flow_estimator( + im0.to(self.dtype), im1.to(self.dtype), return_feat=True, iters=20 + ) + f10, features1, fnet1 = self.flow_estimator( + im1.to(self.dtype), im0.to(self.dtype), return_feat=True, iters=20 + ) + corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4) + features0 = [ + self.amt_second_last_cproj(features0[0]), + self.amt_last_cproj(features0[1]), + ] + features1 = [ + self.amt_second_last_cproj(features1[0]), + self.amt_last_cproj(features1[1]), + ] + flow01 = f01.unsqueeze(2) + flow10 = f10.unsqueeze(2) + noraml_flows = torch.cat([flow01, -flow10], dim=2) + noraml_flows, flow_scalers = normalize_flow(noraml_flows) + + ori_flows = torch.cat([flow01, flow10], dim=2) + return ( + noraml_flows, + ori_flows, + flow_scalers, + features0, + features1, + corr_fn, + torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2), + ) + + def predict_flow(self, f, coord, t, flows): + raft_flow01 = flows[:, :, 0].detach() + raft_flow10 = flows[:, :, 1].detach() + + # calculate splatting metrics + weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) + strtype = self.fwarp_type + "-zeroeps" + + # b,c,h,w + pixel_latent_0 = self.cnn_encoder(f[:, :, 0]) + pixel_latent_1 = self.cnn_encoder(f[:, :, 1]) + pixel_latent = [] + + for i, cur_t in enumerate(t): + cur_t = cur_t.reshape(-1, 1, 1, 1) + + tmp_pixel_latent_0 = softsplat( + tenIn=pixel_latent_0, + tenFlow=raft_flow01 * cur_t, + tenMetric=weights1, + strMode=strtype, + ) + tmp_pixel_latent_1 = softsplat( + tenIn=pixel_latent_1, + tenFlow=raft_flow10 * (1 - cur_t), + tenMetric=weights2, + strMode=strtype, + ) + + tmp_pixel_latent = torch.cat( + [tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1 + ) + tmp_pixel_latent = tmp_pixel_latent + self.res_conv( + torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) + ) + pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1)) + + all_outputs = [] + permute_idx_range = [i for i in range(1, f.ndim - 1)] + for idx, c in enumerate(coord): + assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze() + assert isinstance(c, tuple) + + if c[1] is None: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ).permute(0, -1, *permute_idx_range) + else: + outputs = self.hyponet( + c, modulation_params_dict=None, pixel_latent=pixel_latent[idx] + ) + all_outputs.append(outputs) + + return all_outputs + + def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1): + ft0 = scale * resize(ft0, scale_factor=scale) + ft1 = scale * resize(ft1, scale_factor=scale) + mask = resize(mask, scale_factor=scale).sigmoid() + img0_warp = warp(img0, ft0) + img1_warp = warp(img1, ft1) + img_warp = mask * img0_warp + (1 - mask) * img1_warp + return img_warp + + @torch.compiler.disable() + def frame_synthesize( + self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None + ): + """ + flow_t: b,2,h,w + cur_t: b,1,1,1 + """ + batch_size = img_xs.shape[0] # b,c,t,h,w + img0 = 2 * img_xs[:, :, 0] - 1.0 + img1 = 2 * img_xs[:, :, 1] - 1.0 + + ##################### update the predicted flow ##################### + ##initialize coordinates for looking up + lookup_coord = build_coord(img_xs[:, :, 0]).to( + img_xs[:, :, 0].device + ) # H//8,W//8 + + flow_t0_fullsize = flow_t * (-cur_t) + flow_t1_fullsize = flow_t * (1.0 - cur_t) + + inv = 1 / 4 + flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv) + flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv) + + ############################# scale 1/4 ############################# + # i. Initialize feature t at scale 1/4 + flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder( + features0[-1], + features1[-1], + flow_t0_inr4, + flow_t1_inr4, + img0=img0, + img1=img1, + ) + features0, features1 = features0[:-1], features1[:-1] + + mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:] + img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4) + img_warp_4 = (img_warp_4 + 1.0) / 2 + img_warp_4 = torch.clamp(img_warp_4, 0, 1) + + corr_4, flow_4_lr = self._amt_corr_scale_lookup( + corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2 + ) + + delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + flowt0_4 = flowt0_4 + delta_flow0_4 + flowt1_4 = flowt1_4 + delta_flow1_4 + ft_4_ = ft_4_ + delta_ft_4_ + + # iii. residue update with lookup corr + corr_4 = resize(corr_4, scale_factor=2.0) + + flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1) + delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4) + flowt0_4 = flowt0_4 + delta_flow_4[:, :2] + flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4] + ft_4_ = ft_4_ + delta_ft_4_ + + ############################# scale 1/1 ############################# + flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder( + ft_4_, + features0[0], + features1[0], + flowt0_4, + flowt1_4, + mask=mask_4_, + img0=img0, + img1=img1, + ) + + if full_img is not None: + img0 = 2 * full_img[:, :, 0] - 1.0 + img1 = 2 * full_img[:, :, 1] - 1.0 + inv = img1.shape[2] / flowt0_1.shape[2] + flowt0_1 = inv * resize(flowt0_1, scale_factor=inv) + flowt1_1 = inv * resize(flowt1_1, scale_factor=inv) + flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv) + flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv) + mask = resize(mask, scale_factor=inv) + img_res = resize(img_res, scale_factor=inv) + + imgt_pred = multi_flow_combine( + self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None + ) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + ###################################################################### + + flowt0_1 = flowt0_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + flowt1_1 = flowt1_1.reshape( + batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] + ) + + flowt0_pred = [flowt0_1, flowt0_4] + flowt1_pred = [flowt1_1, flowt1_4] + other_pred = [img_warp_4] + return imgt_pred, flowt0_pred, flowt1_pred, other_pred + + def forward(self, img_xs, coord=None, t=None, iters=None, ds_factor=None): + assert isinstance(t, list) + assert isinstance(coord, list) + assert len(t) == len(coord) + full_size_img = None + if ds_factor is not None: + full_size_img = img_xs.clone() + img_xs = torch.cat( + [ + resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2), + resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2), + ], + dim=2, + ) + + iters = self.raft_iter if iters is None else iters + ( + normal_flows, + flows, + flow_scalers, + features0, + features1, + corr_fn, + preserved_raft_flows, + ) = self.cal_bidirection_flow( + 255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters + ) + assert coord is not None + + # List of flows + normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows) + + ############ Unnormalize the predicted/reconstructed flow ############ + start_idx = 0 + if coord[0][1] is not None: + # Subsmapled flows for reconstruction supervision in the GIMM module + # In such case, by default, first two coords are subsampled for supervision up-mentioned + # normalized flow_t versus positive t-axis + assert len(coord) > 2 + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(2, len(coord)) + ] + start_idx = 2 + else: + flow_t = [ + unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze() + for i in range(len(coord)) + ] + + imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], [] + + for idx in range(start_idx, len(coord)): + cur_flow_t = flow_t[idx - start_idx] + cur_t = t[idx].reshape(-1, 1, 1, 1) + if cur_flow_t.ndim != 4: + cur_flow_t = cur_flow_t.unsqueeze(0) + assert cur_flow_t.ndim == 4 + + imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize( + img_xs, + cur_flow_t, + features0, + features1, + corr_fn, + cur_t, + full_img=full_size_img, + ) + + imgt_preds.append(imgt_pred) + flowt0_preds.append(flowt0_pred) + flowt1_preds.append(flowt1_pred) + all_others.append(others) + + return { + "imgt_pred": imgt_preds, + "other_pred": all_others, + "flowt0_pred": flowt0_preds, + "flowt1_pred": flowt1_preds, + "raft_flow": preserved_raft_flows, + "ninrflow": normal_inr_flows, + "nflow": normal_flows, + "flowt": flow_t, + } + + def warp_frame(self, frame, flow): + return warp(frame, flow) + + def compute_psnr(self, preds, targets, reduction="mean"): + assert reduction in ["mean", "sum", "none"] + batch_size = preds.shape[0] + sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean( + dim=-1 + ) + + if reduction == "mean": + psnr = (-10 * torch.log10(sample_mses)).mean() + elif reduction == "sum": + psnr = (-10 * torch.log10(sample_mses)).sum() + else: + psnr = -10 * torch.log10(sample_mses) + + return psnr + + def sample_coord_input( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + assert device is not None + assert coord_range is None + coord_inputs = self.coord_sampler( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + return coord_inputs + + def cal_splatting_weights(self, raft_flow01, raft_flow10): + batch_size = raft_flow01.shape[0] + raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) + + ## flow variance metric + sqaure_mean, mean_square = torch.split( + F.conv3d( + F.pad( + torch.cat([raft_flows**2, raft_flows], 1), + (1, 1, 1, 1), + mode="reflect", + ).unsqueeze(1), + self.g_filter, + ).squeeze(1), + 2, + dim=1, + ) + var = ( + (sqaure_mean - mean_square**2) + .clamp(1e-9, None) + .sqrt() + .mean(1) + .unsqueeze(1) + ) + var01 = var[:batch_size] + var10 = var[batch_size:] + + ## flow warp metirc + f01_warp = -warp(raft_flow10, raft_flow01) + f10_warp = -warp(raft_flow01, raft_flow10) + err01 = ( + torch.nn.functional.l1_loss( + input=f01_warp, target=raft_flow01, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + err02 = ( + torch.nn.functional.l1_loss( + input=f10_warp, target=raft_flow10, reduction="none" + ) + .mean(1) + .unsqueeze(1) + ) + + weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) + weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) + + return weights1, weights2 + + def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t0_scale = 1.0 / embt + t1_scale = 1.0 / (1.0 - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow diff --git a/gimm_vfi_arch/generalizable_INR/modules/__init__.py b/gimm_vfi_arch/generalizable_INR/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py b/gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py new file mode 100644 index 0000000..9b5225c --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import torch +import torch.nn as nn + + +class CoordSampler3D(nn.Module): + def __init__(self, coord_range, t_coord_only=False): + super().__init__() + self.coord_range = coord_range + self.t_coord_only = t_coord_only + + def shape2coordinate( + self, + batch_size, + spatial_shape, + t_ids, + coord_range=(-1.0, 1.0), + upsample_ratio=1, + device=None, + ): + coords = [] + assert isinstance(t_ids, list) + _coords = torch.tensor(t_ids, device=device) / 1.0 + coords.append(_coords.to(torch.float32)) + for num_s in spatial_shape: + num_s = int(num_s * upsample_ratio) + _coords = (0.5 + torch.arange(num_s, device=device)) / num_s + _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords + coords.append(_coords) + coords = torch.meshgrid(*coords, indexing="ij") + coords = torch.stack(coords, dim=-1) + ones_like_shape = (1,) * coords.ndim + coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) + return coords # (B,T,H,W,3) + + def batchshape2coordinate( + self, + batch_size, + spatial_shape, + t_ids, + coord_range=(-1.0, 1.0), + upsample_ratio=1, + device=None, + ): + coords = [] + _coords = torch.tensor(1, device=device) + coords.append(_coords.to(torch.float32)) + for num_s in spatial_shape: + num_s = int(num_s * upsample_ratio) + _coords = (0.5 + torch.arange(num_s, device=device)) / num_s + _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords + coords.append(_coords) + coords = torch.meshgrid(*coords, indexing="ij") + coords = torch.stack(coords, dim=-1) + ones_like_shape = (1,) * coords.ndim + # Now coords b,1,h,w,3, coords[...,0]=1. + coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) + # assign per-sample timestep within the batch + coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1) + return coords + + def forward( + self, + batch_size, + s_shape, + t_ids, + coord_range=None, + upsample_ratio=1.0, + device=None, + ): + coord_range = self.coord_range if coord_range is None else coord_range + if isinstance(t_ids, list): + coords = self.shape2coordinate( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + elif isinstance(t_ids, torch.Tensor): + coords = self.batchshape2coordinate( + batch_size, s_shape, t_ids, coord_range, upsample_ratio, device + ) + if self.t_coord_only: + coords = coords[..., :1] + return coords diff --git a/gimm_vfi_arch/generalizable_INR/modules/fi_components.py b/gimm_vfi_arch/generalizable_INR/modules/fi_components.py new file mode 100644 index 0000000..f3bfcdd --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/fi_components.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# motif: https://github.com/sichun233746/MoTIF +# -------------------------------------------------------- + +import torch +import torch.nn as nn +from .fi_utils import warp, resize + + +class LateralBlock(nn.Module): + def __init__(self, dim): + super(LateralBlock, self).__init__() + self.layers = nn.Sequential( + nn.Conv2d(dim, dim, 3, 1, 1, bias=True), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(dim, dim, 3, 1, 1, bias=True), + ) + + def forward(self, x): + res = x + x = self.layers(x) + return x + res + + +def convrelu( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, +): + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias, + ), + nn.PReLU(out_channels), + ) + + +def multi_flow_combine( + comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None +): + assert mean is None + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = ( + mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) + if mask is not None + else None + ) + img_res = ( + img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) + if img_res is not None + else 0 + ) + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = ( + torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) + if mean is not None + else 0 + ) + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + + res = comb_block(img_warps.view(b, -1, h, w)) + imgt_pred = img_warps.mean(1) + res + + imgt_pred = (imgt_pred + 1.0) / 2 + + return imgt_pred + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv2 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv3 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv4 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv5 = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, : -self.side_channels, ...] + side_feat = out[:, -self.side_channels :, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class BasicUpdateBlock(nn.Module): + def __init__( + self, + cdim, + hidden_dim, + flow_dim, + corr_dim, + corr_dim2, + fc_dim, + corr_levels=4, + radius=3, + scale_factor=None, + out_num=1, + ): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = ( + resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net + ) + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize( + delta_flow, scale_factor=self.scale_factor + ) + return delta_net, delta_flow + + +def get_bn(): + return nn.BatchNorm2d + + +class NewInitDecoder(nn.Module): + def __init__(self, in_ch, skip_ch): + super().__init__() + norm_layer = get_bn() + + self.upsample = nn.Sequential( + nn.PixelShuffle(2), + convrelu(in_ch // 4, in_ch // 4, 5, 1, 2), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 2), + nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), + norm_layer(in_ch // 2), + nn.ReLU(inplace=True), + ) + + in_ch = in_ch // 2 + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0), + ResBlock(in_ch, skip_ch), + ResBlock(in_ch, skip_ch), + ResBlock(in_ch, skip_ch), + nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True), + ) + + def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None): + f0 = self.upsample(f0) + f1 = self.upsample(f1) + f0_warp_ks = warp(f0, flow0_in) + f1_warp_ks = warp(f1, flow1_in) + + f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1) + + assert img0 is not None + assert img1 is not None + scale_factor = f_in.shape[2] / img0.shape[2] + img0 = resize(img0, scale_factor=scale_factor) + img1 = resize(img1, scale_factor=scale_factor) + warped_img0 = warp(img0, flow0_in) + warped_img1 = warp(img1, flow1_in) + f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) + + out = self.convblock(f_in) + ft_ = out[:, 4:, ...] + flow0 = flow0_in + out[:, :2, ...] + flow1 = flow1_in + out[:, 2:4, ...] + return flow0, flow1, ft_ + + +class NewMultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(NewMultiFlowDecoder, self).__init__() + norm_layer = get_bn() + + self.upsample = nn.Sequential( + nn.PixelShuffle(2), + nn.PixelShuffle(2), + convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 4), + convrelu(in_ch // 4, in_ch // 2), + nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), + norm_layer(in_ch // 2), + nn.ReLU(inplace=True), + ) + + self.num_flows = num_flows + ch_factor = 2 + self.convblock = nn.Sequential( + convrelu(in_ch * ch_factor + 17, in_ch * ch_factor), + ResBlock(in_ch * ch_factor, skip_ch), + ResBlock(in_ch * ch_factor, skip_ch), + ResBlock(in_ch * ch_factor, skip_ch), + nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1), + ) + + def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None): + f0 = self.upsample(f0) + # print([f1.shape,f0.shape]) + f1 = self.upsample(f1) + n = self.num_flows + flow0 = 4.0 * resize(flow0, scale_factor=4.0) + flow1 = 4.0 * resize(flow1, scale_factor=4.0) + + ft_ = resize(ft_, scale_factor=4.0) + mask = resize(mask, scale_factor=4.0) + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1) + + assert mask is not None + f_in = torch.cat([f_in, mask], 1) + + assert img0 is not None + assert img1 is not None + warped_img0 = warp(img0, flow0) + warped_img1 = warp(img1, flow1) + f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) + + out = self.convblock(f_in) + delta_flow0, delta_flow1, delta_mask, img_res = torch.split( + out, [2 * n, 2 * n, n, 3 * n], 1 + ) + mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1) + mask = torch.sigmoid(mask) + flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res diff --git a/gimm_vfi_arch/generalizable_INR/modules/fi_utils.py b/gimm_vfi_arch/generalizable_INR/modules/fi_utils.py new file mode 100644 index 0000000..163fb0e --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/fi_utils.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# raft: https://github.com/princeton-vl/RAFT +# ema-vfi: https://github.com/MCG-NJU/EMA-VFI +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F + +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +def normalize_flow(flows): + # FIXME: MULTI-DIMENSION + flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape( + -1, 1, 1, 1, 1 + ) + flows = flows / flow_scaler # [-1,1] + # # Adapt to [0,1] + flows = (flows + 1.0) / 2.0 + return flows, flow_scaler + + +def unnormalize_flow(flows, flow_scaler): + return (flows * 2.0 - 1.0) * flow_scaler + + +def resize(x, scale_factor): + return F.interpolate( + x, scale_factor=scale_factor, mode="bilinear", align_corners=False + ) + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def build_coord(img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8) + return coords diff --git a/gimm_vfi_arch/generalizable_INR/modules/hyponet.py b/gimm_vfi_arch/generalizable_INR/modules/hyponet.py new file mode 100644 index 0000000..7531027 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/hyponet.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import einops + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from ..configs import HypoNetConfig +from .utils import create_params_with_init, create_activation + + +class HypoNet(nn.Module): + r""" + The Hyponetwork with a coordinate-based MLP to be modulated. + """ + + def __init__(self, config: HypoNetConfig, add_coord_dim=32): + super().__init__() + self.config = config + self.use_bias = config.use_bias + self.init_config = config.initialization + self.num_layer = config.n_layer + self.hidden_dims = config.hidden_dim + self.add_coord_dim = add_coord_dim + + if len(self.hidden_dims) == 1: + self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * ( + self.num_layer - 1 + ) # exclude output layer + else: + assert len(self.hidden_dims) == self.num_layer - 1 + + if self.config.activation.type == "siren": + assert self.init_config.weight_init_type == "siren" + assert self.init_config.bias_init_type == "siren" + + # after computes the shape of trainable parameters, initialize them + self.params_dict = None + self.params_shape_dict = self.compute_params_shape() + self.activation = create_activation(self.config.activation) + self.build_base_params_dict(self.config.initialization) + self.output_bias = config.output_bias + + self.normalize_weight = config.normalize_weight + + self.ignore_base_param_dict = {name: False for name in self.params_dict} + + @staticmethod + def subsample_coords(coords, subcoord_idx=None): + if subcoord_idx is None: + return coords + + batch_size = coords.shape[0] + sub_coords = [] + coords = coords.view(batch_size, -1, coords.shape[-1]) + for idx in range(batch_size): + sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]]) + sub_coords = torch.cat(sub_coords, dim=0) + return sub_coords + + def forward(self, coord, modulation_params_dict=None, pixel_latent=None): + sub_idx = None + if isinstance(coord, tuple): + coord, sub_idx = coord[0], coord[1] + + if modulation_params_dict is not None: + self.check_valid_param_keys(modulation_params_dict) + + batch_size, coord_shape, input_dim = ( + coord.shape[0], + coord.shape[1:-1], + coord.shape[-1], + ) + coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates + assert pixel_latent is not None + pixel_latent = F.interpolate( + pixel_latent.permute(0, 3, 1, 2), + size=(coord_shape[1], coord_shape[2]), + mode="bilinear", + ).permute(0, 2, 3, 1) + pixel_latent_dim = pixel_latent.shape[-1] + pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim) + hidden = coord + + hidden = torch.cat([pixel_latent, hidden], dim=-1) + + hidden = self.subsample_coords(hidden, sub_idx) + + for idx in range(self.config.n_layer): + param_key = f"linear_wb{idx}" + base_param = einops.repeat( + self.params_dict[param_key], "n m -> b n m", b=batch_size + ) + + if (modulation_params_dict is not None) and ( + param_key in modulation_params_dict.keys() + ): + modulation_param = modulation_params_dict[param_key] + else: + if self.config.use_bias: + modulation_param = torch.ones_like(base_param[:, :-1]) + else: + modulation_param = torch.ones_like(base_param) + + if self.config.use_bias: + ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device) + hidden = torch.cat([hidden, ones], dim=-1) + + base_param_w, base_param_b = ( + base_param[:, :-1, :], + base_param[:, -1:, :], + ) + + if self.ignore_base_param_dict[param_key]: + base_param_w = 1.0 + param_w = base_param_w * modulation_param + if self.normalize_weight: + param_w = F.normalize(param_w, dim=1) + modulated_param = torch.cat([param_w, base_param_b], dim=1) + else: + if self.ignore_base_param_dict[param_key]: + base_param = 1.0 + if self.normalize_weight: + modulated_param = F.normalize(base_param * modulation_param, dim=1) + else: + modulated_param = base_param * modulation_param + # print([param_key,hidden.shape,modulated_param.shape]) + hidden = torch.bmm(hidden, modulated_param) + + if idx < (self.config.n_layer - 1): + hidden = self.activation(hidden) + + outputs = hidden + self.output_bias + if sub_idx is None: + outputs = outputs.view(batch_size, *coord_shape, -1) + return outputs + + def compute_params_shape(self): + """ + Computes the shape of MLP parameters. + The computed shapes are used to build the initial weights by `build_base_params_dict`. + """ + config = self.config + use_bias = self.use_bias + + param_shape_dict = dict() + + fan_in = config.input_dim + add_dim = self.add_coord_dim + fan_in = fan_in + add_dim + fan_in = fan_in + 1 if use_bias else fan_in + + for i in range(config.n_layer - 1): + fan_out = self.hidden_dims[i] + param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out) + fan_in = fan_out + 1 if use_bias else fan_out + + param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim) + return param_shape_dict + + def build_base_params_dict(self, init_config): + assert self.params_shape_dict + params_dict = nn.ParameterDict() + for idx, (name, shape) in enumerate(self.params_shape_dict.items()): + is_first = idx == 0 + params = create_params_with_init( + shape, + init_type=init_config.weight_init_type, + include_bias=self.use_bias, + bias_init_type=init_config.bias_init_type, + is_first=is_first, + siren_w0=self.config.activation.siren_w0, # valid only for siren + ) + params = nn.Parameter(params) + params_dict[name] = params + self.set_params_dict(params_dict) + + def check_valid_param_keys(self, params_dict): + predefined_params_keys = self.params_shape_dict.keys() + for param_key in params_dict.keys(): + if param_key in predefined_params_keys: + continue + else: + raise KeyError + + def set_params_dict(self, params_dict): + self.check_valid_param_keys(params_dict) + self.params_dict = params_dict diff --git a/gimm_vfi_arch/generalizable_INR/modules/layers.py b/gimm_vfi_arch/generalizable_INR/modules/layers.py new file mode 100644 index 0000000..515de38 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/layers.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- + +from torch import nn +import torch + + +# define siren layer & Siren model +class Sine(nn.Module): + """Sine activation with scaling. + + Args: + w0 (float): Omega_0 parameter from SIREN paper. + """ + + def __init__(self, w0=1.0): + super().__init__() + self.w0 = w0 + + def forward(self, x): + return torch.sin(self.w0 * x) + + +# Damping activation from http://arxiv.org/abs/2306.15242 +class Damping(nn.Module): + """Sine activation with sublinear factor + + Args: + w0 (float): Omega_0 parameter from SIREN paper. + """ + + def __init__(self, w0=1.0): + super().__init__() + self.w0 = w0 + + def forward(self, x): + x = torch.clamp(x, min=1e-30) + return torch.sin(self.w0 * x) * torch.sqrt(x.abs()) diff --git a/gimm_vfi_arch/generalizable_INR/modules/module_config.py b/gimm_vfi_arch/generalizable_INR/modules/module_config.py new file mode 100644 index 0000000..a117db2 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/module_config.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +from typing import List, Optional +from dataclasses import dataclass, field +from omegaconf import MISSING + + +@dataclass +class HypoNetActivationConfig: + type: str = "relu" + siren_w0: Optional[float] = 30.0 + + +@dataclass +class HypoNetInitConfig: + weight_init_type: Optional[str] = "kaiming_uniform" + bias_init_type: Optional[str] = "zero" + + +@dataclass +class HypoNetConfig: + type: str = "mlp" + n_layer: int = 5 + hidden_dim: List[int] = MISSING + use_bias: bool = True + input_dim: int = 2 + output_dim: int = 3 + output_bias: float = 0.5 + activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig) + initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig) + + normalize_weight: bool = True + linear_interpo: bool = False + + +@dataclass +class CoordSamplerConfig: + data_type: str = "image" + t_coord_only: bool = False + coord_range: List[float] = MISSING + time_range: List[float] = MISSING + train_strategy: Optional[str] = MISSING + val_strategy: Optional[str] = MISSING + patch_size: Optional[int] = MISSING diff --git a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py new file mode 100644 index 0000000..415fc51 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py @@ -0,0 +1,672 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# softmax-splatting: https://github.com/sniklaus/softmax-splatting +# -------------------------------------------------------- + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): + if "device" not in objCudacache: + objCudacache["device"] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + strKey += objCudacache["device"] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace("{{type}}", "unsigned char") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace("{{type}}", "half") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace("{{type}}", "float") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace("{{type}}", "double") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace("{{type}}", "int") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace("{{type}}", "long") + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert False + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + while True: + objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace( + objMatch.group(), + str( + intSizes[intArg] + if torch.is_tensor(intSizes[intArg]) == False + else intSizes[intArg].item() + ), + ) + # end + + while True: + objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) == False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + "(" + str.join("+", strIndex) + ")", + ) + # end + + while True: + objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) == False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + strTensor + "[" + str.join("+", strIndex) + "]", + ) + # end + + objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} + # end + + return strKey + + +# end + + +@cupy.memoize(for_each_device=True) +@torch.compiler.disable() +def cuda_launch(strKey: str): + try: + os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path()) + except Exception: + if "CUDA_HOME" not in os.environ: + raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.") + + strKernel = objCudacache[strKey]["strKernel"] + strFunction = objCudacache[strKey]["strFunction"] + + return cupy.RawModule( + code=strKernel, + options=( + "-I " + os.environ["CUDA_HOME"], + "-I " + os.environ["CUDA_HOME"] + "/include", + ), + ).get_function(strFunction) + + + +########################################################## + + +@torch.compiler.disable() +def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False): + assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"] + + if strMode == "sum": + assert tenMetric is None + if strMode == "avg": + assert tenMetric is None + if strMode.split("-")[0] == "linear": + assert tenMetric is not None + if strMode.split("-")[0] == "softmax": + assert tenMetric is not None + + if strMode == "avg": + tenIn = torch.cat( + [ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]), + ], + 1, + ) + + elif strMode.split("-")[0] == "linear": + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split("-")[0] == "softmax": + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + if torch.isnan(tenIn).any(): + print("NaN values detected during training in tenIn. Exiting.") + assert False + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if torch.isnan(tenOut).any(): + print("NaN values detected during training in tenOut_1. Exiting.") + assert False + + if strMode.split("-")[0] in ["avg", "linear", "softmax"]: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split("-")) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "addeps": + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "zeroeps": + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split("-")[1] == "clipeps": + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + if return_norm: + return tenOut[:, :-1, :, :], tenNormalize + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + + if torch.isnan(tenOut).any(): + print("NaN values detected during training in tenOut_2. Exiting.") + assert False + + # end + + return tenOut + + +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + + if tenIn.is_cuda == True: + cuda_launch( + cuda_kernel( + "softsplat_out", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + """, + {"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut}, + ) + )( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOut.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + + elif tenIn.is_cuda != True: + assert False + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.compiler.disable() + @torch.amp.custom_bwd(device_type="cuda") + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda == True + + tenIngrad = ( + tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + tenFlowgrad = ( + tenFlow.new_zeros( + [tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_ingrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), + None, + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + if tenFlowgrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_flowgrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenFlowgrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + None, + tenFlowgrad.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + return tenIngrad, tenFlowgrad + + # end + + +# end diff --git a/gimm_vfi_arch/generalizable_INR/modules/utils.py b/gimm_vfi_arch/generalizable_INR/modules/utils.py new file mode 100644 index 0000000..7f65480 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/modules/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ginr-ipc: https://github.com/kakaobrain/ginr-ipc +# -------------------------------------------------------- + +import math +import torch +import torch.nn as nn + +from .layers import Sine, Damping + + +def convert_int_to_list(size, len_list=2): + if isinstance(size, int): + return [size] * len_list + else: + assert len(size) == len_list + return size + + +def initialize_params(params, init_type, **kwargs): + fan_in, fan_out = params.shape[0], params.shape[1] + if init_type is None or init_type == "normal": + nn.init.normal_(params) + elif init_type == "kaiming_uniform": + nn.init.kaiming_uniform_(params, a=math.sqrt(5)) + elif init_type == "uniform_fan_in": + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(params, -bound, bound) + elif init_type == "zero": + nn.init.zeros_(params) + elif "siren" == init_type: + assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys() + w0 = kwargs["siren_w0"] + if kwargs["is_first"]: + w_std = 1 / fan_in + else: + w_std = math.sqrt(6.0 / fan_in) / w0 + nn.init.uniform_(params, -w_std, w_std) + else: + raise NotImplementedError + + +def create_params_with_init( + shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs +): + if not include_bias: + params = torch.empty([shape[0], shape[1]]) + initialize_params(params, init_type, **kwargs) + return params + else: + params = torch.empty([shape[0] - 1, shape[1]]) + bias = torch.empty([1, shape[1]]) + + initialize_params(params, init_type, **kwargs) + initialize_params(bias, bias_init_type, **kwargs) + return torch.cat([params, bias], dim=0) + + +def create_activation(config): + if config.type == "relu": + activation = nn.ReLU() + elif config.type == "siren": + activation = Sine(config.siren_w0) + elif config.type == "silu": + activation = nn.SiLU() + elif config.type == "damp": + activation = Damping(config.siren_w0) + else: + raise NotImplementedError + return activation diff --git a/gimm_vfi_arch/generalizable_INR/raft/__init__.py b/gimm_vfi_arch/generalizable_INR/raft/__init__.py new file mode 100644 index 0000000..163f56a --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/__init__.py @@ -0,0 +1 @@ +from .raft import RAFT diff --git a/gimm_vfi_arch/generalizable_INR/raft/corr.py b/gimm_vfi_arch/generalizable_INR/raft/corr.py new file mode 100644 index 0000000..c5d678b --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/corr.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# amt: https://github.com/MCG-NKU/AMT +# raft: https://github.com/princeton-vl/RAFT +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert ( + coords0.shape == coords1.shape + ), f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + + centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return ( + out.permute(0, 3, 1, 2).contiguous().float(), + out_T.permute(0, 3, 1, 2).contiguous().float(), + ) + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/gimm_vfi_arch/generalizable_INR/raft/extractor.py b/gimm_vfi_arch/generalizable_INR/raft/extractor.py new file mode 100644 index 0000000..7ac21b5 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/extractor.py @@ -0,0 +1,293 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0, only_feat=False): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + self.only_feat = only_feat + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + if not self.only_feat: + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, return_feature=False, mif=False): + features = [] + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x_2 = F.interpolate(x, scale_factor=1 / 2, mode="bilinear", align_corners=False) + x_4 = F.interpolate(x, scale_factor=1 / 4, mode="bilinear", align_corners=False) + + def f1(feat): + feat = self.conv1(feat) + feat = self.norm1(feat) + feat = self.relu1(feat) + feat = self.layer1(feat) + return feat + + x = f1(x) + features.append(x) + x = self.layer2(x) + if mif: + x_2_2 = f1(x_2) + features.append(torch.cat([x, x_2_2], dim=1)) + else: + features.append(x) + x = self.layer3(x) + if mif: + x_2_4 = self.layer2(x_2_2) + x_4_4 = f1(x_4) + features.append(torch.cat([x, x_2_4, x_4_4], dim=1)) + else: + features.append(x) + + if not self.only_feat: + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + features = [torch.split(f, [batch_dim, batch_dim], dim=0) for f in features] + if return_feature: + return x, features + else: + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/gimm_vfi_arch/generalizable_INR/raft/other_raft.py b/gimm_vfi_arch/generalizable_INR/raft/other_raft.py new file mode 100644 index 0000000..77a02fe --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/other_raft.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import BidirCorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +# BiRAFT +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + self.corr_levels = 4 + self.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + self.corr_levels = 4 + self.corr_radius = 4 + + if "dropout" not in args._get_kwargs(): + self.args.dropout = 0 + + if "alternate_corr" not in args._get_kwargs(): + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def build_coord(self, img): + N, C, H, W = img.shape + coords = coords_grid(N, H // 8, W // 8, device=img.device) + return coords + + def initialize_flow(self, img, img2): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + assert img.shape == img2.shape + N, C, H, W = img.shape + coords01 = coords_grid(N, H // 8, W // 8, device=img.device) + coords02 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + coords2 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords01, coords02, coords1, coords2 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def get_corr_fn(self, image1, image2, projector=None): + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmaps, feats = self.fnet([image1, image2], return_feature=True) + fmap1, fmap2 = fmaps + fmap1 = fmap1.float() + fmap2 = fmap2.float() + corr_fn1 = None + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + if projector is not None: + corr_fn1 = AlternateCorrBlock( + projector(feats[-1][0]), + projector(feats[-1][1]), + radius=self.args.corr_radius, + ) + else: + corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + if projector is not None: + corr_fn1 = BidirCorrBlock( + projector(feats[-1][0]), + projector(feats[-1][1]), + radius=self.args.corr_radius, + ) + if corr_fn1 is None: + return corr_fn, corr_fn + else: + return corr_fn, corr_fn1 + + def get_corr_fn_from_feat(self, fmap1, fmap2): + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + return corr_fn + + def forward( + self, + image1, + image2, + iters=12, + flow_init=None, + upsample=True, + test_mode=False, + corr_fn=None, + mif=False, + ): + """Estimate optical flow between pair of frames""" + assert flow_init is None + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + if corr_fn is None: + corr_fn, _ = self.get_corr_fn(image1, image2) + + # # run the feature network + # with autocast(enabled=self.args.mixed_precision): + # fmap1, fmap2 = self.fnet([image1, image2]) + + # fmap1 = fmap1.float() + # fmap2 = fmap2.float() + # if self.args.alternate_corr: + # corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + # else: + # corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + # for image1 + cnet1, features1 = self.cnet(image1, return_feature=True, mif=mif) + net1, inp1 = torch.split(cnet1, [hdim, cdim], dim=1) + net1 = torch.tanh(net1) + inp1 = torch.relu(inp1) + # for image2 + cnet2, features2 = self.cnet(image2, return_feature=True, mif=mif) + net2, inp2 = torch.split(cnet2, [hdim, cdim], dim=1) + net2 = torch.tanh(net2) + inp2 = torch.relu(inp2) + + coords01, coords02, coords1, coords2 = self.initialize_flow(image1, image2) + + # if flow_init is not None: + # coords1 = coords1 + flow_init + + # flow_predictions1 = [] + # flow_predictions2 = [] + for itr in range(iters): + coords1 = coords1.detach() + coords2 = coords2.detach() + corr1, corr2 = corr_fn(coords1, coords2) # index correlation volume + + flow1 = coords1 - coords01 + flow2 = coords2 - coords02 + + with autocast(enabled=self.args.mixed_precision): + net1, up_mask1, delta_flow1 = self.update_block( + net1, inp1, corr1, flow1 + ) + net2, up_mask2, delta_flow2 = self.update_block( + net2, inp2, corr2, flow2 + ) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow1 + coords2 = coords2 + delta_flow2 + flow_low1 = coords1 - coords01 + flow_low2 = coords2 - coords02 + # upsample predictions + if up_mask1 is None: + flow_up1 = upflow8(coords1 - coords01) + flow_up2 = upflow8(coords2 - coords02) + else: + flow_up1 = self.upsample_flow(coords1 - coords01, up_mask1) + flow_up2 = self.upsample_flow(coords2 - coords02, up_mask2) + + # flow_predictions.append(flow_up) + return flow_up1, flow_up2, flow_low1, flow_low2, features1, features2 + # if test_mode: + # return coords1 - coords0, flow_up + + # return flow_predictions diff --git a/gimm_vfi_arch/generalizable_INR/raft/raft.py b/gimm_vfi_arch/generalizable_INR/raft/raft.py new file mode 100644 index 0000000..75d9d02 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/raft.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + self.corr_levels = 4 + self.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + self.corr_levels = 4 + self.corr_radius = 4 + + if "dropout" not in args._get_kwargs(): + self.args.dropout = 0 + + if "alternate_corr" not in args._get_kwargs(): + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, + image1, + image2, + iters=12, + flow_init=None, + upsample=True, + test_mode=False, + return_feat=True, + ): + """Estimate optical flow between pair of frames""" + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet, feats = self.cnet(image1, return_feature=True) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + if return_feat: + return flow_up, feats[1:], fmap1 + + return flow_predictions diff --git a/gimm_vfi_arch/generalizable_INR/raft/update.py b/gimm_vfi_arch/generalizable_INR/raft/update.py new file mode 100644 index 0000000..90abe56 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/update.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow.to(inp), corr.to(inp)) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/gimm_vfi_arch/generalizable_INR/raft/utils/__init__.py b/gimm_vfi_arch/generalizable_INR/raft/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/generalizable_INR/raft/utils/utils.py b/gimm_vfi_arch/generalizable_INR/raft/utils/utils.py new file mode 100644 index 0000000..781dc55 --- /dev/null +++ b/gimm_vfi_arch/generalizable_INR/raft/utils/utils.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 + ) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 + ) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid( + torch.arange(ht, device=device), torch.arange(wd, device=device) + ) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/gimm_vfi_arch/utils/__init__.py b/gimm_vfi_arch/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gimm_vfi_arch/utils/utils.py b/gimm_vfi_arch/utils/utils.py new file mode 100644 index 0000000..981bff6 --- /dev/null +++ b/gimm_vfi_arch/utils/utils.py @@ -0,0 +1,52 @@ +from easydict import EasyDict as edict +import torch.nn.functional as F + +class InputPadder: + """Pads images such that dimensions are divisible by divisor""" + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode="replicate") + else: + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] +def easydict_to_dict(obj): + if not isinstance(obj, edict): + return obj + else: + return {k: easydict_to_dict(v) for k, v in obj.items()} + + +class RaftArgs: + def __init__(self, small, mixed_precision, alternate_corr): + self.small = small + self.mixed_precision = mixed_precision + self.alternate_corr = alternate_corr + + def _get_kwargs(self): + return { + "small": self.small, + "mixed_precision": self.mixed_precision, + "alternate_corr": self.alternate_corr + } \ No newline at end of file diff --git a/inference.py b/inference.py index 33cf372..c9d71c9 100644 --- a/inference.py +++ b/inference.py @@ -441,3 +441,183 @@ class SGMVFIModel: pred = self._inference(img0, img1, timestep=time_step) pred = padder.unpad(pred) return torch.clamp(pred, 0, 1) + + +# --------------------------------------------------------------------------- +# GIMM-VFI model wrapper +# --------------------------------------------------------------------------- + +class GIMMVFIModel: + """Clean inference wrapper around GIMM-VFI for ComfyUI integration. + + Supports two modes: + - interpolate_batch(): standard single-midpoint interface (compatible with + recursive _interpolate_frames machinery used by other models) + - interpolate_multi(): GIMM-VFI's unique single-pass mode, generates all + N-1 intermediate frames between each pair in one forward pass + """ + + def __init__(self, checkpoint_path, flow_checkpoint_path, variant="auto", + ds_factor=1.0, device="cpu"): + import os + import yaml + from omegaconf import OmegaConf + from .gimm_vfi_arch import ( + GIMMVFI_R, GIMMVFI_F, GIMMVFIConfig, + GIMM_RAFT, GIMM_FlowFormer, gimm_get_flowformer_cfg, + GIMMInputPadder, GIMMRaftArgs, easydict_to_dict, + ) + import comfy.utils + + self.ds_factor = ds_factor + self.device = device + self._InputPadder = GIMMInputPadder + + filename = os.path.basename(checkpoint_path).lower() + + # Detect variant from filename + if variant == "auto": + self.is_flowformer = "gimmvfi_f" in filename + else: + self.is_flowformer = (variant == "flowformer") + + self.variant_name = "flowformer" if self.is_flowformer else "raft" + + # Load config + script_dir = os.path.dirname(os.path.abspath(__file__)) + if self.is_flowformer: + config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_f_arb.yaml") + else: + config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_r_arb.yaml") + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + config = easydict_to_dict(config) + config = OmegaConf.create(config) + arch_defaults = GIMMVFIConfig.create(config.arch) + config = OmegaConf.merge(arch_defaults, config.arch) + + # Build model + flow estimator + dtype = torch.float32 + + if self.is_flowformer: + self.model = GIMMVFI_F(dtype, config) + cfg = gimm_get_flowformer_cfg() + flow_estimator = GIMM_FlowFormer(cfg.latentcostformer) + flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path) + flow_estimator.load_state_dict(flow_sd, strict=True) + else: + self.model = GIMMVFI_R(dtype, config) + raft_args = GIMMRaftArgs(small=False, mixed_precision=False, alternate_corr=False) + flow_estimator = GIMM_RAFT(raft_args) + flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path) + flow_estimator.load_state_dict(flow_sd, strict=True) + + # Load main model weights + sd = comfy.utils.load_torch_file(checkpoint_path) + self.model.load_state_dict(sd, strict=False) + + self.model.flow_estimator = flow_estimator + self.model.eval() + + def to(self, device): + """Move model to device (returns self for chaining).""" + self.device = device if isinstance(device, str) else str(device) + self.model.to(device) + return self + + @torch.no_grad() + def interpolate_batch(self, frames0, frames1, time_step=0.5): + """Interpolate a single midpoint frame per pair (standard interface). + + Args: + frames0: [B, C, H, W] tensor, float32, range [0, 1] + frames1: [B, C, H, W] tensor, float32, range [0, 1] + time_step: float in (0, 1) + + Returns: + Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + results = [] + + for i in range(frames0.shape[0]): + I0 = frames0[i:i+1].to(device) + I2 = frames1[i:i+1].to(device) + + padder = self._InputPadder(I0.shape, 32) + I0_p, I2_p = padder.pad(I0, I2) + + xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2) + batch_size = xs.shape[0] + s_shape = xs.shape[-2:] + + coord_inputs = [( + self.model.sample_coord_input( + batch_size, s_shape, [time_step], + device=xs.device, upsample_ratio=self.ds_factor, + ), + None, + )] + timesteps = [ + time_step * torch.ones(xs.shape[0]).to(xs.device) + ] + + all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor) + pred = padder.unpad(all_outputs["imgt_pred"][0]) + results.append(torch.clamp(pred, 0, 1)) + + return torch.cat(results, dim=0) + + @torch.no_grad() + def interpolate_multi(self, frame0, frame1, num_intermediates): + """Generate all intermediate frames between a pair in one forward pass. + + This is GIMM-VFI's unique capability -- arbitrary timestep interpolation + without recursive 2x passes. + + Args: + frame0: [1, C, H, W] tensor, float32, range [0, 1] + frame1: [1, C, H, W] tensor, float32, range [0, 1] + num_intermediates: int, number of intermediate frames to generate + + Returns: + List of [1, C, H, W] tensors, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + I0 = frame0.to(device) + I2 = frame1.to(device) + + padder = self._InputPadder(I0.shape, 32) + I0_p, I2_p = padder.pad(I0, I2) + + xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2) + batch_size = xs.shape[0] + s_shape = xs.shape[-2:] + interp_factor = num_intermediates + 1 + + coord_inputs = [ + ( + self.model.sample_coord_input( + batch_size, s_shape, + [1.0 / interp_factor * i], + device=xs.device, + upsample_ratio=self.ds_factor, + ), + None, + ) + for i in range(1, interp_factor) + ] + timesteps = [ + i * 1.0 / interp_factor * torch.ones(xs.shape[0]).to(xs.device) + for i in range(1, interp_factor) + ] + + all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor) + + results = [] + for pred in all_outputs["imgt_pred"]: + unpadded = padder.unpad(pred) + results.append(torch.clamp(unpadded, 0, 1)) + + return results diff --git a/nodes.py b/nodes.py index 5308435..a25ae3f 100644 --- a/nodes.py +++ b/nodes.py @@ -8,10 +8,11 @@ import torch import folder_paths from comfy.utils import ProgressBar -from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel +from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel from .bim_vfi_arch import clear_backwarp_cache from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache +from .gimm_vfi_arch import clear_gimm_caches logger = logging.getLogger("Tween") @@ -40,6 +41,17 @@ SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi") if not os.path.exists(SGM_MODEL_DIR): os.makedirs(SGM_MODEL_DIR, exist_ok=True) +# GIMM-VFI +GIMM_HF_REPO = "Kijai/GIMM-VFI_safetensors" +GIMM_AVAILABLE_MODELS = [ + "gimmvfi_r_arb_lpips_fp32.safetensors", + "gimmvfi_f_arb_lpips_fp32.safetensors", +] + +GIMM_MODEL_DIR = os.path.join(folder_paths.models_dir, "gimm-vfi") +if not os.path.exists(GIMM_MODEL_DIR): + os.makedirs(GIMM_MODEL_DIR, exist_ok=True) + def get_available_models(): """List available checkpoint files in the bim-vfi model directory.""" @@ -1113,3 +1125,385 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate): result = result[1:] # skip duplicate boundary frame return (result, model) + + +# --------------------------------------------------------------------------- +# GIMM-VFI nodes +# --------------------------------------------------------------------------- + +def get_available_gimm_models(): + """List available GIMM-VFI checkpoint files in the gimm-vfi model directory.""" + models = [] + if os.path.isdir(GIMM_MODEL_DIR): + for f in os.listdir(GIMM_MODEL_DIR): + if f.endswith((".safetensors", ".pth", ".pt", ".ckpt")): + # Exclude flow estimator checkpoints from the model list + if f.startswith(("raft-", "flowformer_")): + continue + models.append(f) + if not models: + models = list(GIMM_AVAILABLE_MODELS) + return sorted(models) + + +def download_gimm_model(filename, dest_dir): + """Download a GIMM-VFI file from HuggingFace.""" + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise RuntimeError( + "huggingface_hub is required to auto-download GIMM-VFI models. " + "Install it with: pip install huggingface_hub" + ) + logger.info(f"Downloading {filename} from HuggingFace ({GIMM_HF_REPO})...") + hf_hub_download( + repo_id=GIMM_HF_REPO, + filename=filename, + local_dir=dest_dir, + local_dir_use_symlinks=False, + ) + dest_path = os.path.join(dest_dir, filename) + if not os.path.exists(dest_path): + raise RuntimeError(f"Failed to download {filename} to {dest_path}") + logger.info(f"Downloaded {filename}") + + +class LoadGIMMVFIModel: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_path": (get_available_gimm_models(), { + "default": GIMM_AVAILABLE_MODELS[0], + "tooltip": "Checkpoint file from models/gimm-vfi/. Auto-downloads from HuggingFace on first use. " + "RAFT variant (~80MB) or FlowFormer variant (~123MB) auto-detected from filename.", + }), + "ds_factor": ("FLOAT", { + "default": 1.0, "min": 0.125, "max": 1.0, "step": 0.125, + "tooltip": "Downscale factor for internal processing. 1.0 = full resolution. " + "Lower values reduce VRAM usage and speed up inference at the cost of quality. " + "Try 0.5 for 4K inputs.", + }), + } + } + + RETURN_TYPES = ("GIMM_VFI_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "video/GIMM-VFI" + + def load_model(self, model_path, ds_factor): + full_path = os.path.join(GIMM_MODEL_DIR, model_path) + + # Auto-download main model if missing + if not os.path.exists(full_path): + logger.info(f"Model not found at {full_path}, attempting download...") + download_gimm_model(model_path, GIMM_MODEL_DIR) + + # Detect and download matching flow estimator + if "gimmvfi_f" in model_path.lower(): + flow_filename = "flowformer_sintel_fp32.safetensors" + else: + flow_filename = "raft-things_fp32.safetensors" + + flow_path = os.path.join(GIMM_MODEL_DIR, flow_filename) + if not os.path.exists(flow_path): + logger.info(f"Flow estimator not found, downloading {flow_filename}...") + download_gimm_model(flow_filename, GIMM_MODEL_DIR) + + wrapper = GIMMVFIModel( + checkpoint_path=full_path, + flow_checkpoint_path=flow_path, + variant="auto", + ds_factor=ds_factor, + device="cpu", + ) + + logger.info(f"GIMM-VFI model loaded (variant={wrapper.variant_name}, ds_factor={ds_factor})") + return (wrapper,) + + +class GIMMVFIInterpolate: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE", { + "tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).", + }), + "model": ("GIMM_VFI_MODEL", { + "tooltip": "GIMM-VFI model from the Load GIMM-VFI Model node.", + }), + "multiplier": ([2, 4, 8], { + "default": 2, + "tooltip": "Frame rate multiplier. In single-pass mode, all intermediate frames are generated " + "in one forward pass per pair. In recursive mode, uses 2x passes like other models.", + }), + "single_pass": ("BOOLEAN", { + "default": True, + "tooltip": "Use GIMM-VFI's single-pass arbitrary-timestep mode. Generates all intermediate frames " + "per pair in one forward pass (no recursive 2x passes). Disable to use the standard " + "recursive approach (same as BIM/EMA/SGM).", + }), + "clear_cache_after_n_frames": ("INT", { + "default": 10, "min": 1, "max": 100, "step": 1, + "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", + }), + "keep_device": ("BOOLEAN", { + "default": True, + "tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).", + }), + "all_on_gpu": ("BOOLEAN", { + "default": False, + "tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.", + }), + "batch_size": ("INT", { + "default": 1, "min": 1, "max": 64, "step": 1, + "tooltip": "Number of frame pairs to process simultaneously in recursive mode. Ignored in single-pass mode (pairs are processed one at a time since each generates multiple frames).", + }), + "chunk_size": ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "interpolate" + CATEGORY = "video/GIMM-VFI" + + def _interpolate_frames_single_pass(self, frames, model, multiplier, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref): + """Single-pass interpolation using GIMM-VFI's arbitrary timestep capability.""" + num_intermediates = multiplier - 1 + new_frames = [] + num_pairs = frames.shape[0] - 1 + pairs_since_clear = 0 + + for i in range(num_pairs): + frame0 = frames[i:i+1] + frame1 = frames[i+1:i+2] + + if not keep_device: + model.to(device) + + mids = model.interpolate_multi(frame0, frame1, num_intermediates) + mids = [m.to(storage_device) for m in mids] + + if not keep_device: + model.to("cpu") + + new_frames.append(frames[i:i+1]) + for m in mids: + new_frames.append(m) + + step_ref[0] += 1 + pbar.update_absolute(step_ref[0]) + + pairs_since_clear += 1 + if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): + clear_gimm_caches() + torch.cuda.empty_cache() + pairs_since_clear = 0 + + new_frames.append(frames[-1:]) + result = torch.cat(new_frames, dim=0) + + if not all_on_gpu and torch.cuda.is_available(): + clear_gimm_caches() + torch.cuda.empty_cache() + + return result + + def _interpolate_frames(self, frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref): + """Recursive 2x interpolation (standard approach, same as other models).""" + for pass_idx in range(num_passes): + new_frames = [] + num_pairs = frames.shape[0] - 1 + pairs_since_clear = 0 + + for i in range(0, num_pairs, batch_size): + batch_end = min(i + batch_size, num_pairs) + actual_batch = batch_end - i + + frames0 = frames[i:batch_end] + frames1 = frames[i + 1:batch_end + 1] + + if not keep_device: + model.to(device) + + mids = model.interpolate_batch(frames0, frames1, time_step=0.5) + mids = mids.to(storage_device) + + if not keep_device: + model.to("cpu") + + for j in range(actual_batch): + new_frames.append(frames[i + j:i + j + 1]) + new_frames.append(mids[j:j+1]) + + step_ref[0] += actual_batch + pbar.update_absolute(step_ref[0]) + + pairs_since_clear += actual_batch + if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): + clear_gimm_caches() + torch.cuda.empty_cache() + pairs_since_clear = 0 + + new_frames.append(frames[-1:]) + frames = torch.cat(new_frames, dim=0) + + if not all_on_gpu and torch.cuda.is_available(): + clear_gimm_caches() + torch.cuda.empty_cache() + + return frames + + @staticmethod + def _count_steps(num_frames, num_passes): + """Count total interpolation steps for recursive mode.""" + n = num_frames + total = 0 + for _ in range(num_passes): + total += n - 1 + n = 2 * n - 1 + return total + + def interpolate(self, images, model, multiplier, single_pass, + clear_cache_after_n_frames, keep_device, all_on_gpu, + batch_size, chunk_size): + if images.shape[0] < 2: + return (images,) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if not single_pass: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + + if all_on_gpu: + keep_device = True + + storage_device = device if all_on_gpu else torch.device("cpu") + + # Convert from ComfyUI [B, H, W, C] to model [B, C, H, W] + all_frames = images.permute(0, 3, 1, 2).to(storage_device) + total_input = all_frames.shape[0] + + # Build chunk boundaries (1-frame overlap between consecutive chunks) + if chunk_size < 2 or chunk_size >= total_input: + chunks = [(0, total_input)] + else: + chunks = [] + start = 0 + while start < total_input - 1: + end = min(start + chunk_size, total_input) + chunks.append((start, end)) + start = end - 1 # overlap by 1 frame + if end == total_input: + break + + # Calculate total progress steps across all chunks + if single_pass: + total_steps = sum(ce - cs - 1 for cs, ce in chunks) + else: + total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks) + pbar = ProgressBar(total_steps) + step_ref = [0] + + if keep_device: + model.to(device) + + result_chunks = [] + for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks): + chunk_frames = all_frames[chunk_start:chunk_end].clone() + + if single_pass: + chunk_result = self._interpolate_frames_single_pass( + chunk_frames, model, multiplier, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + else: + chunk_result = self._interpolate_frames( + chunk_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + + # Skip first frame of subsequent chunks (duplicate of previous chunk's last frame) + if chunk_idx > 0: + chunk_result = chunk_result[1:] + + # Move completed chunk to CPU to bound memory when chunking + if len(chunks) > 1: + chunk_result = chunk_result.cpu() + + result_chunks.append(chunk_result) + + result = torch.cat(result_chunks, dim=0) + # Convert back to ComfyUI [B, H, W, C], on CPU + result = result.cpu().permute(0, 2, 3, 1) + return (result,) + + +class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): + """Process a numbered segment of the input batch for GIMM-VFI. + + Chain multiple instances with Save nodes between them to bound peak RAM. + The model pass-through output forces sequential execution so each segment + saves and frees from RAM before the next starts. + """ + + @classmethod + def INPUT_TYPES(cls): + base = GIMMVFIInterpolate.INPUT_TYPES() + base["required"]["segment_index"] = ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, " + "unlike chunk_size which bounds VRAM but still assembles the full output in RAM. " + "Chain the model output to the next Segment Interpolate to force sequential execution.", + }) + base["required"]["segment_size"] = ("INT", { + "default": 500, "min": 2, "max": 10000, "step": 1, + "tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. " + "Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.", + }) + return base + + RETURN_TYPES = ("IMAGE", "GIMM_VFI_MODEL") + RETURN_NAMES = ("images", "model") + FUNCTION = "interpolate" + CATEGORY = "video/GIMM-VFI" + + def interpolate(self, images, model, multiplier, single_pass, + clear_cache_after_n_frames, keep_device, all_on_gpu, + batch_size, chunk_size, segment_index, segment_size): + total_input = images.shape[0] + + # Compute segment boundaries (1-frame overlap) + start = segment_index * (segment_size - 1) + end = min(start + segment_size, total_input) + + if start >= total_input - 1: + # Past the end — return empty single frame + model + return (images[:1], model) + + segment_images = images[start:end] + is_continuation = segment_index > 0 + + # Delegate to the parent interpolation logic + (result,) = super().interpolate( + segment_images, model, multiplier, single_pass, + clear_cache_after_n_frames, keep_device, all_on_gpu, + batch_size, chunk_size, + ) + + if is_continuation: + result = result[1:] # skip duplicate boundary frame + + return (result, model) diff --git a/requirements.txt b/requirements.txt index 19387ac..0a20468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,6 @@ gdown +omegaconf +yacs +easydict +einops +huggingface_hub