Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
15
__init__.py
15
__init__.py
@@ -34,6 +34,14 @@ def _auto_install_deps():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
|
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
|
||||||
|
|
||||||
|
# GIMM-VFI dependencies
|
||||||
|
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"):
|
||||||
|
try:
|
||||||
|
__import__(pkg)
|
||||||
|
except ImportError:
|
||||||
|
logger.info(f"[Tween] Installing {pkg}...")
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
|
||||||
|
|
||||||
|
|
||||||
_auto_install_deps()
|
_auto_install_deps()
|
||||||
|
|
||||||
@@ -41,6 +49,7 @@ from .nodes import (
|
|||||||
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
|
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
|
||||||
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
||||||
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
||||||
|
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
||||||
)
|
)
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web"
|
WEB_DIRECTORY = "./web"
|
||||||
@@ -56,6 +65,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadSGMVFIModel": LoadSGMVFIModel,
|
"LoadSGMVFIModel": LoadSGMVFIModel,
|
||||||
"SGMVFIInterpolate": SGMVFIInterpolate,
|
"SGMVFIInterpolate": SGMVFIInterpolate,
|
||||||
"SGMVFISegmentInterpolate": SGMVFISegmentInterpolate,
|
"SGMVFISegmentInterpolate": SGMVFISegmentInterpolate,
|
||||||
|
"LoadGIMMVFIModel": LoadGIMMVFIModel,
|
||||||
|
"GIMMVFIInterpolate": GIMMVFIInterpolate,
|
||||||
|
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -69,4 +81,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadSGMVFIModel": "Load SGM-VFI Model",
|
"LoadSGMVFIModel": "Load SGM-VFI Model",
|
||||||
"SGMVFIInterpolate": "SGM-VFI Interpolate",
|
"SGMVFIInterpolate": "SGM-VFI Interpolate",
|
||||||
"SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate",
|
"SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate",
|
||||||
|
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
|
||||||
|
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
|
||||||
|
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
|
||||||
}
|
}
|
||||||
|
|||||||
15
gimm_vfi_arch/__init__.py
Normal file
15
gimm_vfi_arch/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from .generalizable_INR.gimmvfi_r import GIMMVFI_R
|
||||||
|
from .generalizable_INR.gimmvfi_f import GIMMVFI_F
|
||||||
|
from .generalizable_INR.configs import GIMMVFIConfig
|
||||||
|
from .generalizable_INR.raft.raft import RAFT as GIMM_RAFT
|
||||||
|
from .generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer as GIMM_FlowFormer
|
||||||
|
from .generalizable_INR.flowformer.configs.submission import get_cfg as gimm_get_flowformer_cfg
|
||||||
|
from .utils.utils import InputPadder as GIMMInputPadder, RaftArgs as GIMMRaftArgs, easydict_to_dict
|
||||||
|
from .generalizable_INR.modules.softsplat import objCudacache as gimm_softsplat_cache
|
||||||
|
|
||||||
|
|
||||||
|
def clear_gimm_caches():
|
||||||
|
"""Clear cached CUDA kernels and warp grids for GIMM-VFI."""
|
||||||
|
from .generalizable_INR.modules.fi_utils import backwarp_tenGrid
|
||||||
|
backwarp_tenGrid.clear()
|
||||||
|
gimm_softsplat_cache.clear()
|
||||||
0
gimm_vfi_arch/configs/__init__.py
Normal file
0
gimm_vfi_arch/configs/__init__.py
Normal file
57
gimm_vfi_arch/configs/gimmvfi_f_arb.yaml
Normal file
57
gimm_vfi_arch/configs/gimmvfi_f_arb.yaml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
trainer: stage_inr
|
||||||
|
dataset:
|
||||||
|
type: vimeo_arb
|
||||||
|
path: ./data/vimeo90k/vimeo_septuplet
|
||||||
|
aug: true
|
||||||
|
|
||||||
|
arch:
|
||||||
|
type: gimmvfi_f
|
||||||
|
ema: true
|
||||||
|
modulated_layer_idxs: [1]
|
||||||
|
|
||||||
|
coord_range: [-1., 1.]
|
||||||
|
|
||||||
|
hyponet:
|
||||||
|
type: mlp
|
||||||
|
n_layer: 5 # including the output layer
|
||||||
|
hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
|
||||||
|
use_bias: true
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 2
|
||||||
|
output_bias: 0.5
|
||||||
|
activation:
|
||||||
|
type: siren
|
||||||
|
siren_w0: 1.0
|
||||||
|
initialization:
|
||||||
|
weight_init_type: siren
|
||||||
|
bias_init_type: siren
|
||||||
|
|
||||||
|
loss:
|
||||||
|
subsample:
|
||||||
|
type: random
|
||||||
|
ratio: 0.1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: adamw
|
||||||
|
init_lr: 0.00008
|
||||||
|
weight_decay: 0.00004
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
ft: true
|
||||||
|
warmup:
|
||||||
|
epoch: 1
|
||||||
|
multiplier: 1
|
||||||
|
buffer_epoch: 0
|
||||||
|
min_lr: 0.000008
|
||||||
|
mode: fix
|
||||||
|
start_from_zero: True
|
||||||
|
max_gn: null
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
amp: True
|
||||||
|
batch_size: 4
|
||||||
|
total_batch_size: 32
|
||||||
|
epochs: 60
|
||||||
|
save_ckpt_freq: 10
|
||||||
|
test_freq: 10
|
||||||
|
test_imlog_freq: 10
|
||||||
|
|
||||||
57
gimm_vfi_arch/configs/gimmvfi_r_arb.yaml
Normal file
57
gimm_vfi_arch/configs/gimmvfi_r_arb.yaml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
trainer: stage_inr
|
||||||
|
dataset:
|
||||||
|
type: vimeo_arb
|
||||||
|
path: ./data/vimeo90k/vimeo_septuplet
|
||||||
|
aug: true
|
||||||
|
|
||||||
|
arch:
|
||||||
|
type: gimmvfi_r
|
||||||
|
ema: true
|
||||||
|
modulated_layer_idxs: [1]
|
||||||
|
|
||||||
|
coord_range: [-1., 1.]
|
||||||
|
|
||||||
|
hyponet:
|
||||||
|
type: mlp
|
||||||
|
n_layer: 5 # including the output layer
|
||||||
|
hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
|
||||||
|
use_bias: true
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 2
|
||||||
|
output_bias: 0.5
|
||||||
|
activation:
|
||||||
|
type: siren
|
||||||
|
siren_w0: 1.0
|
||||||
|
initialization:
|
||||||
|
weight_init_type: siren
|
||||||
|
bias_init_type: siren
|
||||||
|
|
||||||
|
loss:
|
||||||
|
subsample:
|
||||||
|
type: random
|
||||||
|
ratio: 0.1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: adamw
|
||||||
|
init_lr: 0.00008
|
||||||
|
weight_decay: 0.00004
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
ft: true
|
||||||
|
warmup:
|
||||||
|
epoch: 1
|
||||||
|
multiplier: 1
|
||||||
|
buffer_epoch: 0
|
||||||
|
min_lr: 0.000008
|
||||||
|
mode: fix
|
||||||
|
start_from_zero: True
|
||||||
|
max_gn: null
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
amp: True
|
||||||
|
batch_size: 4
|
||||||
|
total_batch_size: 32
|
||||||
|
epochs: 60
|
||||||
|
save_ckpt_freq: 10
|
||||||
|
test_freq: 10
|
||||||
|
test_imlog_freq: 10
|
||||||
|
|
||||||
0
gimm_vfi_arch/generalizable_INR/__init__.py
Normal file
0
gimm_vfi_arch/generalizable_INR/__init__.py
Normal file
57
gimm_vfi_arch/generalizable_INR/configs.py
Normal file
57
gimm_vfi_arch/generalizable_INR/configs.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf, MISSING
|
||||||
|
from .modules.module_config import HypoNetConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GIMMConfig:
|
||||||
|
type: str = "gimm"
|
||||||
|
ema: Optional[bool] = None
|
||||||
|
ema_value: Optional[float] = None
|
||||||
|
fwarp_type: str = "linear"
|
||||||
|
hyponet: HypoNetConfig = field(default_factory=HypoNetConfig)
|
||||||
|
coord_range: List[float] = MISSING
|
||||||
|
modulated_layer_idxs: Optional[List[int]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config):
|
||||||
|
# We need to specify the type of the default DataEncoderConfig.
|
||||||
|
# Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
|
||||||
|
# hence merging with the config with other type would cause config error.
|
||||||
|
defaults = OmegaConf.structured(cls(ema=False))
|
||||||
|
config = OmegaConf.merge(defaults, config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GIMMVFIConfig:
|
||||||
|
type: str = "gimmvfi"
|
||||||
|
ema: Optional[bool] = None
|
||||||
|
ema_value: Optional[float] = None
|
||||||
|
fwarp_type: str = "linear"
|
||||||
|
rec_weight: float = 0.1
|
||||||
|
hyponet: HypoNetConfig = field(default_factory=HypoNetConfig)
|
||||||
|
raft_iter: int = 20
|
||||||
|
coord_range: List[float] = MISSING
|
||||||
|
modulated_layer_idxs: Optional[List[int]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config):
|
||||||
|
# We need to specify the type of the default DataEncoderConfig.
|
||||||
|
# Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
|
||||||
|
# hence merging with the config with other type would cause config error.
|
||||||
|
defaults = OmegaConf.structured(cls(ema=False))
|
||||||
|
config = OmegaConf.merge(defaults, config)
|
||||||
|
return config
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
_CN = CN()
|
||||||
|
|
||||||
|
_CN.name = ""
|
||||||
|
_CN.suffix = ""
|
||||||
|
_CN.gamma = 0.8
|
||||||
|
_CN.max_flow = 400
|
||||||
|
_CN.batch_size = 6
|
||||||
|
_CN.sum_freq = 100
|
||||||
|
_CN.val_freq = 5000000
|
||||||
|
_CN.image_size = [432, 960]
|
||||||
|
_CN.add_noise = False
|
||||||
|
_CN.critical_params = []
|
||||||
|
|
||||||
|
_CN.transformer = "latentcostformer"
|
||||||
|
_CN.model = "pretrained_ckpt/flowformer_sintel.pth"
|
||||||
|
|
||||||
|
# latentcostformer
|
||||||
|
_CN.latentcostformer = CN()
|
||||||
|
_CN.latentcostformer.pe = "linear"
|
||||||
|
_CN.latentcostformer.dropout = 0.0
|
||||||
|
_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
|
||||||
|
_CN.latentcostformer.query_latent_dim = 64
|
||||||
|
_CN.latentcostformer.cost_latent_input_dim = 64
|
||||||
|
_CN.latentcostformer.cost_latent_token_num = 8
|
||||||
|
_CN.latentcostformer.cost_latent_dim = 128
|
||||||
|
_CN.latentcostformer.arc_type = "transformer"
|
||||||
|
_CN.latentcostformer.cost_heads_num = 1
|
||||||
|
# encoder
|
||||||
|
_CN.latentcostformer.pretrain = True
|
||||||
|
_CN.latentcostformer.context_concat = False
|
||||||
|
_CN.latentcostformer.encoder_depth = 3
|
||||||
|
_CN.latentcostformer.feat_cross_attn = False
|
||||||
|
_CN.latentcostformer.patch_size = 8
|
||||||
|
_CN.latentcostformer.patch_embed = "single"
|
||||||
|
_CN.latentcostformer.no_pe = False
|
||||||
|
_CN.latentcostformer.gma = "GMA"
|
||||||
|
_CN.latentcostformer.kernel_size = 9
|
||||||
|
_CN.latentcostformer.rm_res = True
|
||||||
|
_CN.latentcostformer.vert_c_dim = 64
|
||||||
|
_CN.latentcostformer.cost_encoder_res = True
|
||||||
|
_CN.latentcostformer.cnet = "twins"
|
||||||
|
_CN.latentcostformer.fnet = "twins"
|
||||||
|
_CN.latentcostformer.no_sc = False
|
||||||
|
_CN.latentcostformer.only_global = False
|
||||||
|
_CN.latentcostformer.add_flow_token = True
|
||||||
|
_CN.latentcostformer.use_mlp = False
|
||||||
|
_CN.latentcostformer.vertical_conv = False
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
_CN.latentcostformer.decoder_depth = 32
|
||||||
|
_CN.latentcostformer.critical_params = [
|
||||||
|
"cost_heads_num",
|
||||||
|
"vert_c_dim",
|
||||||
|
"cnet",
|
||||||
|
"pretrain",
|
||||||
|
"add_flow_token",
|
||||||
|
"encoder_depth",
|
||||||
|
"gma",
|
||||||
|
"cost_encoder_res",
|
||||||
|
]
|
||||||
|
|
||||||
|
### TRAINER
|
||||||
|
_CN.trainer = CN()
|
||||||
|
_CN.trainer.scheduler = "OneCycleLR"
|
||||||
|
_CN.trainer.optimizer = "adamw"
|
||||||
|
_CN.trainer.canonical_lr = 12.5e-5
|
||||||
|
_CN.trainer.adamw_decay = 1e-4
|
||||||
|
_CN.trainer.clip = 1.0
|
||||||
|
_CN.trainer.num_steps = 120000
|
||||||
|
_CN.trainer.epsilon = 1e-8
|
||||||
|
_CN.trainer.anneal_strategy = "linear"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg():
|
||||||
|
return _CN.clone()
|
||||||
@@ -0,0 +1,197 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class BroadMultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads):
|
||||||
|
super(BroadMultiHeadAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (dim / heads) ** -0.5
|
||||||
|
self.attend = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def attend_with_rpe(self, Q, K):
|
||||||
|
Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
|
||||||
|
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
|
||||||
|
|
||||||
|
return self.attend(dots)
|
||||||
|
|
||||||
|
def forward(self, Q, K, V):
|
||||||
|
attn = self.attend_with_rpe(Q, K)
|
||||||
|
B, _, _ = K.shape
|
||||||
|
_, N, _ = Q.shape
|
||||||
|
|
||||||
|
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||||
|
out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads):
|
||||||
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (dim / heads) ** -0.5
|
||||||
|
self.attend = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def attend_with_rpe(self, Q, K):
|
||||||
|
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
|
||||||
|
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
dots = (
|
||||||
|
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
|
||||||
|
) # (b hw) heads 1 pointnum
|
||||||
|
|
||||||
|
return self.attend(dots)
|
||||||
|
|
||||||
|
def forward(self, Q, K, V):
|
||||||
|
attn = self.attend_with_rpe(Q, K)
|
||||||
|
B, HW, _ = Q.shape
|
||||||
|
|
||||||
|
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||||
|
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# class MultiHeadAttentionRelative_encoder(nn.Module):
|
||||||
|
# def __init__(self, dim, heads):
|
||||||
|
# super(MultiHeadAttentionRelative, self).__init__()
|
||||||
|
# self.dim = dim
|
||||||
|
# self.heads = heads
|
||||||
|
# self.scale = (dim/heads) ** -0.5
|
||||||
|
# self.attend = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
# def attend_with_rpe(self, Q, K, Q_r, K_r):
|
||||||
|
# """
|
||||||
|
# Q: [BH1W1, H3W3, dim]
|
||||||
|
# K: [BH1W1, H3W3, dim]
|
||||||
|
# Q_r: [BH1W1, H3W3, H3W3, dim]
|
||||||
|
# K_r: [BH1W1, H3W3, H3W3, dim]
|
||||||
|
# """
|
||||||
|
|
||||||
|
# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||||
|
# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||||
|
# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||||
|
# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||||
|
|
||||||
|
# # context-context similarity
|
||||||
|
# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3]
|
||||||
|
# # context-position similarity
|
||||||
|
# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
|
||||||
|
# # position-context similarity
|
||||||
|
# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:])
|
||||||
|
# p_c = torch.squeeze(p_c, dim=4)
|
||||||
|
# p_c = p_c.permute(0, 1, 3, 2)
|
||||||
|
# dots = c_c + c_p + p_c
|
||||||
|
# return self.attend(dots)
|
||||||
|
|
||||||
|
# def forward(self, Q, K, V, Q_r, K_r):
|
||||||
|
# attn = self.attend_with_rpe(Q, K, Q_r, K_r)
|
||||||
|
# B, HW, _ = Q.shape
|
||||||
|
|
||||||
|
# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
|
||||||
|
|
||||||
|
# out = einsum('bhij, bhjd -> bhid', attn, V)
|
||||||
|
# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
|
||||||
|
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionRelative(nn.Module):
|
||||||
|
def __init__(self, dim, heads):
|
||||||
|
super(MultiHeadAttentionRelative, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (dim / heads) ** -0.5
|
||||||
|
self.attend = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def attend_with_rpe(self, Q, K, Q_r, K_r):
|
||||||
|
"""
|
||||||
|
Q: [BH1W1, 1, dim]
|
||||||
|
K: [BH1W1, H3W3, dim]
|
||||||
|
Q_r: [BH1W1, H3W3, dim]
|
||||||
|
K_r: [BH1W1, H3W3, dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
Q = rearrange(
|
||||||
|
Q, "b i (heads d) -> b heads i d", heads=self.heads
|
||||||
|
) # [BH1W1, heads, 1, dim]
|
||||||
|
K = rearrange(
|
||||||
|
K, "b j (heads d) -> b heads j d", heads=self.heads
|
||||||
|
) # [BH1W1, heads, H3W3, dim]
|
||||||
|
K_r = rearrange(
|
||||||
|
K_r, "b j (heads d) -> b heads j d", heads=self.heads
|
||||||
|
) # [BH1W1, heads, H3W3, dim]
|
||||||
|
Q_r = rearrange(
|
||||||
|
Q_r, "b j (heads d) -> b heads j d", heads=self.heads
|
||||||
|
) # [BH1W1, heads, H3W3, dim]
|
||||||
|
|
||||||
|
# context-context similarity
|
||||||
|
c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3]
|
||||||
|
# context-position similarity
|
||||||
|
c_p = (
|
||||||
|
einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale
|
||||||
|
) # [(B H1W1) heads 1 H3W3]
|
||||||
|
# position-context similarity
|
||||||
|
p_c = (
|
||||||
|
einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :])
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
p_c = torch.squeeze(p_c, dim=4)
|
||||||
|
p_c = p_c.permute(0, 1, 3, 2)
|
||||||
|
dots = c_c + c_p + p_c
|
||||||
|
return self.attend(dots)
|
||||||
|
|
||||||
|
def forward(self, Q, K, V, Q_r, K_r):
|
||||||
|
attn = self.attend_with_rpe(Q, K, Q_r, K_r)
|
||||||
|
B, HW, _ = Q.shape
|
||||||
|
|
||||||
|
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||||
|
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
|
||||||
|
# 200 should be enough for a 8x downsampled image
|
||||||
|
# assume x to be [_, _, 2]
|
||||||
|
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
|
||||||
|
# 200 should be enough for a 8x downsampled image
|
||||||
|
# assume x to be [_, _, 2]
|
||||||
|
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||||
|
torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||||
|
torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||||
|
torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
@@ -0,0 +1,649 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
y = self.relu(self.norm3(self.conv3(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module):
|
||||||
|
def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
mul = input_dim // 3
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(64 * mul)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(64 * mul)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 64 * mul
|
||||||
|
self.layer1 = self._make_layer(64 * mul, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96 * mul, stride=2)
|
||||||
|
self.layer3 = self._make_layer(128 * mul, stride=2)
|
||||||
|
|
||||||
|
# output convolution
|
||||||
|
self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def compute_params(self):
|
||||||
|
num = 0
|
||||||
|
for param in self.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SmallEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||||
|
super(SmallEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 32
|
||||||
|
self.layer1 = self._make_layer(32, stride=1)
|
||||||
|
self.layer2 = self._make_layer(64, stride=2)
|
||||||
|
self.layer3 = self._make_layer(96, stride=2)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNets(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1):
|
||||||
|
super(ConvNets, self).__init__()
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(
|
||||||
|
in_dim, inter_dim, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv_last = nn.Conv2d(
|
||||||
|
inter_dim, out_dim, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.inter_convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu(self.conv_first(x))
|
||||||
|
for inter_conv in self.inter_convs:
|
||||||
|
x = inter_conv(x)
|
||||||
|
x = self.conv_last(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FlowHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(ConvGRU, self).__init__()
|
||||||
|
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
|
||||||
|
z = torch.sigmoid(self.convz(hx))
|
||||||
|
r = torch.sigmoid(self.convr(hx))
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||||
|
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(SepConvGRU, self).__init__()
|
||||||
|
self.convz1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convr1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convq1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convr2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convq2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
# horizontal
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz1(hx))
|
||||||
|
r = torch.sigmoid(self.convr1(hx))
|
||||||
|
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
# vertical
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz2(hx))
|
||||||
|
r = torch.sigmoid(self.convr2(hx))
|
||||||
|
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicMotionEncoder, self).__init__()
|
||||||
|
cor_planes = args.motion_feature_dim
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||||
|
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
cor = F.relu(self.convc2(cor))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicFuseMotion(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicFuseMotion, self).__init__()
|
||||||
|
cor_planes = args.motion_feature_dim
|
||||||
|
out_planes = args.query_latent_dim
|
||||||
|
|
||||||
|
self.normf1 = nn.InstanceNorm2d(128)
|
||||||
|
self.normf2 = nn.InstanceNorm2d(128)
|
||||||
|
|
||||||
|
self.convf1 = nn.Conv2d(2, 128, 3, padding=1)
|
||||||
|
self.convf2 = nn.Conv2d(128, 128, 3, padding=1)
|
||||||
|
self.convf3 = nn.Conv2d(128, 64, 3, padding=1)
|
||||||
|
|
||||||
|
s = 1
|
||||||
|
self.normc1 = nn.InstanceNorm2d(256 * s)
|
||||||
|
self.normc2 = nn.InstanceNorm2d(256 * s)
|
||||||
|
self.normc3 = nn.InstanceNorm2d(256 * s)
|
||||||
|
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||||
|
self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||||
|
self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, flow, feat, context1=None):
|
||||||
|
flo = F.relu(self.normf1(self.convf1(flow)))
|
||||||
|
flo = F.relu(self.normf2(self.convf2(flo)))
|
||||||
|
flo = self.convf3(flo)
|
||||||
|
|
||||||
|
feat = torch.cat([feat, context1], dim=1)
|
||||||
|
feat = F.relu(self.normc1(self.convc1(feat)))
|
||||||
|
feat = F.relu(self.normc2(self.convc2(feat)))
|
||||||
|
feat = F.relu(self.normc3(self.convc3(feat)))
|
||||||
|
feat = self.convc4(feat)
|
||||||
|
|
||||||
|
feat = torch.cat([flo, feat], dim=1)
|
||||||
|
feat = F.relu(self.conv(feat))
|
||||||
|
|
||||||
|
return feat
|
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||||
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, upsample=True):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
# scale mask to balence gradients
|
||||||
|
mask = 0.25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
class DirectMeanMaskPredictor(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(DirectMeanMaskPredictor, self).__init__()
|
||||||
|
self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256)
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(args.predictor_dim, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, motion_features):
|
||||||
|
delta_flow = self.flow_head(motion_features)
|
||||||
|
mask = 0.25 * self.mask(motion_features)
|
||||||
|
|
||||||
|
return mask, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
class BaiscMeanPredictor(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128):
|
||||||
|
super(BaiscMeanPredictor, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, latent, flow):
|
||||||
|
motion_features = self.encoder(flow, latent)
|
||||||
|
delta_flow = self.flow_head(motion_features)
|
||||||
|
mask = 0.25 * self.mask(motion_features)
|
||||||
|
|
||||||
|
return mask, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
class BasicRPEEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicRPEEncoder, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
dim = args.query_latent_dim
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(2, dim // 2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(dim // 2, dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, rpe_tokens):
|
||||||
|
return self.encoder(rpe_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
from .twins import Block, CrossBlock
|
||||||
|
|
||||||
|
|
||||||
|
class TwinsSelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(TwinsSelfAttentionLayer, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 8
|
||||||
|
mlp_ratio = 4
|
||||||
|
ws = 7
|
||||||
|
sr_ratio = 4
|
||||||
|
dpr = 0.0
|
||||||
|
drop_rate = 0.0
|
||||||
|
attn_drop_rate = 0.0
|
||||||
|
|
||||||
|
self.local_block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=ws,
|
||||||
|
with_rpe=True,
|
||||||
|
)
|
||||||
|
self.global_block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=1,
|
||||||
|
with_rpe=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, tgt, size):
|
||||||
|
x = self.local_block(x, size)
|
||||||
|
x = self.global_block(x, size)
|
||||||
|
|
||||||
|
tgt = self.local_block(tgt, size)
|
||||||
|
tgt = self.global_block(tgt, size)
|
||||||
|
return x, tgt
|
||||||
|
|
||||||
|
|
||||||
|
class TwinsCrossAttentionLayer(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(TwinsCrossAttentionLayer, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 8
|
||||||
|
mlp_ratio = 4
|
||||||
|
ws = 7
|
||||||
|
sr_ratio = 4
|
||||||
|
dpr = 0.0
|
||||||
|
drop_rate = 0.0
|
||||||
|
attn_drop_rate = 0.0
|
||||||
|
|
||||||
|
self.local_block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=ws,
|
||||||
|
with_rpe=True,
|
||||||
|
)
|
||||||
|
self.global_block = CrossBlock(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=1,
|
||||||
|
with_rpe=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, tgt, size):
|
||||||
|
x = self.local_block(x, size)
|
||||||
|
tgt = self.local_block(tgt, size)
|
||||||
|
x, tgt = self.global_block(x, tgt, size)
|
||||||
|
|
||||||
|
return x, tgt
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
#from turtle import forward
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNextLayer(nn.Module):
|
||||||
|
def __init__(self, dim, depth=4):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
def compute_params(self):
|
||||||
|
num = 0
|
||||||
|
for param in self.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNextBlock(nn.Module):
|
||||||
|
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||||
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||||
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||||
|
We use (2) as we find it slightly faster in PyTorch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||||
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, layer_scale_init_value=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.dwconv = nn.Conv2d(
|
||||||
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = LayerNorm(dim, eps=1e-6)
|
||||||
|
self.pwconv1 = nn.Linear(
|
||||||
|
dim, 4 * dim
|
||||||
|
) # pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.pwconv2 = nn.Linear(4 * dim, dim)
|
||||||
|
self.gamma = (
|
||||||
|
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||||
|
if layer_scale_init_value > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
# print(f"conv next layer")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
input = x
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
if self.gamma is not None:
|
||||||
|
x = self.gamma * x
|
||||||
|
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||||
|
|
||||||
|
x = input + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||||
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||||
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||||
|
with shape (batch_size, channels, height, width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||||
|
self.eps = eps
|
||||||
|
self.data_format = data_format
|
||||||
|
if self.data_format not in ["channels_last", "channels_first"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.normalized_shape = (normalized_shape,)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.data_format == "channels_last":
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape, self.weight, self.bias, self.eps
|
||||||
|
)
|
||||||
|
elif self.data_format == "channels_first":
|
||||||
|
u = x.mean(1, keepdim=True)
|
||||||
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||||
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||||||
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||||
|
return x
|
||||||
@@ -0,0 +1,316 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ...utils.utils import coords_grid, bilinear_sampler
|
||||||
|
from .attention import (
|
||||||
|
MultiHeadAttention,
|
||||||
|
LinearPositionEmbeddingSine,
|
||||||
|
ExpPositionEmbeddingSine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
|
||||||
|
from .gru import BasicUpdateBlock, GMAUpdateBlock
|
||||||
|
from .gma import Attention
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_flow(img):
|
||||||
|
"""Flow is represented as difference between two means flow = mean1 - mean0"""
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
mean = coords_grid(N, H, W).to(img.device)
|
||||||
|
mean_init = coords_grid(N, H, W).to(img.device)
|
||||||
|
|
||||||
|
# optical flow computed as difference: flow = mean1 - mean0
|
||||||
|
return mean, mean_init
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionLayer(nn.Module):
|
||||||
|
# def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qk_dim,
|
||||||
|
v_dim,
|
||||||
|
query_token_dim,
|
||||||
|
tgt_token_dim,
|
||||||
|
add_flow_token=True,
|
||||||
|
num_heads=8,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
dropout=0.0,
|
||||||
|
pe="linear",
|
||||||
|
):
|
||||||
|
super(CrossAttentionLayer, self).__init__()
|
||||||
|
|
||||||
|
head_dim = qk_dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
self.query_token_dim = query_token_dim
|
||||||
|
self.pe = pe
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(query_token_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(query_token_dim)
|
||||||
|
self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads)
|
||||||
|
self.q, self.k, self.v = (
|
||||||
|
nn.Linear(query_token_dim, qk_dim, bias=True),
|
||||||
|
nn.Linear(tgt_token_dim, qk_dim, bias=True),
|
||||||
|
nn.Linear(tgt_token_dim, v_dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Linear(v_dim * 2, query_token_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(query_token_dim, query_token_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(query_token_dim, query_token_dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
self.add_flow_token = add_flow_token
|
||||||
|
self.dim = qk_dim
|
||||||
|
|
||||||
|
def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3):
|
||||||
|
"""
|
||||||
|
query_coord [B, 2, H1, W1]
|
||||||
|
"""
|
||||||
|
B, _, H1, W1 = query_coord.shape
|
||||||
|
|
||||||
|
if key is None and value is None:
|
||||||
|
key = self.k(memory)
|
||||||
|
value = self.v(memory)
|
||||||
|
|
||||||
|
# [B, 2, H1, W1] -> [BH1W1, 1, 2]
|
||||||
|
query_coord = query_coord.contiguous()
|
||||||
|
query_coord = (
|
||||||
|
query_coord.view(B, 2, -1)
|
||||||
|
.permute(0, 2, 1)[:, :, None, :]
|
||||||
|
.contiguous()
|
||||||
|
.view(B * H1 * W1, 1, 2)
|
||||||
|
)
|
||||||
|
if self.pe == "linear":
|
||||||
|
query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim)
|
||||||
|
elif self.pe == "exp":
|
||||||
|
query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim)
|
||||||
|
|
||||||
|
short_cut = query
|
||||||
|
query = self.norm1(query)
|
||||||
|
|
||||||
|
if self.add_flow_token:
|
||||||
|
q = self.q(query + query_coord_enc)
|
||||||
|
else:
|
||||||
|
q = self.q(query_coord_enc)
|
||||||
|
k, v = key, value
|
||||||
|
|
||||||
|
x = self.multi_head_attn(q, k, v)
|
||||||
|
|
||||||
|
x = self.proj(torch.cat([x, short_cut], dim=2))
|
||||||
|
x = short_cut + self.proj_drop(x)
|
||||||
|
|
||||||
|
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||||
|
|
||||||
|
return x, k, v
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, dim, cfg):
|
||||||
|
super(MemoryDecoderLayer, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.patch_size = cfg.patch_size # for converting coords into H2', W2' space
|
||||||
|
|
||||||
|
query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim
|
||||||
|
qk_dim, v_dim = query_token_dim, query_token_dim
|
||||||
|
self.cross_attend = CrossAttentionLayer(
|
||||||
|
qk_dim,
|
||||||
|
v_dim,
|
||||||
|
query_token_dim,
|
||||||
|
tgt_token_dim,
|
||||||
|
add_flow_token=cfg.add_flow_token,
|
||||||
|
dropout=cfg.dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, query, key, value, memory, coords1, size, size_h3w3):
|
||||||
|
"""
|
||||||
|
x: [B*H1*W1, 1, C]
|
||||||
|
memory: [B*H1*W1, H2'*W2', C]
|
||||||
|
coords1 [B, 2, H2, W2]
|
||||||
|
size: B, C, H1, W1
|
||||||
|
1. Note that here coords0 and coords1 are in H2, W2 space.
|
||||||
|
Should first convert it into H2', W2' space.
|
||||||
|
2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0]
|
||||||
|
"""
|
||||||
|
x_global, k, v = self.cross_attend(
|
||||||
|
query, key, value, memory, coords1, self.patch_size, size_h3w3
|
||||||
|
)
|
||||||
|
B, C, H1, W1 = size
|
||||||
|
C = self.cfg.query_latent_dim
|
||||||
|
x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2)
|
||||||
|
return x_global, k, v
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseCostExtractor(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(ReverseCostExtractor, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def forward(self, cost_maps, coords0, coords1):
|
||||||
|
"""
|
||||||
|
cost_maps - B*H1*W1, cost_heads_num, H2, W2
|
||||||
|
coords - B, 2, H1, W1
|
||||||
|
"""
|
||||||
|
BH1W1, heads, H2, W2 = cost_maps.shape
|
||||||
|
B, _, H1, W1 = coords1.shape
|
||||||
|
|
||||||
|
assert (H1 == H2) and (W1 == W2)
|
||||||
|
assert BH1W1 == B * H1 * W1
|
||||||
|
|
||||||
|
cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2)
|
||||||
|
coords = coords1.permute(0, 2, 3, 1)
|
||||||
|
corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2]
|
||||||
|
corr = rearrange(
|
||||||
|
corr,
|
||||||
|
"b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1",
|
||||||
|
b=B,
|
||||||
|
heads=heads,
|
||||||
|
h1=H1,
|
||||||
|
w1=W1,
|
||||||
|
h2=H2,
|
||||||
|
w2=W2,
|
||||||
|
)
|
||||||
|
|
||||||
|
r = 4
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device)
|
||||||
|
centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2)
|
||||||
|
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords = centroid + delta
|
||||||
|
corr = bilinear_sampler(corr, coords)
|
||||||
|
corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2)
|
||||||
|
return corr
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryDecoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(MemoryDecoder, self).__init__()
|
||||||
|
dim = self.dim = cfg.query_latent_dim
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.flow_token_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(dim, dim, 1, 1),
|
||||||
|
)
|
||||||
|
self.proj = nn.Conv2d(256, 256, 1)
|
||||||
|
self.depth = cfg.decoder_depth
|
||||||
|
self.decoder_layer = MemoryDecoderLayer(dim, cfg)
|
||||||
|
|
||||||
|
if self.cfg.gma:
|
||||||
|
self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128)
|
||||||
|
self.att = Attention(
|
||||||
|
args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128)
|
||||||
|
|
||||||
|
def upsample_flow(self, flow, mask):
|
||||||
|
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
|
||||||
|
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
||||||
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||||
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
||||||
|
|
||||||
|
def encode_flow_token(self, cost_maps, coords):
|
||||||
|
"""
|
||||||
|
cost_maps - B*H1*W1, cost_heads_num, H2, W2
|
||||||
|
coords - B, 2, H1, W1
|
||||||
|
"""
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
batch, h1, w1, _ = coords.shape
|
||||||
|
|
||||||
|
r = 4
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||||
|
|
||||||
|
centroid = coords.reshape(batch * h1 * w1, 1, 1, 2)
|
||||||
|
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords = centroid + delta
|
||||||
|
corr = bilinear_sampler(cost_maps, coords)
|
||||||
|
corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2)
|
||||||
|
return corr
|
||||||
|
|
||||||
|
def forward(self, cost_memory, context, data={}, flow_init=None, iters=None):
|
||||||
|
"""
|
||||||
|
memory: [B*H1*W1, H2'*W2', C]
|
||||||
|
context: [B, D, H1, W1]
|
||||||
|
"""
|
||||||
|
cost_maps = data["cost_maps"]
|
||||||
|
coords0, coords1 = initialize_flow(context)
|
||||||
|
|
||||||
|
if flow_init is not None:
|
||||||
|
# print("[Using warm start]")
|
||||||
|
coords1 = coords1 + flow_init
|
||||||
|
|
||||||
|
# flow = coords1
|
||||||
|
|
||||||
|
flow_predictions = []
|
||||||
|
|
||||||
|
context = self.proj(context)
|
||||||
|
net, inp = torch.split(context, [128, 128], dim=1)
|
||||||
|
net = torch.tanh(net)
|
||||||
|
inp = torch.relu(inp)
|
||||||
|
if self.cfg.gma:
|
||||||
|
attention = self.att(inp)
|
||||||
|
|
||||||
|
size = net.shape
|
||||||
|
key, value = None, None
|
||||||
|
if iters is None:
|
||||||
|
iters = self.depth
|
||||||
|
for idx in range(iters):
|
||||||
|
coords1 = coords1.detach()
|
||||||
|
|
||||||
|
cost_forward = self.encode_flow_token(cost_maps, coords1)
|
||||||
|
# cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1)
|
||||||
|
|
||||||
|
query = self.flow_token_encoder(cost_forward)
|
||||||
|
query = (
|
||||||
|
query.permute(0, 2, 3, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(size[0] * size[2] * size[3], 1, self.dim)
|
||||||
|
)
|
||||||
|
cost_global, key, value = self.decoder_layer(
|
||||||
|
query, key, value, cost_memory, coords1, size, data["H3W3"]
|
||||||
|
)
|
||||||
|
if self.cfg.only_global:
|
||||||
|
corr = cost_global
|
||||||
|
else:
|
||||||
|
corr = torch.cat([cost_global, cost_forward], dim=1)
|
||||||
|
|
||||||
|
flow = coords1 - coords0
|
||||||
|
|
||||||
|
if self.cfg.gma:
|
||||||
|
net, up_mask, delta_flow = self.update_block(
|
||||||
|
net, inp, corr, flow, attention
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||||
|
|
||||||
|
# flow = delta_flow
|
||||||
|
coords1 = coords1 + delta_flow
|
||||||
|
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||||
|
flow_predictions.append(flow_up)
|
||||||
|
|
||||||
|
# if self.training:
|
||||||
|
# return flow_predictions
|
||||||
|
# else:
|
||||||
|
return flow_predictions[-1], coords1 - coords0
|
||||||
@@ -0,0 +1,534 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from ...utils.utils import coords_grid
|
||||||
|
from .attention import (
|
||||||
|
BroadMultiHeadAttention,
|
||||||
|
MultiHeadAttention,
|
||||||
|
LinearPositionEmbeddingSine,
|
||||||
|
ExpPositionEmbeddingSine,
|
||||||
|
)
|
||||||
|
from ..encoders import twins_svt_large
|
||||||
|
from typing import Tuple
|
||||||
|
from .twins import Size_
|
||||||
|
from .cnn import BasicEncoder
|
||||||
|
from .mlpmixer import MLPMixerLayer
|
||||||
|
from .convnext import ConvNextLayer
|
||||||
|
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.dim = embed_dim
|
||||||
|
self.pe = pe
|
||||||
|
|
||||||
|
# assert patch_size == 8
|
||||||
|
if patch_size == 8:
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif patch_size == 4:
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"patch size = {patch_size} is unacceptable.")
|
||||||
|
|
||||||
|
self.ffn_with_coord = nn.Sequential(
|
||||||
|
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(embed_dim * 2)
|
||||||
|
|
||||||
|
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
|
||||||
|
B, C, H, W = x.shape # C == 1
|
||||||
|
|
||||||
|
pad_l = pad_t = 0
|
||||||
|
pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||||
|
pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||||
|
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
out_size = x.shape[2:]
|
||||||
|
|
||||||
|
patch_coord = (
|
||||||
|
coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size
|
||||||
|
+ self.patch_size / 2
|
||||||
|
) # in feature coordinate space
|
||||||
|
patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
|
||||||
|
if self.pe == "linear":
|
||||||
|
patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
|
||||||
|
elif self.pe == "exp":
|
||||||
|
patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
|
||||||
|
patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(
|
||||||
|
B, -1, out_size[0], out_size[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
x_pe = torch.cat([x, patch_coord_enc], dim=1)
|
||||||
|
x = self.ffn_with_coord(x_pe)
|
||||||
|
x = self.norm(x.flatten(2).transpose(1, 2))
|
||||||
|
|
||||||
|
return x, out_size
|
||||||
|
|
||||||
|
|
||||||
|
from .twins import Block, CrossBlock
|
||||||
|
|
||||||
|
|
||||||
|
class GroupVerticalSelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
cfg,
|
||||||
|
num_heads=8,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(GroupVerticalSelfAttentionLayer, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
embed_dim = dim
|
||||||
|
mlp_ratio = 4
|
||||||
|
ws = 7
|
||||||
|
sr_ratio = 4
|
||||||
|
dpr = 0.0
|
||||||
|
drop_rate = dropout
|
||||||
|
attn_drop_rate = 0.0
|
||||||
|
|
||||||
|
self.block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=ws,
|
||||||
|
with_rpe=True,
|
||||||
|
vert_c_dim=cfg.vert_c_dim,
|
||||||
|
groupattention=True,
|
||||||
|
cfg=self.cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, size, context=None):
|
||||||
|
x = self.block(x, size, context)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalSelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
cfg,
|
||||||
|
num_heads=8,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(VerticalSelfAttentionLayer, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
embed_dim = dim
|
||||||
|
mlp_ratio = 4
|
||||||
|
ws = 7
|
||||||
|
sr_ratio = 4
|
||||||
|
dpr = 0.0
|
||||||
|
drop_rate = dropout
|
||||||
|
attn_drop_rate = 0.0
|
||||||
|
|
||||||
|
self.local_block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=ws,
|
||||||
|
with_rpe=True,
|
||||||
|
vert_c_dim=cfg.vert_c_dim,
|
||||||
|
)
|
||||||
|
self.global_block = Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr,
|
||||||
|
sr_ratio=sr_ratio,
|
||||||
|
ws=1,
|
||||||
|
with_rpe=True,
|
||||||
|
vert_c_dim=cfg.vert_c_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, size, context=None):
|
||||||
|
x = self.local_block(x, size, context)
|
||||||
|
x = self.global_block(x, size, context)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_params(self):
|
||||||
|
num = 0
|
||||||
|
for param in self.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
cfg,
|
||||||
|
num_heads=8,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(SelfAttentionLayer, self).__init__()
|
||||||
|
assert (
|
||||||
|
dim % num_heads == 0
|
||||||
|
), f"dim {dim} should be divided by num_heads {num_heads}."
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.multi_head_attn = MultiHeadAttention(dim, num_heads)
|
||||||
|
self.q, self.k, self.v = (
|
||||||
|
nn.Linear(dim, dim, bias=True),
|
||||||
|
nn.Linear(dim, dim, bias=True),
|
||||||
|
nn.Linear(dim, dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: [BH1W1, H3W3, D]
|
||||||
|
"""
|
||||||
|
short_cut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
q, k, v = self.q(x), self.k(x), self.v(x)
|
||||||
|
|
||||||
|
x = self.multi_head_attn(q, k, v)
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
x = short_cut + self.proj_drop(x)
|
||||||
|
|
||||||
|
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_params(self):
|
||||||
|
num = 0
|
||||||
|
for param in self.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qk_dim,
|
||||||
|
v_dim,
|
||||||
|
query_token_dim,
|
||||||
|
tgt_token_dim,
|
||||||
|
num_heads=8,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(CrossAttentionLayer, self).__init__()
|
||||||
|
assert (
|
||||||
|
qk_dim % num_heads == 0
|
||||||
|
), f"dim {qk_dim} should be divided by num_heads {num_heads}."
|
||||||
|
assert (
|
||||||
|
v_dim % num_heads == 0
|
||||||
|
), f"dim {v_dim} should be divided by num_heads {num_heads}."
|
||||||
|
"""
|
||||||
|
Query Token: [N, C] -> [N, qk_dim] (Q)
|
||||||
|
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
|
||||||
|
"""
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = qk_dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(query_token_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(query_token_dim)
|
||||||
|
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
|
||||||
|
self.q, self.k, self.v = (
|
||||||
|
nn.Linear(query_token_dim, qk_dim, bias=True),
|
||||||
|
nn.Linear(tgt_token_dim, qk_dim, bias=True),
|
||||||
|
nn.Linear(tgt_token_dim, v_dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Linear(v_dim, query_token_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(query_token_dim, query_token_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(query_token_dim, query_token_dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, query, tgt_token):
|
||||||
|
"""
|
||||||
|
x: [BH1W1, H3W3, D]
|
||||||
|
"""
|
||||||
|
short_cut = query
|
||||||
|
query = self.norm1(query)
|
||||||
|
|
||||||
|
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
|
||||||
|
|
||||||
|
x = self.multi_head_attn(q, k, v)
|
||||||
|
|
||||||
|
x = short_cut + self.proj_drop(self.proj(x))
|
||||||
|
|
||||||
|
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CostPerceiverEncoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(CostPerceiverEncoder, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.patch_size = cfg.patch_size
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
in_chans=self.cfg.cost_heads_num,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
embed_dim=cfg.cost_latent_input_dim,
|
||||||
|
pe=cfg.pe,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.depth = cfg.encoder_depth
|
||||||
|
|
||||||
|
self.latent_tokens = nn.Parameter(
|
||||||
|
torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_token_dim, tgt_token_dim = (
|
||||||
|
cfg.cost_latent_dim,
|
||||||
|
cfg.cost_latent_input_dim * 2,
|
||||||
|
)
|
||||||
|
qk_dim, v_dim = query_token_dim, query_token_dim
|
||||||
|
self.input_layer = CrossAttentionLayer(
|
||||||
|
qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.use_mlp:
|
||||||
|
self.encoder_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.vertical_conv:
|
||||||
|
self.vertical_encoder_layers = nn.ModuleList(
|
||||||
|
[ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.vertical_encoder_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
VerticalSelfAttentionLayer(
|
||||||
|
cfg.cost_latent_dim, cfg, dropout=cfg.dropout
|
||||||
|
)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.cost_scale_aug = None
|
||||||
|
if "cost_scale_aug" in cfg.keys():
|
||||||
|
self.cost_scale_aug = cfg.cost_scale_aug
|
||||||
|
print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug))
|
||||||
|
|
||||||
|
def forward(self, cost_volume, data, context=None):
|
||||||
|
B, heads, H1, W1, H2, W2 = cost_volume.shape
|
||||||
|
cost_maps = (
|
||||||
|
cost_volume.permute(0, 2, 3, 1, 4, 5)
|
||||||
|
.contiguous()
|
||||||
|
.view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
|
||||||
|
)
|
||||||
|
data["cost_maps"] = cost_maps
|
||||||
|
|
||||||
|
if self.cost_scale_aug is not None:
|
||||||
|
scale_factor = (
|
||||||
|
torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
|
||||||
|
.uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1])
|
||||||
|
.to(cost_maps.device)
|
||||||
|
)
|
||||||
|
cost_maps = cost_maps * scale_factor
|
||||||
|
|
||||||
|
x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C
|
||||||
|
data["H3W3"] = size
|
||||||
|
H3, W3 = size
|
||||||
|
|
||||||
|
x = self.input_layer(self.latent_tokens, x)
|
||||||
|
|
||||||
|
short_cut = x
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.encoder_layers):
|
||||||
|
x = layer(x)
|
||||||
|
if self.cfg.vertical_conv:
|
||||||
|
# B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1
|
||||||
|
x = (
|
||||||
|
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1)
|
||||||
|
)
|
||||||
|
x = self.vertical_encoder_layers[idx](x)
|
||||||
|
# B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D
|
||||||
|
x = (
|
||||||
|
x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1)
|
||||||
|
.permute(0, 2, 3, 1)
|
||||||
|
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = (
|
||||||
|
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1)
|
||||||
|
)
|
||||||
|
x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
|
||||||
|
x = (
|
||||||
|
x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.cost_encoder_res is True:
|
||||||
|
x = x + short_cut
|
||||||
|
# print("~~~~")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEncoder(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(MemoryEncoder, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
if cfg.fnet == "twins":
|
||||||
|
self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
|
||||||
|
elif cfg.fnet == "basicencoder":
|
||||||
|
self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
|
||||||
|
else:
|
||||||
|
exit()
|
||||||
|
self.channel_convertor = nn.Conv2d(
|
||||||
|
cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False
|
||||||
|
)
|
||||||
|
self.cost_perceiver_encoder = CostPerceiverEncoder(cfg)
|
||||||
|
|
||||||
|
def corr(self, fmap1, fmap2):
|
||||||
|
batch, dim, ht, wd = fmap1.shape
|
||||||
|
fmap1 = rearrange(
|
||||||
|
fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
|
||||||
|
)
|
||||||
|
fmap2 = rearrange(
|
||||||
|
fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
|
||||||
|
)
|
||||||
|
corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2)
|
||||||
|
corr = corr.permute(0, 2, 1, 3).view(
|
||||||
|
batch * ht * wd, self.cfg.cost_heads_num, ht, wd
|
||||||
|
)
|
||||||
|
# corr = self.norm(self.relu(corr))
|
||||||
|
corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd)
|
||||||
|
|
||||||
|
return corr
|
||||||
|
|
||||||
|
def forward(self, img1, img2, data, context=None, return_feat=False):
|
||||||
|
# The original implementation
|
||||||
|
# feat_s = self.feat_encoder(img1)
|
||||||
|
# feat_t = self.feat_encoder(img2)
|
||||||
|
# feat_s = self.channel_convertor(feat_s)
|
||||||
|
# feat_t = self.channel_convertor(feat_t)
|
||||||
|
|
||||||
|
imgs = torch.cat([img1, img2], dim=0)
|
||||||
|
feats = self.feat_encoder(imgs)
|
||||||
|
feats = self.channel_convertor(feats)
|
||||||
|
B = feats.shape[0] // 2
|
||||||
|
feat_s = feats[:B]
|
||||||
|
if return_feat:
|
||||||
|
ffeat = feats[:B]
|
||||||
|
feat_t = feats[B:]
|
||||||
|
|
||||||
|
B, C, H, W = feat_s.shape
|
||||||
|
size = (H, W)
|
||||||
|
|
||||||
|
if self.cfg.feat_cross_attn:
|
||||||
|
feat_s = feat_s.flatten(2).transpose(1, 2)
|
||||||
|
feat_t = feat_t.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
feat_s, feat_t = layer(feat_s, feat_t, size)
|
||||||
|
|
||||||
|
feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
cost_volume = self.corr(feat_s, feat_t)
|
||||||
|
x = self.cost_perceiver_encoder(cost_volume, data, context)
|
||||||
|
|
||||||
|
if return_feat:
|
||||||
|
return x, ffeat
|
||||||
|
return x
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class RelPosEmb(nn.Module):
|
||||||
|
def __init__(self, max_pos_size, dim_head):
|
||||||
|
super().__init__()
|
||||||
|
self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
|
||||||
|
self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
|
||||||
|
|
||||||
|
deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(
|
||||||
|
max_pos_size
|
||||||
|
).view(-1, 1)
|
||||||
|
rel_ind = deltas + max_pos_size - 1
|
||||||
|
self.register_buffer("rel_ind", rel_ind)
|
||||||
|
|
||||||
|
def forward(self, q):
|
||||||
|
batch, heads, h, w, c = q.shape
|
||||||
|
height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
|
||||||
|
width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
|
||||||
|
|
||||||
|
height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h)
|
||||||
|
width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w)
|
||||||
|
|
||||||
|
height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb)
|
||||||
|
width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb)
|
||||||
|
|
||||||
|
return height_score + width_score
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
args,
|
||||||
|
dim,
|
||||||
|
max_pos_size=100,
|
||||||
|
heads=4,
|
||||||
|
dim_head=128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
inner_dim = heads * dim_head
|
||||||
|
|
||||||
|
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
|
||||||
|
|
||||||
|
self.pos_emb = RelPosEmb(max_pos_size, dim_head)
|
||||||
|
for param in self.pos_emb.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, fmap):
|
||||||
|
heads, b, c, h, w = self.heads, *fmap.shape
|
||||||
|
|
||||||
|
q, k = self.to_qk(fmap).chunk(2, dim=1)
|
||||||
|
|
||||||
|
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
|
||||||
|
q = self.scale * q
|
||||||
|
|
||||||
|
# if self.args.position_only:
|
||||||
|
# sim = self.pos_emb(q)
|
||||||
|
|
||||||
|
# elif self.args.position_and_content:
|
||||||
|
# sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
|
||||||
|
# sim_pos = self.pos_emb(q)
|
||||||
|
# sim = sim_content + sim_pos
|
||||||
|
|
||||||
|
# else:
|
||||||
|
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
|
||||||
|
|
||||||
|
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
return attn
|
||||||
|
|
||||||
|
|
||||||
|
class Aggregate(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
dim,
|
||||||
|
heads=4,
|
||||||
|
dim_head=128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
inner_dim = heads * dim_head
|
||||||
|
|
||||||
|
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
if dim != inner_dim:
|
||||||
|
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
|
||||||
|
else:
|
||||||
|
self.project = None
|
||||||
|
|
||||||
|
def forward(self, attn, fmap):
|
||||||
|
heads, b, c, h, w = self.heads, *fmap.shape
|
||||||
|
|
||||||
|
v = self.to_v(fmap)
|
||||||
|
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
|
||||||
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||||
|
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
||||||
|
|
||||||
|
if self.project is not None:
|
||||||
|
out = self.project(out)
|
||||||
|
|
||||||
|
out = fmap + self.gamma * out
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
att = Attention(dim=128, heads=1)
|
||||||
|
fmap = torch.randn(2, 128, 40, 90)
|
||||||
|
out = att(fmap)
|
||||||
|
|
||||||
|
print(out.shape)
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FlowHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(ConvGRU, self).__init__()
|
||||||
|
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
|
||||||
|
z = torch.sigmoid(self.convz(hx))
|
||||||
|
r = torch.sigmoid(self.convr(hx))
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||||
|
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(SepConvGRU, self).__init__()
|
||||||
|
self.convz1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convr1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convq1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convr2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convq2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
# horizontal
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz1(hx))
|
||||||
|
r = torch.sigmoid(self.convr1(hx))
|
||||||
|
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
# vertical
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz2(hx))
|
||||||
|
r = torch.sigmoid(self.convr2(hx))
|
||||||
|
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicMotionEncoder, self).__init__()
|
||||||
|
if args.only_global:
|
||||||
|
print("[Decoding with only global cost]")
|
||||||
|
cor_planes = args.query_latent_dim
|
||||||
|
else:
|
||||||
|
cor_planes = 81 + args.query_latent_dim
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||||
|
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
cor = F.relu(self.convc2(cor))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||||
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, upsample=True):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
# scale mask to balence gradients
|
||||||
|
mask = 0.25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
from .gma import Aggregate
|
||||||
|
|
||||||
|
|
||||||
|
class GMAUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.gru = SepConvGRU(
|
||||||
|
hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim
|
||||||
|
)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, attention):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
motion_features_global = self.aggregator(attention, motion_features)
|
||||||
|
inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
|
||||||
|
|
||||||
|
# Attentional update
|
||||||
|
net = self.gru(net, inp_cat)
|
||||||
|
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
# scale mask to balence gradients
|
||||||
|
mask = 0.25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
from torch import nn
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
|
||||||
|
return nn.Sequential(
|
||||||
|
dense(dim, dim * expansion_factor),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
dense(dim * expansion_factor, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPMixerLayer(nn.Module):
|
||||||
|
def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0):
|
||||||
|
super(MLPMixerLayer, self).__init__()
|
||||||
|
|
||||||
|
# print(f"use mlp mixer layer")
|
||||||
|
K = cfg.cost_latent_token_num
|
||||||
|
expansion_factor = cfg.mlp_expansion_factor
|
||||||
|
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
||||||
|
|
||||||
|
self.mlpmixer = nn.Sequential(
|
||||||
|
PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)),
|
||||||
|
PreNormResidual(
|
||||||
|
dim, FeedForward(dim, expansion_factor, dropout, chan_last)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_params(self):
|
||||||
|
num = 0
|
||||||
|
for param in self.mlpmixer.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: [BH1W1, K, D]
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.mlpmixer(x)
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...utils.utils import coords_grid
|
||||||
|
from ..encoders import twins_svt_large
|
||||||
|
from .encoder import MemoryEncoder
|
||||||
|
from .decoder import MemoryDecoder
|
||||||
|
from .cnn import BasicEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class FlowFormer(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(FlowFormer, self).__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.memory_encoder = MemoryEncoder(cfg)
|
||||||
|
self.memory_decoder = MemoryDecoder(cfg)
|
||||||
|
if cfg.cnet == "twins":
|
||||||
|
self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
|
||||||
|
elif cfg.cnet == "basicencoder":
|
||||||
|
self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
|
||||||
|
|
||||||
|
def build_coord(self, img):
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
coords = coords_grid(N, H // 8, W // 8)
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None
|
||||||
|
):
|
||||||
|
# Following https://github.com/princeton-vl/RAFT/
|
||||||
|
image1 = 2 * (image1 / 255.0) - 1.0
|
||||||
|
image2 = 2 * (image2 / 255.0) - 1.0
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
if self.cfg.context_concat:
|
||||||
|
context = self.context_encoder(torch.cat([image1, image2], dim=1))
|
||||||
|
else:
|
||||||
|
if return_feat:
|
||||||
|
context, cfeat = self.context_encoder(image1, return_feat=return_feat)
|
||||||
|
else:
|
||||||
|
context = self.context_encoder(image1)
|
||||||
|
if return_feat:
|
||||||
|
cost_memory, ffeat = self.memory_encoder(
|
||||||
|
image1, image2, data, context, return_feat=return_feat
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cost_memory = self.memory_encoder(image1, image2, data, context)
|
||||||
|
|
||||||
|
flow_predictions = self.memory_decoder(
|
||||||
|
cost_memory, context, data, flow_init=flow_init, iters=iters
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_feat:
|
||||||
|
return flow_predictions, cfeat, ffeat
|
||||||
|
return flow_predictions
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
|||||||
|
def build_flowformer(cfg):
|
||||||
|
name = cfg.transformer
|
||||||
|
if name == "latentcostformer":
|
||||||
|
from .LatentCostFormer.transformer import FlowFormer
|
||||||
|
else:
|
||||||
|
raise ValueError(f"FlowFormer = {name} is not a valid architecture!")
|
||||||
|
return FlowFormer(cfg[name])
|
||||||
@@ -0,0 +1,562 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..utils.utils import bilinear_sampler, indexing
|
||||||
|
|
||||||
|
|
||||||
|
def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300):
|
||||||
|
"""
|
||||||
|
x is of shape [*, 2]. The last dimension are two coordinates (x and y).
|
||||||
|
"""
|
||||||
|
freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device)
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
x * NORMALIZE_FACOR,
|
||||||
|
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian(latent, mean, std, image_size, point_num=25):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
# std [B, 1, H, W]
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
delta_3sigma = (
|
||||||
|
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
|
||||||
|
* STD_MAX
|
||||||
|
* delta
|
||||||
|
* 3
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = centroid + delta_3sigma
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
sampled_latents = bilinear_sampler(
|
||||||
|
latent, coords
|
||||||
|
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||||
|
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
|
||||||
|
|
||||||
|
return sampled_latents, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_zy(
|
||||||
|
latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1
|
||||||
|
):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
# std [B, 1, H, W]
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
delta_3sigma = (
|
||||||
|
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = centroid + delta_3sigma
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
sampled_latents = bilinear_sampler(
|
||||||
|
latent, coords
|
||||||
|
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||||
|
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta
|
||||||
|
|
||||||
|
if return_deltaXY:
|
||||||
|
return sampled_latents, sampled_weights, delta_3sigma
|
||||||
|
else:
|
||||||
|
return sampled_latents, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
# std [B, 1, H, W]
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
delta_3sigma = (
|
||||||
|
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
|
||||||
|
* STD_MAX
|
||||||
|
* delta
|
||||||
|
* 3
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = centroid + delta_3sigma
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
sampled_latents = bilinear_sampler(
|
||||||
|
latent, coords
|
||||||
|
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||||
|
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
|
||||||
|
|
||||||
|
if return_deltaXY:
|
||||||
|
return sampled_latents, sampled_weights, delta_3sigma
|
||||||
|
else:
|
||||||
|
return sampled_latents, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_fix(latent, mean, image_size, point_num=49):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
radius = int((int(point_num**0.5) - 1) / 2)
|
||||||
|
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = centroid + delta
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
sampled_latents = bilinear_sampler(
|
||||||
|
latent, coords
|
||||||
|
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||||
|
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||||
|
|
||||||
|
return sampled_latents, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_fix_pyramid(
|
||||||
|
latent, feat_pyramid, scale_weight, mean, image_size, point_num=25
|
||||||
|
):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
# scale weight [B, H*W, layer_num]
|
||||||
|
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
radius = int((int(point_num**0.5) - 1) / 2)
|
||||||
|
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
sampled_latents = []
|
||||||
|
for i in range(len(feat_pyramid)):
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = (centroid + delta) / 2**i
|
||||||
|
coords = rearrange(
|
||||||
|
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
|
||||||
|
)
|
||||||
|
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
|
||||||
|
|
||||||
|
sampled_latents = torch.stack(
|
||||||
|
sampled_latents, dim=1
|
||||||
|
) # [B, layer_num, dim, H*W, point_num]
|
||||||
|
sampled_latents = sampled_latents.permute(
|
||||||
|
0, 3, 4, 2, 1
|
||||||
|
) # [B, H*W, point_num, dim, layer_num]
|
||||||
|
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
|
||||||
|
vis_out = scale_weight
|
||||||
|
scale_weight = torch.unsqueeze(
|
||||||
|
torch.unsqueeze(scale_weight, dim=2), dim=2
|
||||||
|
) # [B, HW, 1, 1, layer_num]
|
||||||
|
|
||||||
|
weighted_latent = torch.sum(
|
||||||
|
sampled_latents * scale_weight, dim=-1
|
||||||
|
) # [B, H*W, point_num, dim]
|
||||||
|
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||||
|
|
||||||
|
return weighted_latent, sampled_weights, vis_out
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_pyramid(
|
||||||
|
latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25
|
||||||
|
):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W]
|
||||||
|
# scale weight [B, H*W, layer_num]
|
||||||
|
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(
|
||||||
|
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||||
|
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||||
|
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
radius = int((int(point_num**0.5) - 1) / 2)
|
||||||
|
|
||||||
|
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
delta_3sigma = (
|
||||||
|
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
sampled_latents = []
|
||||||
|
for i in range(len(feat_pyramid)):
|
||||||
|
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = (centroid + delta_3sigma) / 2**i
|
||||||
|
coords = rearrange(
|
||||||
|
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
|
||||||
|
)
|
||||||
|
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
|
||||||
|
|
||||||
|
sampled_latents = torch.stack(
|
||||||
|
sampled_latents, dim=1
|
||||||
|
) # [B, layer_num, dim, H*W, point_num]
|
||||||
|
sampled_latents = sampled_latents.permute(
|
||||||
|
0, 3, 4, 2, 1
|
||||||
|
) # [B, H*W, point_num, dim, layer_num]
|
||||||
|
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
|
||||||
|
vis_out = scale_weight
|
||||||
|
scale_weight = torch.unsqueeze(
|
||||||
|
torch.unsqueeze(scale_weight, dim=2), dim=2
|
||||||
|
) # [B, HW, 1, 1, layer_num]
|
||||||
|
|
||||||
|
weighted_latent = torch.sum(
|
||||||
|
sampled_latents * scale_weight, dim=-1
|
||||||
|
) # [B, H*W, point_num, dim]
|
||||||
|
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||||
|
|
||||||
|
return weighted_latent, sampled_weights, vis_out
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25):
|
||||||
|
"""different heads have different mean"""
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W, heands]
|
||||||
|
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
_, _, _, _, HEADS = mean.shape
|
||||||
|
STD_MAX = 20
|
||||||
|
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
|
||||||
|
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
|
||||||
|
|
||||||
|
radius = int((int(point_num**0.5) - 1) / 2)
|
||||||
|
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = (
|
||||||
|
torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||||
|
.to(mean.device)
|
||||||
|
.repeat(HEADS, 1, 1, 1)
|
||||||
|
) # [HEADS, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
|
||||||
|
coords = centroid + delta
|
||||||
|
coords = rearrange(
|
||||||
|
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
|
||||||
|
)
|
||||||
|
sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum]
|
||||||
|
sampled_latents = sampled_latents.permute(
|
||||||
|
0, 2, 3, 1
|
||||||
|
) # [B, H*W*HEADS, pointnum, dim]
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||||
|
return sampled_latents, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler_gaussian_fix_pyramid_MH(
|
||||||
|
latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25
|
||||||
|
):
|
||||||
|
# latent [B, H*W, D]
|
||||||
|
# mean [B, 2, H, W, heands]
|
||||||
|
# scale_head weight [B, H*W, layer_num*heads]
|
||||||
|
|
||||||
|
H, W = image_size
|
||||||
|
B, HW, D = latent.shape
|
||||||
|
_, _, _, _, HEADS = mean.shape
|
||||||
|
|
||||||
|
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
|
||||||
|
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
|
||||||
|
|
||||||
|
radius = int((int(point_num**0.5) - 1) / 2)
|
||||||
|
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
mean.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
sampled_latents = []
|
||||||
|
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
|
||||||
|
for i in range(len(feat_pyramid)):
|
||||||
|
coords = (centroid) / 2**i + delta
|
||||||
|
coords = rearrange(
|
||||||
|
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
|
||||||
|
)
|
||||||
|
sampled_latents.append(
|
||||||
|
bilinear_sampler(feat_pyramid[i], coords)
|
||||||
|
) # [B, dim, H*W*HEADS, point_num]
|
||||||
|
|
||||||
|
sampled_latents = torch.stack(
|
||||||
|
sampled_latents, dim=1
|
||||||
|
) # [B, layer_num, dim, H*W*HEADS, point_num]
|
||||||
|
sampled_latents = sampled_latents.permute(
|
||||||
|
0, 3, 4, 2, 1
|
||||||
|
) # [B, H*W*HEADS, point_num, dim, layer_num]
|
||||||
|
|
||||||
|
scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1)
|
||||||
|
scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num]
|
||||||
|
scale_head_weight = torch.unsqueeze(
|
||||||
|
torch.unsqueeze(scale_head_weight, dim=2), dim=2
|
||||||
|
) # [B, H*W*HEADS, 1, 1, layer_num]
|
||||||
|
|
||||||
|
weighted_latent = torch.sum(
|
||||||
|
sampled_latents * scale_head_weight, dim=-1
|
||||||
|
) # [B, H*W*HEADS, point_num, dim]
|
||||||
|
|
||||||
|
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||||
|
|
||||||
|
return weighted_latent, sampled_weights
|
||||||
|
|
||||||
|
|
||||||
|
def sampler(feat, center, window_size):
|
||||||
|
# feat [B, C, H, W]
|
||||||
|
# center [B, 2, H, W]
|
||||||
|
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
B, H, W, C = center.shape
|
||||||
|
|
||||||
|
radius = window_size // 2
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
center.device
|
||||||
|
) # [B*H*W, window_size, point_num**0.5, 2]
|
||||||
|
|
||||||
|
center = center.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = center + delta
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
sampled_latents = bilinear_sampler(
|
||||||
|
feat, coords
|
||||||
|
) # [B*H*W, dim, window_size, window_size]
|
||||||
|
# sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return sampled_latents
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_tokens(feat, center, window_size, sampler):
|
||||||
|
# feat [B, C, H, W]
|
||||||
|
# center [B, 2, H, W]
|
||||||
|
radius = window_size // 2
|
||||||
|
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||||
|
center.device
|
||||||
|
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||||
|
|
||||||
|
B, H, W, C = center.shape
|
||||||
|
centroid = center.reshape(B * H * W, 1, 1, 2)
|
||||||
|
coords = centroid + delta
|
||||||
|
|
||||||
|
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||||
|
if sampler == "nn":
|
||||||
|
sampled_latents = indexing(feat, coords)
|
||||||
|
elif sampler == "bilinear":
|
||||||
|
sampled_latents = bilinear_sampler(feat, coords)
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid sampler")
|
||||||
|
# [B, dim, H*W, point_num]
|
||||||
|
|
||||||
|
return sampled_latents
|
||||||
|
|
||||||
|
|
||||||
|
def pyramid_retrieve_tokens(
|
||||||
|
feat_pyramid, center, image_size, window_sizes, sampler="bilinear"
|
||||||
|
):
|
||||||
|
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||||
|
sampled_latents_pyramid = []
|
||||||
|
for idx in range(len(window_sizes)):
|
||||||
|
sampled_latents_pyramid.append(
|
||||||
|
retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler)
|
||||||
|
)
|
||||||
|
center = center / 2
|
||||||
|
|
||||||
|
return torch.cat(sampled_latents_pyramid, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.net(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5):
|
||||||
|
super().__init__()
|
||||||
|
self.FC1 = nn.Linear(in_dim, innter_dim)
|
||||||
|
self.FC_out = nn.Linear(innter_dim, out_dim)
|
||||||
|
self.relu = torch.nn.LeakyReLU(0.2)
|
||||||
|
self.FC_inter = nn.ModuleList(
|
||||||
|
[nn.Linear(innter_dim, innter_dim) for i in range(depth)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.FC1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
for inter_fc in self.FC_inter:
|
||||||
|
x = inter_fc(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.FC_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False):
|
||||||
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.num_kv_tokens = num_kv_tokens
|
||||||
|
self.scale = (dim / heads) ** -0.5
|
||||||
|
self.rpe = cfg.rpe
|
||||||
|
self.attend = nn.Softmax(dim=-1)
|
||||||
|
self.use_rpe = use_rpe
|
||||||
|
|
||||||
|
if use_rpe:
|
||||||
|
if rpe_bias is None:
|
||||||
|
if self.rpe == "element-wise":
|
||||||
|
self.rpe_bias = nn.Parameter(
|
||||||
|
torch.zeros(heads, self.num_kv_tokens, dim // heads)
|
||||||
|
)
|
||||||
|
elif self.rpe == "head-wise":
|
||||||
|
self.rpe_bias = nn.Parameter(
|
||||||
|
torch.zeros(1, heads, 1, self.num_kv_tokens)
|
||||||
|
)
|
||||||
|
elif self.rpe == "token-wise":
|
||||||
|
self.rpe_bias = nn.Parameter(
|
||||||
|
torch.zeros(1, 1, 1, self.num_kv_tokens)
|
||||||
|
) # 81 is point_num
|
||||||
|
elif self.rpe == "implicit":
|
||||||
|
pass
|
||||||
|
# self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2)
|
||||||
|
# raise ValueError('Implicit Encoding Not Implemented')
|
||||||
|
elif self.rpe == "element-wise-value":
|
||||||
|
self.rpe_bias = nn.Parameter(
|
||||||
|
torch.zeros(heads, self.num_kv_tokens, dim // heads)
|
||||||
|
)
|
||||||
|
self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim))
|
||||||
|
else:
|
||||||
|
raise ValueError("Not Implemented")
|
||||||
|
else:
|
||||||
|
self.rpe_bias = rpe_bias
|
||||||
|
|
||||||
|
def attend_with_rpe(self, Q, K, rpe_bias):
|
||||||
|
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
|
||||||
|
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
dots = (
|
||||||
|
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
|
||||||
|
) # (b hw) heads 1 pointnum
|
||||||
|
if self.use_rpe:
|
||||||
|
if self.rpe == "element-wise":
|
||||||
|
rpe_bias_weight = (
|
||||||
|
einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale
|
||||||
|
) # (b hw) heads 1 pointnum
|
||||||
|
dots = dots + rpe_bias_weight
|
||||||
|
elif self.rpe == "implicit":
|
||||||
|
pass
|
||||||
|
rpe_bias_weight = (
|
||||||
|
einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale
|
||||||
|
) # (b hw) heads 1 pointnum
|
||||||
|
dots = dots + rpe_bias_weight
|
||||||
|
elif self.rpe == "head-wise" or self.rpe == "token-wise":
|
||||||
|
dots = dots + rpe_bias
|
||||||
|
|
||||||
|
return self.attend(dots), dots
|
||||||
|
|
||||||
|
def forward(self, Q, K, V, rpe_bias=None):
|
||||||
|
if self.use_rpe:
|
||||||
|
if rpe_bias is None or self.rpe == "element-wise":
|
||||||
|
rpe_bias = self.rpe_bias
|
||||||
|
else:
|
||||||
|
rpe_bias = rearrange(
|
||||||
|
rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads
|
||||||
|
)
|
||||||
|
attn, dots = self.attend_with_rpe(Q, K, rpe_bias)
|
||||||
|
else:
|
||||||
|
attn, dots = self.attend_with_rpe(Q, K, None)
|
||||||
|
B, HW, _ = Q.shape
|
||||||
|
|
||||||
|
if V is not None:
|
||||||
|
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||||
|
|
||||||
|
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||||
|
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||||
|
else:
|
||||||
|
out = None
|
||||||
|
|
||||||
|
# dots = torch.squeeze(dots, 2)
|
||||||
|
# dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW)
|
||||||
|
|
||||||
|
return out, dots
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import timm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class twins_svt_large(nn.Module):
|
||||||
|
def __init__(self, pretrained=True):
|
||||||
|
super().__init__()
|
||||||
|
self.svt = timm.create_model("twins_svt_large", pretrained=pretrained)
|
||||||
|
|
||||||
|
del self.svt.head
|
||||||
|
del self.svt.patch_embeds[2]
|
||||||
|
del self.svt.patch_embeds[2]
|
||||||
|
del self.svt.blocks[2]
|
||||||
|
del self.svt.blocks[2]
|
||||||
|
del self.svt.pos_block[2]
|
||||||
|
del self.svt.pos_block[2]
|
||||||
|
self.svt.norm.weight.requires_grad = False
|
||||||
|
self.svt.norm.bias.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x, data=None, layer=2, return_feat=False):
|
||||||
|
B = x.shape[0]
|
||||||
|
if return_feat:
|
||||||
|
feat = []
|
||||||
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.svt.patch_embeds,
|
||||||
|
self.svt.pos_drops,
|
||||||
|
self.svt.blocks,
|
||||||
|
self.svt.pos_block,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
x, size = embed(x)
|
||||||
|
x = drop(x)
|
||||||
|
for j, blk in enumerate(blocks):
|
||||||
|
x = blk(x, size)
|
||||||
|
if j == 0:
|
||||||
|
x = pos_blk(x, size)
|
||||||
|
if i < len(self.svt.depths) - 1:
|
||||||
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
if return_feat:
|
||||||
|
feat.append(x)
|
||||||
|
if i == layer - 1:
|
||||||
|
break
|
||||||
|
if return_feat:
|
||||||
|
return x, feat
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_params(self, layer=2):
|
||||||
|
num = 0
|
||||||
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.svt.patch_embeds,
|
||||||
|
self.svt.pos_drops,
|
||||||
|
self.svt.blocks,
|
||||||
|
self.svt.pos_block,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
for param in embed.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
for param in drop.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
for param in blocks.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
for param in pos_blk.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
if i == layer - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
for param in self.svt.head.parameters():
|
||||||
|
num += np.prod(param.size())
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
class twins_svt_large_context(nn.Module):
|
||||||
|
def __init__(self, pretrained=True):
|
||||||
|
super().__init__()
|
||||||
|
self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained)
|
||||||
|
|
||||||
|
def forward(self, x, data=None, layer=2):
|
||||||
|
B = x.shape[0]
|
||||||
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.svt.patch_embeds,
|
||||||
|
self.svt.pos_drops,
|
||||||
|
self.svt.blocks,
|
||||||
|
self.svt.pos_block,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
x, size = embed(x)
|
||||||
|
x = drop(x)
|
||||||
|
for j, blk in enumerate(blocks):
|
||||||
|
x = blk(x, size)
|
||||||
|
if j == 0:
|
||||||
|
x = pos_blk(x, size)
|
||||||
|
if i < len(self.svt.depths) - 1:
|
||||||
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
if i == layer - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
m = twins_svt_large()
|
||||||
|
input = torch.randn(2, 3, 400, 800)
|
||||||
|
out = m.extract_feature(input)
|
||||||
|
print(out.shape)
|
||||||
90
gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py
Normal file
90
gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .utils.utils import bilinear_sampler, coords_grid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import alt_cuda_corr
|
||||||
|
except:
|
||||||
|
# alt_cuda_corr is not compiled
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
self.corr_pyramid = []
|
||||||
|
|
||||||
|
# all pairs correlation
|
||||||
|
corr = CorrBlock.corr(fmap1, fmap2)
|
||||||
|
|
||||||
|
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||||
|
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||||
|
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
for i in range(self.num_levels - 1):
|
||||||
|
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
|
||||||
|
def __call__(self, coords):
|
||||||
|
r = self.radius
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
batch, h1, w1, _ = coords.shape
|
||||||
|
|
||||||
|
out_pyramid = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
corr = self.corr_pyramid[i]
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||||
|
|
||||||
|
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||||
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords_lvl = centroid_lvl + delta_lvl
|
||||||
|
corr = bilinear_sampler(corr, coords_lvl)
|
||||||
|
corr = corr.view(batch, h1, w1, -1)
|
||||||
|
out_pyramid.append(corr)
|
||||||
|
|
||||||
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
|
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def corr(fmap1, fmap2):
|
||||||
|
batch, dim, ht, wd = fmap1.shape
|
||||||
|
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||||
|
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||||
|
|
||||||
|
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||||
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
||||||
|
|
||||||
|
class AlternateCorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
|
||||||
|
self.pyramid = [(fmap1, fmap2)]
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||||
|
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||||
|
self.pyramid.append((fmap1, fmap2))
|
||||||
|
|
||||||
|
def __call__(self, coords):
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
B, H, W, _ = coords.shape
|
||||||
|
dim = self.pyramid[0][0].shape[1]
|
||||||
|
|
||||||
|
corr_list = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
r = self.radius
|
||||||
|
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||||
|
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||||
|
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||||
|
corr_list.append(corr.squeeze(1))
|
||||||
|
|
||||||
|
corr = torch.stack(corr_list, dim=1)
|
||||||
|
corr = corr.reshape(B, -1, H, W)
|
||||||
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
267
gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py
Normal file
267
gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
y = self.relu(self.norm3(self.conv3(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 64
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=2)
|
||||||
|
self.layer3 = self._make_layer(128, stride=2)
|
||||||
|
|
||||||
|
# output convolution
|
||||||
|
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SmallEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||||
|
super(SmallEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 32
|
||||||
|
self.layer1 = self._make_layer(32, stride=1)
|
||||||
|
self.layer2 = self._make_layer(64, stride=2)
|
||||||
|
self.layer3 = self._make_layer(96, stride=2)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEncodingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, max_shape=(256, 256)):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
pe = torch.zeros((d_model, *max_shape))
|
||||||
|
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||||
|
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, d_model // 2, 2).float()
|
||||||
|
* (-math.log(10000.0) / d_model // 2)
|
||||||
|
)
|
||||||
|
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||||
|
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||||
|
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||||
|
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||||
|
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||||
|
|
||||||
|
self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [N, C, H, W]
|
||||||
|
"""
|
||||||
|
return x + self.pe[:, :, : x.size(2), : x.size(3)]
|
||||||
|
|
||||||
|
|
||||||
|
class LinearPositionEncoding(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, max_shape=(256, 256)):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
pe = torch.zeros((d_model, *max_shape))
|
||||||
|
y_position = (
|
||||||
|
torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1
|
||||||
|
) / max_shape[0]
|
||||||
|
x_position = (
|
||||||
|
torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1
|
||||||
|
) / max_shape[1]
|
||||||
|
div_term = torch.arange(0, d_model // 2, 2).float()
|
||||||
|
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||||
|
pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi)
|
||||||
|
pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi)
|
||||||
|
pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi)
|
||||||
|
pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi)
|
||||||
|
|
||||||
|
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [N, C, H, W]
|
||||||
|
"""
|
||||||
|
# assert x.shape[2] == 80 and x.shape[3] == 80
|
||||||
|
|
||||||
|
return x + self.pe[:, :, : x.size(2), : x.size(3)]
|
||||||
|
|
||||||
|
|
||||||
|
class LearnedPositionEncoding(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, max_shape=(80, 80)):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [N, C, H, W]
|
||||||
|
"""
|
||||||
|
# assert x.shape[2] == 80 and x.shape[3] == 80
|
||||||
|
|
||||||
|
return x + self.pe
|
||||||
154
gimm_vfi_arch/generalizable_INR/flowformer/core/update.py
Normal file
154
gimm_vfi_arch/generalizable_INR/flowformer/core/update.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FlowHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(ConvGRU, self).__init__()
|
||||||
|
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
|
||||||
|
z = torch.sigmoid(self.convz(hx))
|
||||||
|
r = torch.sigmoid(self.convr(hx))
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||||
|
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(SepConvGRU, self).__init__()
|
||||||
|
self.convz1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convr1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convq1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convr2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convq2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
# horizontal
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz1(hx))
|
||||||
|
r = torch.sigmoid(self.convr1(hx))
|
||||||
|
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
# vertical
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz2(hx))
|
||||||
|
r = torch.sigmoid(self.convr2(hx))
|
||||||
|
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SmallMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(SmallMotionEncoder, self).__init__()
|
||||||
|
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||||
|
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicMotionEncoder, self).__init__()
|
||||||
|
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||||
|
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
cor = F.relu(self.convc2(cor))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class SmallUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=96):
|
||||||
|
super(SmallUpdateBlock, self).__init__()
|
||||||
|
self.encoder = SmallMotionEncoder(args)
|
||||||
|
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
return net, None, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||||
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, upsample=True):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
# scale mask to balence gradients
|
||||||
|
mask = 0.25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
113
gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py
Normal file
113
gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from scipy import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
"""Pads images such that dimensions are divisible by 8"""
|
||||||
|
|
||||||
|
def __init__(self, dims, mode="sintel"):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||||
|
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||||
|
if mode == "sintel":
|
||||||
|
self._pad = [
|
||||||
|
pad_wd // 2,
|
||||||
|
pad_wd - pad_wd // 2,
|
||||||
|
pad_ht // 2,
|
||||||
|
pad_ht - pad_ht // 2,
|
||||||
|
]
|
||||||
|
elif mode == "kitti400":
|
||||||
|
self._pad = [0, 0, 0, 400 - self.ht]
|
||||||
|
else:
|
||||||
|
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, x):
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||||
|
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||||
|
|
||||||
|
|
||||||
|
def forward_interpolate(flow):
|
||||||
|
flow = flow.detach().cpu().numpy()
|
||||||
|
dx, dy = flow[0], flow[1]
|
||||||
|
|
||||||
|
ht, wd = dx.shape
|
||||||
|
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||||
|
|
||||||
|
x1 = x0 + dx
|
||||||
|
y1 = y0 + dy
|
||||||
|
|
||||||
|
x1 = x1.reshape(-1)
|
||||||
|
y1 = y1.reshape(-1)
|
||||||
|
dx = dx.reshape(-1)
|
||||||
|
dy = dy.reshape(-1)
|
||||||
|
|
||||||
|
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||||
|
x1 = x1[valid]
|
||||||
|
y1 = y1[valid]
|
||||||
|
dx = dx[valid]
|
||||||
|
dy = dy[valid]
|
||||||
|
|
||||||
|
flow_x = interpolate.griddata(
|
||||||
|
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
flow_y = interpolate.griddata(
|
||||||
|
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
flow = np.stack([flow_x, flow_y], axis=0)
|
||||||
|
return torch.from_numpy(flow).float()
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||||
|
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||||
|
xgrid = 2 * xgrid / (W - 1) - 1
|
||||||
|
ygrid = 2 * ygrid / (H - 1) - 1
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def indexing(img, coords, mask=False):
|
||||||
|
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||||
|
"""
|
||||||
|
TODO: directly indexing features instead of sampling
|
||||||
|
"""
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||||
|
xgrid = 2 * xgrid / (W - 1) - 1
|
||||||
|
ygrid = 2 * ygrid / (H - 1) - 1
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True, mode="nearest")
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd):
|
||||||
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||||
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
return coords[None].repeat(batch, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def upflow8(flow, mode="bilinear"):
|
||||||
|
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||||
|
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||||
253
gimm_vfi_arch/generalizable_INR/gimm.py
Normal file
253
gimm_vfi_arch/generalizable_INR/gimm.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# motif: https://github.com/sichun233746/MoTIF
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .configs import GIMMConfig
|
||||||
|
from .modules.coord_sampler import CoordSampler3D
|
||||||
|
from .modules.fi_components import LateralBlock
|
||||||
|
from .modules.hyponet import HypoNet
|
||||||
|
from .modules.fi_utils import warp
|
||||||
|
|
||||||
|
from .modules.softsplat import softsplat
|
||||||
|
|
||||||
|
|
||||||
|
class GIMM(nn.Module):
|
||||||
|
Config = GIMMConfig
|
||||||
|
|
||||||
|
def __init__(self, config: GIMMConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config = config.copy()
|
||||||
|
self.hyponet_config = config.hyponet
|
||||||
|
self.coord_sampler = CoordSampler3D(config.coord_range)
|
||||||
|
self.fwarp_type = config.fwarp_type
|
||||||
|
|
||||||
|
# Motion Encoder
|
||||||
|
channel = 32
|
||||||
|
in_dim = 2
|
||||||
|
self.cnn_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Latent Refiner
|
||||||
|
channel = 64
|
||||||
|
in_dim = 64
|
||||||
|
self.res_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.g_filter = torch.nn.Parameter(
|
||||||
|
torch.FloatTensor(
|
||||||
|
[
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
]
|
||||||
|
).reshape(1, 1, 1, 3, 3),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
|
||||||
|
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
|
||||||
|
|
||||||
|
def cal_splatting_weights(self, raft_flow01, raft_flow10):
|
||||||
|
batch_size = raft_flow01.shape[0]
|
||||||
|
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
|
||||||
|
|
||||||
|
## flow variance metric
|
||||||
|
sqaure_mean, mean_square = torch.split(
|
||||||
|
F.conv3d(
|
||||||
|
F.pad(
|
||||||
|
torch.cat([raft_flows**2, raft_flows], 1),
|
||||||
|
(1, 1, 1, 1),
|
||||||
|
mode="reflect",
|
||||||
|
).unsqueeze(1),
|
||||||
|
self.g_filter,
|
||||||
|
).squeeze(1),
|
||||||
|
2,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
var = (
|
||||||
|
(sqaure_mean - mean_square**2)
|
||||||
|
.clamp(1e-9, None)
|
||||||
|
.sqrt()
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
var01 = var[:batch_size]
|
||||||
|
var10 = var[batch_size:]
|
||||||
|
|
||||||
|
## flow warp metirc
|
||||||
|
f01_warp = -warp(raft_flow10, raft_flow01)
|
||||||
|
f10_warp = -warp(raft_flow01, raft_flow10)
|
||||||
|
err01 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f01_warp, target=raft_flow01, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
err02 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f10_warp, target=raft_flow10, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
|
||||||
|
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
|
||||||
|
|
||||||
|
return weights1, weights2
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, xs, coord=None, keep_xs_shape=True, ori_flow=None, timesteps=None
|
||||||
|
):
|
||||||
|
coord = self.sample_coord_input(xs) if coord is None else coord
|
||||||
|
raft_flow01 = ori_flow[:, :, 0]
|
||||||
|
raft_flow10 = ori_flow[:, :, 1]
|
||||||
|
|
||||||
|
# calculate splatting metrics
|
||||||
|
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
|
||||||
|
# b,c,h,w
|
||||||
|
pixel_latent_0 = self.cnn_encoder(xs[:, :, 0])
|
||||||
|
pixel_latent_1 = self.cnn_encoder(xs[:, :, 1])
|
||||||
|
pixel_latent = []
|
||||||
|
|
||||||
|
modulation_params_dict = None
|
||||||
|
strtype = self.fwarp_type
|
||||||
|
if isinstance(timesteps, list):
|
||||||
|
assert isinstance(coord, list)
|
||||||
|
assert len(timesteps) == len(coord)
|
||||||
|
for i, cur_t in enumerate(timesteps):
|
||||||
|
cur_t = cur_t.reshape(-1, 1, 1, 1)
|
||||||
|
tmp_pixel_latent_0 = softsplat(
|
||||||
|
tenIn=pixel_latent_0,
|
||||||
|
tenFlow=raft_flow01 * cur_t,
|
||||||
|
tenMetric=weights1,
|
||||||
|
strMode=strtype + "-zeroeps",
|
||||||
|
)
|
||||||
|
tmp_pixel_latent_1 = softsplat(
|
||||||
|
tenIn=pixel_latent_1,
|
||||||
|
tenFlow=raft_flow10 * (1 - cur_t),
|
||||||
|
tenMetric=weights2,
|
||||||
|
strMode=strtype + "-zeroeps",
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = torch.cat(
|
||||||
|
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
|
||||||
|
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
|
||||||
|
)
|
||||||
|
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
|
||||||
|
|
||||||
|
all_outputs = []
|
||||||
|
for idx, c in enumerate(coord):
|
||||||
|
outputs = self.hyponet(
|
||||||
|
c,
|
||||||
|
modulation_params_dict=modulation_params_dict,
|
||||||
|
pixel_latent=pixel_latent[idx],
|
||||||
|
)
|
||||||
|
if keep_xs_shape:
|
||||||
|
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
|
||||||
|
outputs = outputs.permute(0, -1, *permute_idx_range)
|
||||||
|
all_outputs.append(outputs)
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
else:
|
||||||
|
cur_t = timesteps.reshape(-1, 1, 1, 1)
|
||||||
|
tmp_pixel_latent_0 = softsplat(
|
||||||
|
tenIn=pixel_latent_0,
|
||||||
|
tenFlow=raft_flow01 * cur_t,
|
||||||
|
tenMetric=weights1,
|
||||||
|
strMode=strtype + "-zeroeps",
|
||||||
|
)
|
||||||
|
tmp_pixel_latent_1 = softsplat(
|
||||||
|
tenIn=pixel_latent_1,
|
||||||
|
tenFlow=raft_flow10 * (1 - cur_t),
|
||||||
|
tenMetric=weights2,
|
||||||
|
strMode=strtype + "-zeroeps",
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = torch.cat(
|
||||||
|
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
|
||||||
|
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
|
||||||
|
)
|
||||||
|
pixel_latent = tmp_pixel_latent.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# predict all pixels of coord after applying the modulation_parms into hyponet
|
||||||
|
outputs = self.hyponet(
|
||||||
|
coord,
|
||||||
|
modulation_params_dict=modulation_params_dict,
|
||||||
|
pixel_latent=pixel_latent,
|
||||||
|
)
|
||||||
|
if keep_xs_shape:
|
||||||
|
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
|
||||||
|
outputs = outputs.permute(0, -1, *permute_idx_range)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def compute_loss(self, preds, targets, reduction="mean", single=False):
|
||||||
|
assert reduction in ["mean", "sum", "none"]
|
||||||
|
batch_size = preds.shape[0]
|
||||||
|
sample_mses = 0
|
||||||
|
assert preds.shape[2] == 1
|
||||||
|
assert targets.shape[2] == 1
|
||||||
|
for i in range(preds.shape[2]):
|
||||||
|
sample_mses += torch.reshape(
|
||||||
|
(preds[:, :, i] - targets[:, :, i]) ** 2, (batch_size, -1)
|
||||||
|
).mean(dim=-1)
|
||||||
|
sample_mses = sample_mses / preds.shape[2]
|
||||||
|
if reduction == "mean":
|
||||||
|
total_loss = sample_mses.mean()
|
||||||
|
psnr = (-10 * torch.log10(sample_mses)).mean()
|
||||||
|
elif reduction == "sum":
|
||||||
|
total_loss = sample_mses.sum()
|
||||||
|
psnr = (-10 * torch.log10(sample_mses)).sum()
|
||||||
|
else:
|
||||||
|
total_loss = sample_mses
|
||||||
|
psnr = -10 * torch.log10(sample_mses)
|
||||||
|
|
||||||
|
return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr}
|
||||||
|
|
||||||
|
def sample_coord_input(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
s_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=None,
|
||||||
|
upsample_ratio=1.0,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
assert device is not None
|
||||||
|
assert coord_range is None
|
||||||
|
coord_inputs = self.coord_sampler(
|
||||||
|
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||||
|
)
|
||||||
|
return coord_inputs
|
||||||
471
gimm_vfi_arch/generalizable_INR/gimmvfi_f.py
Normal file
471
gimm_vfi_arch/generalizable_INR/gimmvfi_f.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# amt: https://github.com/MCG-NKU/AMT
|
||||||
|
# motif: https://github.com/sichun233746/MoTIF
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .configs import GIMMVFIConfig
|
||||||
|
from .modules.coord_sampler import CoordSampler3D
|
||||||
|
from .modules.hyponet import HypoNet
|
||||||
|
from .modules.fi_components import *
|
||||||
|
#from .flowformer import initialize_Flowformer
|
||||||
|
from .modules.fi_utils import normalize_flow, unnormalize_flow, warp, resize
|
||||||
|
from .raft.corr import BidirCorrBlock
|
||||||
|
from .modules.softsplat import softsplat
|
||||||
|
|
||||||
|
|
||||||
|
class GIMMVFI_F(nn.Module):
|
||||||
|
Config = GIMMVFIConfig
|
||||||
|
|
||||||
|
def __init__(self, dtype, config: GIMMVFIConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config = config.copy()
|
||||||
|
self.hyponet_config = config.hyponet
|
||||||
|
self.raft_iter = config.raft_iter
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
######### Encoder and Decoder Settings #########
|
||||||
|
#self.flow_estimator = initialize_Flowformer()
|
||||||
|
f_dims = [256, 128]
|
||||||
|
|
||||||
|
skip_channels = f_dims[-1] // 2
|
||||||
|
self.num_flows = 3
|
||||||
|
self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels)
|
||||||
|
self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels)
|
||||||
|
|
||||||
|
self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0)
|
||||||
|
self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None)
|
||||||
|
|
||||||
|
self.amt_comb_block = nn.Sequential(
|
||||||
|
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
|
||||||
|
nn.PReLU(6 * self.num_flows),
|
||||||
|
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
################ GIMM settings #################
|
||||||
|
self.coord_sampler = CoordSampler3D(config.coord_range)
|
||||||
|
|
||||||
|
self.g_filter = torch.nn.Parameter(
|
||||||
|
torch.FloatTensor(
|
||||||
|
[
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
]
|
||||||
|
).reshape(1, 1, 1, 3, 3),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.fwarp_type = config.fwarp_type
|
||||||
|
|
||||||
|
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
|
||||||
|
channel = 32
|
||||||
|
in_dim = 2
|
||||||
|
self.cnn_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
channel = 64
|
||||||
|
in_dim = 64
|
||||||
|
self.res_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
|
||||||
|
|
||||||
|
def _get_updateblock(self, cdim, scale_factor=None):
|
||||||
|
return BasicUpdateBlock(
|
||||||
|
cdim=cdim,
|
||||||
|
hidden_dim=192,
|
||||||
|
flow_dim=64,
|
||||||
|
corr_dim=256,
|
||||||
|
corr_dim2=192,
|
||||||
|
fc_dim=188,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
corr_levels=4,
|
||||||
|
radius=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def cal_bidirection_flow(self, im0, im1):
|
||||||
|
f01, features0, fnet0 = self.flow_estimator(
|
||||||
|
im0, im1, return_feat=True, iters=None
|
||||||
|
)
|
||||||
|
f10, features1, fnet1 = self.flow_estimator(
|
||||||
|
im1, im0, return_feat=True, iters=None
|
||||||
|
)
|
||||||
|
f01 = f01[0]
|
||||||
|
f10 = f10[0]
|
||||||
|
corr_fn = BidirCorrBlock(fnet0, fnet1, radius=4)
|
||||||
|
flow01 = f01.unsqueeze(2)
|
||||||
|
flow10 = f10.unsqueeze(2)
|
||||||
|
noraml_flows = torch.cat([flow01, -flow10], dim=2)
|
||||||
|
noraml_flows, flow_scalers = normalize_flow(noraml_flows)
|
||||||
|
|
||||||
|
ori_flows = torch.cat([flow01, flow10], dim=2)
|
||||||
|
return (
|
||||||
|
noraml_flows,
|
||||||
|
ori_flows,
|
||||||
|
flow_scalers,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict_flow(self, f, coord, t, flows):
|
||||||
|
raft_flow01 = flows[:, :, 0].detach()
|
||||||
|
raft_flow10 = flows[:, :, 1].detach()
|
||||||
|
|
||||||
|
# calculate splatting metrics
|
||||||
|
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
|
||||||
|
strtype = self.fwarp_type + "-zeroeps"
|
||||||
|
|
||||||
|
# b,c,h,w
|
||||||
|
pixel_latent_0 = self.cnn_encoder(f[:, :, 0])
|
||||||
|
pixel_latent_1 = self.cnn_encoder(f[:, :, 1])
|
||||||
|
pixel_latent = []
|
||||||
|
|
||||||
|
for i, cur_t in enumerate(t):
|
||||||
|
cur_t = cur_t.reshape(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
tmp_pixel_latent_0 = softsplat(
|
||||||
|
tenIn=pixel_latent_0,
|
||||||
|
tenFlow=raft_flow01 * cur_t,
|
||||||
|
tenMetric=weights1,
|
||||||
|
strMode=strtype,
|
||||||
|
)
|
||||||
|
tmp_pixel_latent_1 = softsplat(
|
||||||
|
tenIn=pixel_latent_1,
|
||||||
|
tenFlow=raft_flow10 * (1 - cur_t),
|
||||||
|
tenMetric=weights2,
|
||||||
|
strMode=strtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
tmp_pixel_latent = torch.cat(
|
||||||
|
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
|
||||||
|
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
|
||||||
|
)
|
||||||
|
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
|
||||||
|
|
||||||
|
all_outputs = []
|
||||||
|
permute_idx_range = [i for i in range(1, f.ndim - 1)]
|
||||||
|
for idx, c in enumerate(coord):
|
||||||
|
assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze()
|
||||||
|
assert isinstance(c, tuple)
|
||||||
|
|
||||||
|
if c[1] is None:
|
||||||
|
outputs = self.hyponet(
|
||||||
|
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||||
|
).permute(0, -1, *permute_idx_range)
|
||||||
|
else:
|
||||||
|
outputs = self.hyponet(
|
||||||
|
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||||
|
)
|
||||||
|
all_outputs.append(outputs)
|
||||||
|
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1):
|
||||||
|
ft0 = scale * resize(ft0, scale_factor=scale)
|
||||||
|
ft1 = scale * resize(ft1, scale_factor=scale)
|
||||||
|
mask = resize(mask, scale_factor=scale).sigmoid()
|
||||||
|
img0_warp = warp(img0, ft0)
|
||||||
|
img1_warp = warp(img1, ft1)
|
||||||
|
img_warp = mask * img0_warp + (1 - mask) * img1_warp
|
||||||
|
return img_warp
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def frame_synthesize(
|
||||||
|
self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
flow_t: b,2,h,w
|
||||||
|
cur_t: b,1,1,1
|
||||||
|
"""
|
||||||
|
batch_size = img_xs.shape[0]
|
||||||
|
img0 = 2 * img_xs[:, :, 0] - 1.0
|
||||||
|
img1 = 2 * img_xs[:, :, 1] - 1.0
|
||||||
|
|
||||||
|
##################### update the predicted flow #####################
|
||||||
|
## initialize coordinates for looking up
|
||||||
|
lookup_coord = self.flow_estimator.build_coord(img_xs[:, :, 0]).to(
|
||||||
|
img_xs[:, :, 0].device
|
||||||
|
)
|
||||||
|
|
||||||
|
flow_t0_fullsize = flow_t * (-cur_t)
|
||||||
|
flow_t1_fullsize = flow_t * (1.0 - cur_t)
|
||||||
|
|
||||||
|
inv = 1 / 4
|
||||||
|
flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv)
|
||||||
|
flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv)
|
||||||
|
|
||||||
|
############################# scale 1/4 #############################
|
||||||
|
# i. Initialize feature t at scale 1/4
|
||||||
|
flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder(
|
||||||
|
features0[-1],
|
||||||
|
features1[-1],
|
||||||
|
flow_t0_inr4,
|
||||||
|
flow_t1_inr4,
|
||||||
|
img0=img0,
|
||||||
|
img1=img1,
|
||||||
|
)
|
||||||
|
mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:]
|
||||||
|
img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4)
|
||||||
|
img_warp_4 = (img_warp_4 + 1.0) / 2
|
||||||
|
img_warp_4 = torch.clamp(img_warp_4, 0, 1)
|
||||||
|
|
||||||
|
corr_4, flow_4_lr = self._amt_corr_scale_lookup(
|
||||||
|
corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4)
|
||||||
|
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
||||||
|
flowt0_4 = flowt0_4 + delta_flow0_4
|
||||||
|
flowt1_4 = flowt1_4 + delta_flow1_4
|
||||||
|
ft_4_ = ft_4_ + delta_ft_4_
|
||||||
|
|
||||||
|
# iii. residue update with lookup corr
|
||||||
|
corr_4 = resize(corr_4, scale_factor=2.0)
|
||||||
|
|
||||||
|
flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1)
|
||||||
|
delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4)
|
||||||
|
flowt0_4 = flowt0_4 + delta_flow_4[:, :2]
|
||||||
|
flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4]
|
||||||
|
ft_4_ = ft_4_ + delta_ft_4_
|
||||||
|
|
||||||
|
############################# scale 1/1 #############################
|
||||||
|
flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder(
|
||||||
|
ft_4_,
|
||||||
|
features0[0],
|
||||||
|
features1[0],
|
||||||
|
flowt0_4,
|
||||||
|
flowt1_4,
|
||||||
|
mask=mask_4_,
|
||||||
|
img0=img0,
|
||||||
|
img1=img1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if full_img is not None:
|
||||||
|
img0 = 2 * full_img[:, :, 0] - 1.0
|
||||||
|
img1 = 2 * full_img[:, :, 1] - 1.0
|
||||||
|
inv = img1.shape[2] / flowt0_1.shape[2]
|
||||||
|
flowt0_1 = inv * resize(flowt0_1, scale_factor=inv)
|
||||||
|
flowt1_1 = inv * resize(flowt1_1, scale_factor=inv)
|
||||||
|
flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv)
|
||||||
|
flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv)
|
||||||
|
mask = resize(mask, scale_factor=inv)
|
||||||
|
img_res = resize(img_res, scale_factor=inv)
|
||||||
|
|
||||||
|
imgt_pred = multi_flow_combine(
|
||||||
|
self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None
|
||||||
|
)
|
||||||
|
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
flowt0_1 = flowt0_1.reshape(
|
||||||
|
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||||
|
)
|
||||||
|
flowt1_1 = flowt1_1.reshape(
|
||||||
|
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
flowt0_pred = [flowt0_1, flowt0_4]
|
||||||
|
flowt1_pred = [flowt1_1, flowt1_4]
|
||||||
|
other_pred = [img_warp_4]
|
||||||
|
return imgt_pred, flowt0_pred, flowt1_pred, other_pred
|
||||||
|
|
||||||
|
def forward(self, img_xs, coord=None, t=None, ds_factor=None):
|
||||||
|
assert isinstance(t, list)
|
||||||
|
assert isinstance(coord, list)
|
||||||
|
assert len(t) == len(coord)
|
||||||
|
full_size_img = None
|
||||||
|
if ds_factor is not None:
|
||||||
|
full_size_img = img_xs.clone()
|
||||||
|
img_xs = torch.cat(
|
||||||
|
[
|
||||||
|
resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2),
|
||||||
|
resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2),
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
normal_flows,
|
||||||
|
flows,
|
||||||
|
flow_scalers,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
preserved_raft_flows,
|
||||||
|
) = self.cal_bidirection_flow(255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1])
|
||||||
|
assert coord is not None
|
||||||
|
|
||||||
|
# List of flows
|
||||||
|
normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows)
|
||||||
|
|
||||||
|
############ Unnormalize the predicted/reconstructed flow ############
|
||||||
|
start_idx = 0
|
||||||
|
if coord[0][1] is not None:
|
||||||
|
# Subsmapled flows for reconstruction supervision in the GIMM module
|
||||||
|
# In such case, first two coords in the list are subsampled for supervision up-mentioned
|
||||||
|
# Normalized flow_t towards positive t-axis
|
||||||
|
assert len(coord) > 2
|
||||||
|
flow_t = [
|
||||||
|
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||||
|
for i in range(2, len(coord))
|
||||||
|
]
|
||||||
|
start_idx = 2
|
||||||
|
else:
|
||||||
|
flow_t = [
|
||||||
|
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||||
|
for i in range(len(coord))
|
||||||
|
]
|
||||||
|
|
||||||
|
imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], []
|
||||||
|
|
||||||
|
for idx in range(start_idx, len(coord)):
|
||||||
|
cur_flow_t = flow_t[idx - start_idx]
|
||||||
|
cur_t = t[idx].reshape(-1, 1, 1, 1)
|
||||||
|
if cur_flow_t.ndim != 4:
|
||||||
|
cur_flow_t = cur_flow_t.unsqueeze(0)
|
||||||
|
assert cur_flow_t.ndim == 4
|
||||||
|
|
||||||
|
imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize(
|
||||||
|
img_xs,
|
||||||
|
cur_flow_t,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
cur_t,
|
||||||
|
full_img=full_size_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
imgt_preds.append(imgt_pred)
|
||||||
|
flowt0_preds.append(flowt0_pred)
|
||||||
|
flowt1_preds.append(flowt1_pred)
|
||||||
|
all_others.append(others)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"imgt_pred": imgt_preds,
|
||||||
|
"other_pred": all_others,
|
||||||
|
"flowt0_pred": flowt0_preds,
|
||||||
|
"flowt1_pred": flowt1_preds,
|
||||||
|
"raft_flow": preserved_raft_flows,
|
||||||
|
"ninrflow": normal_inr_flows,
|
||||||
|
"nflow": normal_flows,
|
||||||
|
"flowt": flow_t,
|
||||||
|
}
|
||||||
|
|
||||||
|
def warp_frame(self, frame, flow):
|
||||||
|
return warp(frame, flow)
|
||||||
|
|
||||||
|
def sample_coord_input(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
s_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=None,
|
||||||
|
upsample_ratio=1.0,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
assert device is not None
|
||||||
|
assert coord_range is None
|
||||||
|
coord_inputs = self.coord_sampler(
|
||||||
|
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||||
|
)
|
||||||
|
return coord_inputs
|
||||||
|
|
||||||
|
def cal_splatting_weights(self, raft_flow01, raft_flow10):
|
||||||
|
batch_size = raft_flow01.shape[0]
|
||||||
|
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
|
||||||
|
|
||||||
|
## flow variance metric
|
||||||
|
sqaure_mean, mean_square = torch.split(
|
||||||
|
F.conv3d(
|
||||||
|
F.pad(
|
||||||
|
torch.cat([raft_flows**2, raft_flows], 1),
|
||||||
|
(1, 1, 1, 1),
|
||||||
|
mode="reflect",
|
||||||
|
).unsqueeze(1),
|
||||||
|
self.g_filter,
|
||||||
|
).squeeze(1),
|
||||||
|
2,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
var = (
|
||||||
|
(sqaure_mean - mean_square**2)
|
||||||
|
.clamp(1e-9, None)
|
||||||
|
.sqrt()
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
var01 = var[:batch_size]
|
||||||
|
var10 = var[batch_size:]
|
||||||
|
|
||||||
|
## flow warp metirc
|
||||||
|
f01_warp = -warp(raft_flow10, raft_flow01)
|
||||||
|
f10_warp = -warp(raft_flow01, raft_flow10)
|
||||||
|
err01 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f01_warp, target=raft_flow01, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
err02 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f10_warp, target=raft_flow10, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
|
||||||
|
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
|
||||||
|
|
||||||
|
return weights1, weights2
|
||||||
|
|
||||||
|
def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||||
|
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||||
|
# based on linear assumption
|
||||||
|
t0_scale = 1.0 / embt
|
||||||
|
t1_scale = 1.0 / (1.0 - embt)
|
||||||
|
if downsample != 1:
|
||||||
|
inv = 1 / downsample
|
||||||
|
flow0 = inv * resize(flow0, scale_factor=inv)
|
||||||
|
flow1 = inv * resize(flow1, scale_factor=inv)
|
||||||
|
|
||||||
|
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
||||||
|
corr = torch.cat([corr0, corr1], dim=1)
|
||||||
|
flow = torch.cat([flow0, flow1], dim=1)
|
||||||
|
return corr, flow
|
||||||
508
gimm_vfi_arch/generalizable_INR/gimmvfi_r.py
Normal file
508
gimm_vfi_arch/generalizable_INR/gimmvfi_r.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# amt: https://github.com/MCG-NKU/AMT
|
||||||
|
# motif: https://github.com/sichun233746/MoTIF
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .configs import GIMMVFIConfig
|
||||||
|
from .modules.coord_sampler import CoordSampler3D
|
||||||
|
from .modules.hyponet import HypoNet
|
||||||
|
from .modules.fi_components import *
|
||||||
|
from .modules.fi_utils import (
|
||||||
|
normalize_flow,
|
||||||
|
unnormalize_flow,
|
||||||
|
warp,
|
||||||
|
resize,
|
||||||
|
build_coord,
|
||||||
|
)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .raft.corr import BidirCorrBlock
|
||||||
|
from .modules.softsplat import softsplat
|
||||||
|
|
||||||
|
|
||||||
|
class GIMMVFI_R(nn.Module):
|
||||||
|
Config = GIMMVFIConfig
|
||||||
|
|
||||||
|
def __init__(self, dtype, config: GIMMVFIConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config = config.copy()
|
||||||
|
self.hyponet_config = config.hyponet
|
||||||
|
self.raft_iter = 20
|
||||||
|
|
||||||
|
######### Encoder and Decoder Settings #########
|
||||||
|
#self.flow_estimator = initialize_RAFT()
|
||||||
|
cur_f_dims = [128, 96]
|
||||||
|
f_dims = [256, 128]
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
skip_channels = f_dims[-1] // 2
|
||||||
|
self.num_flows = 3
|
||||||
|
|
||||||
|
self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1)
|
||||||
|
self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1)
|
||||||
|
self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1)
|
||||||
|
self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels)
|
||||||
|
self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels)
|
||||||
|
|
||||||
|
self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0)
|
||||||
|
self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None)
|
||||||
|
|
||||||
|
self.amt_comb_block = nn.Sequential(
|
||||||
|
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
|
||||||
|
nn.PReLU(6 * self.num_flows),
|
||||||
|
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
################ GIMM settings #################
|
||||||
|
self.coord_sampler = CoordSampler3D(config.coord_range)
|
||||||
|
|
||||||
|
self.g_filter = torch.nn.Parameter(
|
||||||
|
torch.FloatTensor(
|
||||||
|
[
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
|
||||||
|
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||||
|
]
|
||||||
|
).reshape(1, 1, 1, 3, 3),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.fwarp_type = config.fwarp_type
|
||||||
|
|
||||||
|
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||||
|
|
||||||
|
channel = 32
|
||||||
|
in_dim = 2
|
||||||
|
self.cnn_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
channel = 64
|
||||||
|
in_dim = 64
|
||||||
|
self.res_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
LateralBlock(channel),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(
|
||||||
|
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
|
||||||
|
|
||||||
|
def _get_updateblock(self, cdim, scale_factor=None):
|
||||||
|
return BasicUpdateBlock(
|
||||||
|
cdim=cdim,
|
||||||
|
hidden_dim=192,
|
||||||
|
flow_dim=64,
|
||||||
|
corr_dim=256,
|
||||||
|
corr_dim2=192,
|
||||||
|
fc_dim=188,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
corr_levels=4,
|
||||||
|
radius=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def cal_bidirection_flow(self, im0, im1, iters=20):
|
||||||
|
f01, features0, fnet0 = self.flow_estimator(
|
||||||
|
im0.to(self.dtype), im1.to(self.dtype), return_feat=True, iters=20
|
||||||
|
)
|
||||||
|
f10, features1, fnet1 = self.flow_estimator(
|
||||||
|
im1.to(self.dtype), im0.to(self.dtype), return_feat=True, iters=20
|
||||||
|
)
|
||||||
|
corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4)
|
||||||
|
features0 = [
|
||||||
|
self.amt_second_last_cproj(features0[0]),
|
||||||
|
self.amt_last_cproj(features0[1]),
|
||||||
|
]
|
||||||
|
features1 = [
|
||||||
|
self.amt_second_last_cproj(features1[0]),
|
||||||
|
self.amt_last_cproj(features1[1]),
|
||||||
|
]
|
||||||
|
flow01 = f01.unsqueeze(2)
|
||||||
|
flow10 = f10.unsqueeze(2)
|
||||||
|
noraml_flows = torch.cat([flow01, -flow10], dim=2)
|
||||||
|
noraml_flows, flow_scalers = normalize_flow(noraml_flows)
|
||||||
|
|
||||||
|
ori_flows = torch.cat([flow01, flow10], dim=2)
|
||||||
|
return (
|
||||||
|
noraml_flows,
|
||||||
|
ori_flows,
|
||||||
|
flow_scalers,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict_flow(self, f, coord, t, flows):
|
||||||
|
raft_flow01 = flows[:, :, 0].detach()
|
||||||
|
raft_flow10 = flows[:, :, 1].detach()
|
||||||
|
|
||||||
|
# calculate splatting metrics
|
||||||
|
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
|
||||||
|
strtype = self.fwarp_type + "-zeroeps"
|
||||||
|
|
||||||
|
# b,c,h,w
|
||||||
|
pixel_latent_0 = self.cnn_encoder(f[:, :, 0])
|
||||||
|
pixel_latent_1 = self.cnn_encoder(f[:, :, 1])
|
||||||
|
pixel_latent = []
|
||||||
|
|
||||||
|
for i, cur_t in enumerate(t):
|
||||||
|
cur_t = cur_t.reshape(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
tmp_pixel_latent_0 = softsplat(
|
||||||
|
tenIn=pixel_latent_0,
|
||||||
|
tenFlow=raft_flow01 * cur_t,
|
||||||
|
tenMetric=weights1,
|
||||||
|
strMode=strtype,
|
||||||
|
)
|
||||||
|
tmp_pixel_latent_1 = softsplat(
|
||||||
|
tenIn=pixel_latent_1,
|
||||||
|
tenFlow=raft_flow10 * (1 - cur_t),
|
||||||
|
tenMetric=weights2,
|
||||||
|
strMode=strtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
tmp_pixel_latent = torch.cat(
|
||||||
|
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
|
||||||
|
)
|
||||||
|
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
|
||||||
|
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
|
||||||
|
)
|
||||||
|
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
|
||||||
|
|
||||||
|
all_outputs = []
|
||||||
|
permute_idx_range = [i for i in range(1, f.ndim - 1)]
|
||||||
|
for idx, c in enumerate(coord):
|
||||||
|
assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze()
|
||||||
|
assert isinstance(c, tuple)
|
||||||
|
|
||||||
|
if c[1] is None:
|
||||||
|
outputs = self.hyponet(
|
||||||
|
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||||
|
).permute(0, -1, *permute_idx_range)
|
||||||
|
else:
|
||||||
|
outputs = self.hyponet(
|
||||||
|
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||||
|
)
|
||||||
|
all_outputs.append(outputs)
|
||||||
|
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1):
|
||||||
|
ft0 = scale * resize(ft0, scale_factor=scale)
|
||||||
|
ft1 = scale * resize(ft1, scale_factor=scale)
|
||||||
|
mask = resize(mask, scale_factor=scale).sigmoid()
|
||||||
|
img0_warp = warp(img0, ft0)
|
||||||
|
img1_warp = warp(img1, ft1)
|
||||||
|
img_warp = mask * img0_warp + (1 - mask) * img1_warp
|
||||||
|
return img_warp
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def frame_synthesize(
|
||||||
|
self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
flow_t: b,2,h,w
|
||||||
|
cur_t: b,1,1,1
|
||||||
|
"""
|
||||||
|
batch_size = img_xs.shape[0] # b,c,t,h,w
|
||||||
|
img0 = 2 * img_xs[:, :, 0] - 1.0
|
||||||
|
img1 = 2 * img_xs[:, :, 1] - 1.0
|
||||||
|
|
||||||
|
##################### update the predicted flow #####################
|
||||||
|
##initialize coordinates for looking up
|
||||||
|
lookup_coord = build_coord(img_xs[:, :, 0]).to(
|
||||||
|
img_xs[:, :, 0].device
|
||||||
|
) # H//8,W//8
|
||||||
|
|
||||||
|
flow_t0_fullsize = flow_t * (-cur_t)
|
||||||
|
flow_t1_fullsize = flow_t * (1.0 - cur_t)
|
||||||
|
|
||||||
|
inv = 1 / 4
|
||||||
|
flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv)
|
||||||
|
flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv)
|
||||||
|
|
||||||
|
############################# scale 1/4 #############################
|
||||||
|
# i. Initialize feature t at scale 1/4
|
||||||
|
flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder(
|
||||||
|
features0[-1],
|
||||||
|
features1[-1],
|
||||||
|
flow_t0_inr4,
|
||||||
|
flow_t1_inr4,
|
||||||
|
img0=img0,
|
||||||
|
img1=img1,
|
||||||
|
)
|
||||||
|
features0, features1 = features0[:-1], features1[:-1]
|
||||||
|
|
||||||
|
mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:]
|
||||||
|
img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4)
|
||||||
|
img_warp_4 = (img_warp_4 + 1.0) / 2
|
||||||
|
img_warp_4 = torch.clamp(img_warp_4, 0, 1)
|
||||||
|
|
||||||
|
corr_4, flow_4_lr = self._amt_corr_scale_lookup(
|
||||||
|
corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4)
|
||||||
|
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
||||||
|
flowt0_4 = flowt0_4 + delta_flow0_4
|
||||||
|
flowt1_4 = flowt1_4 + delta_flow1_4
|
||||||
|
ft_4_ = ft_4_ + delta_ft_4_
|
||||||
|
|
||||||
|
# iii. residue update with lookup corr
|
||||||
|
corr_4 = resize(corr_4, scale_factor=2.0)
|
||||||
|
|
||||||
|
flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1)
|
||||||
|
delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4)
|
||||||
|
flowt0_4 = flowt0_4 + delta_flow_4[:, :2]
|
||||||
|
flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4]
|
||||||
|
ft_4_ = ft_4_ + delta_ft_4_
|
||||||
|
|
||||||
|
############################# scale 1/1 #############################
|
||||||
|
flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder(
|
||||||
|
ft_4_,
|
||||||
|
features0[0],
|
||||||
|
features1[0],
|
||||||
|
flowt0_4,
|
||||||
|
flowt1_4,
|
||||||
|
mask=mask_4_,
|
||||||
|
img0=img0,
|
||||||
|
img1=img1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if full_img is not None:
|
||||||
|
img0 = 2 * full_img[:, :, 0] - 1.0
|
||||||
|
img1 = 2 * full_img[:, :, 1] - 1.0
|
||||||
|
inv = img1.shape[2] / flowt0_1.shape[2]
|
||||||
|
flowt0_1 = inv * resize(flowt0_1, scale_factor=inv)
|
||||||
|
flowt1_1 = inv * resize(flowt1_1, scale_factor=inv)
|
||||||
|
flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv)
|
||||||
|
flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv)
|
||||||
|
mask = resize(mask, scale_factor=inv)
|
||||||
|
img_res = resize(img_res, scale_factor=inv)
|
||||||
|
|
||||||
|
imgt_pred = multi_flow_combine(
|
||||||
|
self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None
|
||||||
|
)
|
||||||
|
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
flowt0_1 = flowt0_1.reshape(
|
||||||
|
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||||
|
)
|
||||||
|
flowt1_1 = flowt1_1.reshape(
|
||||||
|
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
flowt0_pred = [flowt0_1, flowt0_4]
|
||||||
|
flowt1_pred = [flowt1_1, flowt1_4]
|
||||||
|
other_pred = [img_warp_4]
|
||||||
|
return imgt_pred, flowt0_pred, flowt1_pred, other_pred
|
||||||
|
|
||||||
|
def forward(self, img_xs, coord=None, t=None, iters=None, ds_factor=None):
|
||||||
|
assert isinstance(t, list)
|
||||||
|
assert isinstance(coord, list)
|
||||||
|
assert len(t) == len(coord)
|
||||||
|
full_size_img = None
|
||||||
|
if ds_factor is not None:
|
||||||
|
full_size_img = img_xs.clone()
|
||||||
|
img_xs = torch.cat(
|
||||||
|
[
|
||||||
|
resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2),
|
||||||
|
resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2),
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
iters = self.raft_iter if iters is None else iters
|
||||||
|
(
|
||||||
|
normal_flows,
|
||||||
|
flows,
|
||||||
|
flow_scalers,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
preserved_raft_flows,
|
||||||
|
) = self.cal_bidirection_flow(
|
||||||
|
255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters
|
||||||
|
)
|
||||||
|
assert coord is not None
|
||||||
|
|
||||||
|
# List of flows
|
||||||
|
normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows)
|
||||||
|
|
||||||
|
############ Unnormalize the predicted/reconstructed flow ############
|
||||||
|
start_idx = 0
|
||||||
|
if coord[0][1] is not None:
|
||||||
|
# Subsmapled flows for reconstruction supervision in the GIMM module
|
||||||
|
# In such case, by default, first two coords are subsampled for supervision up-mentioned
|
||||||
|
# normalized flow_t versus positive t-axis
|
||||||
|
assert len(coord) > 2
|
||||||
|
flow_t = [
|
||||||
|
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||||
|
for i in range(2, len(coord))
|
||||||
|
]
|
||||||
|
start_idx = 2
|
||||||
|
else:
|
||||||
|
flow_t = [
|
||||||
|
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||||
|
for i in range(len(coord))
|
||||||
|
]
|
||||||
|
|
||||||
|
imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], []
|
||||||
|
|
||||||
|
for idx in range(start_idx, len(coord)):
|
||||||
|
cur_flow_t = flow_t[idx - start_idx]
|
||||||
|
cur_t = t[idx].reshape(-1, 1, 1, 1)
|
||||||
|
if cur_flow_t.ndim != 4:
|
||||||
|
cur_flow_t = cur_flow_t.unsqueeze(0)
|
||||||
|
assert cur_flow_t.ndim == 4
|
||||||
|
|
||||||
|
imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize(
|
||||||
|
img_xs,
|
||||||
|
cur_flow_t,
|
||||||
|
features0,
|
||||||
|
features1,
|
||||||
|
corr_fn,
|
||||||
|
cur_t,
|
||||||
|
full_img=full_size_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
imgt_preds.append(imgt_pred)
|
||||||
|
flowt0_preds.append(flowt0_pred)
|
||||||
|
flowt1_preds.append(flowt1_pred)
|
||||||
|
all_others.append(others)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"imgt_pred": imgt_preds,
|
||||||
|
"other_pred": all_others,
|
||||||
|
"flowt0_pred": flowt0_preds,
|
||||||
|
"flowt1_pred": flowt1_preds,
|
||||||
|
"raft_flow": preserved_raft_flows,
|
||||||
|
"ninrflow": normal_inr_flows,
|
||||||
|
"nflow": normal_flows,
|
||||||
|
"flowt": flow_t,
|
||||||
|
}
|
||||||
|
|
||||||
|
def warp_frame(self, frame, flow):
|
||||||
|
return warp(frame, flow)
|
||||||
|
|
||||||
|
def compute_psnr(self, preds, targets, reduction="mean"):
|
||||||
|
assert reduction in ["mean", "sum", "none"]
|
||||||
|
batch_size = preds.shape[0]
|
||||||
|
sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
if reduction == "mean":
|
||||||
|
psnr = (-10 * torch.log10(sample_mses)).mean()
|
||||||
|
elif reduction == "sum":
|
||||||
|
psnr = (-10 * torch.log10(sample_mses)).sum()
|
||||||
|
else:
|
||||||
|
psnr = -10 * torch.log10(sample_mses)
|
||||||
|
|
||||||
|
return psnr
|
||||||
|
|
||||||
|
def sample_coord_input(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
s_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=None,
|
||||||
|
upsample_ratio=1.0,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
assert device is not None
|
||||||
|
assert coord_range is None
|
||||||
|
coord_inputs = self.coord_sampler(
|
||||||
|
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||||
|
)
|
||||||
|
return coord_inputs
|
||||||
|
|
||||||
|
def cal_splatting_weights(self, raft_flow01, raft_flow10):
|
||||||
|
batch_size = raft_flow01.shape[0]
|
||||||
|
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
|
||||||
|
|
||||||
|
## flow variance metric
|
||||||
|
sqaure_mean, mean_square = torch.split(
|
||||||
|
F.conv3d(
|
||||||
|
F.pad(
|
||||||
|
torch.cat([raft_flows**2, raft_flows], 1),
|
||||||
|
(1, 1, 1, 1),
|
||||||
|
mode="reflect",
|
||||||
|
).unsqueeze(1),
|
||||||
|
self.g_filter,
|
||||||
|
).squeeze(1),
|
||||||
|
2,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
var = (
|
||||||
|
(sqaure_mean - mean_square**2)
|
||||||
|
.clamp(1e-9, None)
|
||||||
|
.sqrt()
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
var01 = var[:batch_size]
|
||||||
|
var10 = var[batch_size:]
|
||||||
|
|
||||||
|
## flow warp metirc
|
||||||
|
f01_warp = -warp(raft_flow10, raft_flow01)
|
||||||
|
f10_warp = -warp(raft_flow01, raft_flow10)
|
||||||
|
err01 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f01_warp, target=raft_flow01, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
err02 = (
|
||||||
|
torch.nn.functional.l1_loss(
|
||||||
|
input=f10_warp, target=raft_flow10, reduction="none"
|
||||||
|
)
|
||||||
|
.mean(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
|
||||||
|
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
|
||||||
|
|
||||||
|
return weights1, weights2
|
||||||
|
|
||||||
|
def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||||
|
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||||
|
# based on linear assumption
|
||||||
|
t0_scale = 1.0 / embt
|
||||||
|
t1_scale = 1.0 / (1.0 - embt)
|
||||||
|
if downsample != 1:
|
||||||
|
inv = 1 / downsample
|
||||||
|
flow0 = inv * resize(flow0, scale_factor=inv)
|
||||||
|
flow1 = inv * resize(flow1, scale_factor=inv)
|
||||||
|
|
||||||
|
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
||||||
|
corr = torch.cat([corr0, corr1], dim=1)
|
||||||
|
flow = torch.cat([flow0, flow1], dim=1)
|
||||||
|
return corr, flow
|
||||||
0
gimm_vfi_arch/generalizable_INR/modules/__init__.py
Normal file
0
gimm_vfi_arch/generalizable_INR/modules/__init__.py
Normal file
91
gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py
Normal file
91
gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CoordSampler3D(nn.Module):
|
||||||
|
def __init__(self, coord_range, t_coord_only=False):
|
||||||
|
super().__init__()
|
||||||
|
self.coord_range = coord_range
|
||||||
|
self.t_coord_only = t_coord_only
|
||||||
|
|
||||||
|
def shape2coordinate(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
spatial_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=(-1.0, 1.0),
|
||||||
|
upsample_ratio=1,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
coords = []
|
||||||
|
assert isinstance(t_ids, list)
|
||||||
|
_coords = torch.tensor(t_ids, device=device) / 1.0
|
||||||
|
coords.append(_coords.to(torch.float32))
|
||||||
|
for num_s in spatial_shape:
|
||||||
|
num_s = int(num_s * upsample_ratio)
|
||||||
|
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
|
||||||
|
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
|
||||||
|
coords.append(_coords)
|
||||||
|
coords = torch.meshgrid(*coords, indexing="ij")
|
||||||
|
coords = torch.stack(coords, dim=-1)
|
||||||
|
ones_like_shape = (1,) * coords.ndim
|
||||||
|
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
|
||||||
|
return coords # (B,T,H,W,3)
|
||||||
|
|
||||||
|
def batchshape2coordinate(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
spatial_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=(-1.0, 1.0),
|
||||||
|
upsample_ratio=1,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
coords = []
|
||||||
|
_coords = torch.tensor(1, device=device)
|
||||||
|
coords.append(_coords.to(torch.float32))
|
||||||
|
for num_s in spatial_shape:
|
||||||
|
num_s = int(num_s * upsample_ratio)
|
||||||
|
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
|
||||||
|
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
|
||||||
|
coords.append(_coords)
|
||||||
|
coords = torch.meshgrid(*coords, indexing="ij")
|
||||||
|
coords = torch.stack(coords, dim=-1)
|
||||||
|
ones_like_shape = (1,) * coords.ndim
|
||||||
|
# Now coords b,1,h,w,3, coords[...,0]=1.
|
||||||
|
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
|
||||||
|
# assign per-sample timestep within the batch
|
||||||
|
coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1)
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
s_shape,
|
||||||
|
t_ids,
|
||||||
|
coord_range=None,
|
||||||
|
upsample_ratio=1.0,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
coord_range = self.coord_range if coord_range is None else coord_range
|
||||||
|
if isinstance(t_ids, list):
|
||||||
|
coords = self.shape2coordinate(
|
||||||
|
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||||
|
)
|
||||||
|
elif isinstance(t_ids, torch.Tensor):
|
||||||
|
coords = self.batchshape2coordinate(
|
||||||
|
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||||
|
)
|
||||||
|
if self.t_coord_only:
|
||||||
|
coords = coords[..., :1]
|
||||||
|
return coords
|
||||||
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# amt: https://github.com/MCG-NKU/AMT
|
||||||
|
# motif: https://github.com/sichun233746/MoTIF
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .fi_utils import warp, resize
|
||||||
|
|
||||||
|
|
||||||
|
class LateralBlock(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super(LateralBlock, self).__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
x = self.layers(x)
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
|
||||||
|
def convrelu(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
nn.PReLU(out_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def multi_flow_combine(
|
||||||
|
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
|
||||||
|
):
|
||||||
|
assert mean is None
|
||||||
|
b, c, h, w = flow0.shape
|
||||||
|
num_flows = c // 2
|
||||||
|
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||||
|
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||||
|
|
||||||
|
mask = (
|
||||||
|
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
|
||||||
|
if mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
img_res = (
|
||||||
|
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
|
||||||
|
if img_res is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||||
|
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||||
|
mean = (
|
||||||
|
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
|
||||||
|
if mean is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
img0_warp = warp(img0, flow0)
|
||||||
|
img1_warp = warp(img1, flow1)
|
||||||
|
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||||
|
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||||
|
|
||||||
|
res = comb_block(img_warps.view(b, -1, h, w))
|
||||||
|
imgt_pred = img_warps.mean(1) + res
|
||||||
|
|
||||||
|
imgt_pred = (imgt_pred + 1.0) / 2
|
||||||
|
|
||||||
|
return imgt_pred
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, side_channels, bias=True):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
self.side_channels = side_channels
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||||
|
),
|
||||||
|
nn.PReLU(in_channels),
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
side_channels,
|
||||||
|
side_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
nn.PReLU(side_channels),
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||||
|
),
|
||||||
|
nn.PReLU(in_channels),
|
||||||
|
)
|
||||||
|
self.conv4 = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
side_channels,
|
||||||
|
side_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
nn.PReLU(side_channels),
|
||||||
|
)
|
||||||
|
self.conv5 = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||||
|
)
|
||||||
|
self.prelu = nn.PReLU(in_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
|
||||||
|
res_feat = out[:, : -self.side_channels, ...]
|
||||||
|
side_feat = out[:, -self.side_channels :, :, :]
|
||||||
|
side_feat = self.conv2(side_feat)
|
||||||
|
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||||
|
|
||||||
|
res_feat = out[:, : -self.side_channels, ...]
|
||||||
|
side_feat = out[:, -self.side_channels :, :, :]
|
||||||
|
side_feat = self.conv4(side_feat)
|
||||||
|
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||||
|
|
||||||
|
out = self.prelu(x + out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cdim,
|
||||||
|
hidden_dim,
|
||||||
|
flow_dim,
|
||||||
|
corr_dim,
|
||||||
|
corr_dim2,
|
||||||
|
fc_dim,
|
||||||
|
corr_levels=4,
|
||||||
|
radius=3,
|
||||||
|
scale_factor=None,
|
||||||
|
out_num=1,
|
||||||
|
):
|
||||||
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
||||||
|
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||||
|
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
||||||
|
|
||||||
|
self.gru = nn.Sequential(
|
||||||
|
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feat_head = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flow_head = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
|
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, net, flow, corr):
|
||||||
|
net = (
|
||||||
|
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
||||||
|
)
|
||||||
|
cor = self.lrelu(self.convc1(corr))
|
||||||
|
cor = self.lrelu(self.convc2(cor))
|
||||||
|
flo = self.lrelu(self.convf1(flow))
|
||||||
|
flo = self.lrelu(self.convf2(flo))
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
inp = self.lrelu(self.conv(cor_flo))
|
||||||
|
inp = torch.cat([inp, flow, net], dim=1)
|
||||||
|
|
||||||
|
out = self.gru(inp)
|
||||||
|
delta_net = self.feat_head(out)
|
||||||
|
delta_flow = self.flow_head(out)
|
||||||
|
|
||||||
|
if self.scale_factor is not None:
|
||||||
|
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||||
|
delta_flow = self.scale_factor * resize(
|
||||||
|
delta_flow, scale_factor=self.scale_factor
|
||||||
|
)
|
||||||
|
return delta_net, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
def get_bn():
|
||||||
|
return nn.BatchNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
class NewInitDecoder(nn.Module):
|
||||||
|
def __init__(self, in_ch, skip_ch):
|
||||||
|
super().__init__()
|
||||||
|
norm_layer = get_bn()
|
||||||
|
|
||||||
|
self.upsample = nn.Sequential(
|
||||||
|
nn.PixelShuffle(2),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 2),
|
||||||
|
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||||
|
norm_layer(in_ch // 2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
in_ch = in_ch // 2
|
||||||
|
self.convblock = nn.Sequential(
|
||||||
|
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
|
||||||
|
ResBlock(in_ch, skip_ch),
|
||||||
|
ResBlock(in_ch, skip_ch),
|
||||||
|
ResBlock(in_ch, skip_ch),
|
||||||
|
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
|
||||||
|
f0 = self.upsample(f0)
|
||||||
|
f1 = self.upsample(f1)
|
||||||
|
f0_warp_ks = warp(f0, flow0_in)
|
||||||
|
f1_warp_ks = warp(f1, flow1_in)
|
||||||
|
|
||||||
|
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
|
||||||
|
|
||||||
|
assert img0 is not None
|
||||||
|
assert img1 is not None
|
||||||
|
scale_factor = f_in.shape[2] / img0.shape[2]
|
||||||
|
img0 = resize(img0, scale_factor=scale_factor)
|
||||||
|
img1 = resize(img1, scale_factor=scale_factor)
|
||||||
|
warped_img0 = warp(img0, flow0_in)
|
||||||
|
warped_img1 = warp(img1, flow1_in)
|
||||||
|
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||||
|
|
||||||
|
out = self.convblock(f_in)
|
||||||
|
ft_ = out[:, 4:, ...]
|
||||||
|
flow0 = flow0_in + out[:, :2, ...]
|
||||||
|
flow1 = flow1_in + out[:, 2:4, ...]
|
||||||
|
return flow0, flow1, ft_
|
||||||
|
|
||||||
|
|
||||||
|
class NewMultiFlowDecoder(nn.Module):
|
||||||
|
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||||
|
super(NewMultiFlowDecoder, self).__init__()
|
||||||
|
norm_layer = get_bn()
|
||||||
|
|
||||||
|
self.upsample = nn.Sequential(
|
||||||
|
nn.PixelShuffle(2),
|
||||||
|
nn.PixelShuffle(2),
|
||||||
|
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 4),
|
||||||
|
convrelu(in_ch // 4, in_ch // 2),
|
||||||
|
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||||
|
norm_layer(in_ch // 2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_flows = num_flows
|
||||||
|
ch_factor = 2
|
||||||
|
self.convblock = nn.Sequential(
|
||||||
|
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
|
||||||
|
ResBlock(in_ch * ch_factor, skip_ch),
|
||||||
|
ResBlock(in_ch * ch_factor, skip_ch),
|
||||||
|
ResBlock(in_ch * ch_factor, skip_ch),
|
||||||
|
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
|
||||||
|
f0 = self.upsample(f0)
|
||||||
|
# print([f1.shape,f0.shape])
|
||||||
|
f1 = self.upsample(f1)
|
||||||
|
n = self.num_flows
|
||||||
|
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
|
||||||
|
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
|
||||||
|
|
||||||
|
ft_ = resize(ft_, scale_factor=4.0)
|
||||||
|
mask = resize(mask, scale_factor=4.0)
|
||||||
|
f0_warp = warp(f0, flow0)
|
||||||
|
f1_warp = warp(f1, flow1)
|
||||||
|
|
||||||
|
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
|
||||||
|
|
||||||
|
assert mask is not None
|
||||||
|
f_in = torch.cat([f_in, mask], 1)
|
||||||
|
|
||||||
|
assert img0 is not None
|
||||||
|
assert img1 is not None
|
||||||
|
warped_img0 = warp(img0, flow0)
|
||||||
|
warped_img1 = warp(img1, flow1)
|
||||||
|
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||||
|
|
||||||
|
out = self.convblock(f_in)
|
||||||
|
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
|
||||||
|
out, [2 * n, 2 * n, n, 3 * n], 1
|
||||||
|
)
|
||||||
|
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
|
||||||
|
mask = torch.sigmoid(mask)
|
||||||
|
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
|
||||||
|
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
|
||||||
|
|
||||||
|
return flow0, flow1, mask, img_res
|
||||||
81
gimm_vfi_arch/generalizable_INR/modules/fi_utils.py
Normal file
81
gimm_vfi_arch/generalizable_INR/modules/fi_utils.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# raft: https://github.com/princeton-vl/RAFT
|
||||||
|
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
backwarp_tenGrid = {}
|
||||||
|
|
||||||
|
|
||||||
|
def warp(tenInput, tenFlow):
|
||||||
|
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||||
|
if k not in backwarp_tenGrid:
|
||||||
|
tenHorizontal = (
|
||||||
|
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
|
||||||
|
.view(1, 1, 1, tenFlow.shape[3])
|
||||||
|
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||||
|
)
|
||||||
|
tenVertical = (
|
||||||
|
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
|
||||||
|
.view(1, 1, tenFlow.shape[2], 1)
|
||||||
|
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||||
|
)
|
||||||
|
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
||||||
|
|
||||||
|
tenFlow = torch.cat(
|
||||||
|
[
|
||||||
|
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||||
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||||
|
return torch.nn.functional.grid_sample(
|
||||||
|
input=tenInput,
|
||||||
|
grid=g,
|
||||||
|
mode="bilinear",
|
||||||
|
padding_mode="border",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_flow(flows):
|
||||||
|
# FIXME: MULTI-DIMENSION
|
||||||
|
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
|
||||||
|
-1, 1, 1, 1, 1
|
||||||
|
)
|
||||||
|
flows = flows / flow_scaler # [-1,1]
|
||||||
|
# # Adapt to [0,1]
|
||||||
|
flows = (flows + 1.0) / 2.0
|
||||||
|
return flows, flow_scaler
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_flow(flows, flow_scaler):
|
||||||
|
return (flows * 2.0 - 1.0) * flow_scaler
|
||||||
|
|
||||||
|
|
||||||
|
def resize(x, scale_factor):
|
||||||
|
return F.interpolate(
|
||||||
|
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd):
|
||||||
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||||
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
return coords[None].repeat(batch, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def build_coord(img):
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
coords = coords_grid(N, H // 8, W // 8)
|
||||||
|
return coords
|
||||||
198
gimm_vfi_arch/generalizable_INR/modules/hyponet.py
Normal file
198
gimm_vfi_arch/generalizable_INR/modules/hyponet.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import einops
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from ..configs import HypoNetConfig
|
||||||
|
from .utils import create_params_with_init, create_activation
|
||||||
|
|
||||||
|
|
||||||
|
class HypoNet(nn.Module):
|
||||||
|
r"""
|
||||||
|
The Hyponetwork with a coordinate-based MLP to be modulated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.use_bias = config.use_bias
|
||||||
|
self.init_config = config.initialization
|
||||||
|
self.num_layer = config.n_layer
|
||||||
|
self.hidden_dims = config.hidden_dim
|
||||||
|
self.add_coord_dim = add_coord_dim
|
||||||
|
|
||||||
|
if len(self.hidden_dims) == 1:
|
||||||
|
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
|
||||||
|
self.num_layer - 1
|
||||||
|
) # exclude output layer
|
||||||
|
else:
|
||||||
|
assert len(self.hidden_dims) == self.num_layer - 1
|
||||||
|
|
||||||
|
if self.config.activation.type == "siren":
|
||||||
|
assert self.init_config.weight_init_type == "siren"
|
||||||
|
assert self.init_config.bias_init_type == "siren"
|
||||||
|
|
||||||
|
# after computes the shape of trainable parameters, initialize them
|
||||||
|
self.params_dict = None
|
||||||
|
self.params_shape_dict = self.compute_params_shape()
|
||||||
|
self.activation = create_activation(self.config.activation)
|
||||||
|
self.build_base_params_dict(self.config.initialization)
|
||||||
|
self.output_bias = config.output_bias
|
||||||
|
|
||||||
|
self.normalize_weight = config.normalize_weight
|
||||||
|
|
||||||
|
self.ignore_base_param_dict = {name: False for name in self.params_dict}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def subsample_coords(coords, subcoord_idx=None):
|
||||||
|
if subcoord_idx is None:
|
||||||
|
return coords
|
||||||
|
|
||||||
|
batch_size = coords.shape[0]
|
||||||
|
sub_coords = []
|
||||||
|
coords = coords.view(batch_size, -1, coords.shape[-1])
|
||||||
|
for idx in range(batch_size):
|
||||||
|
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
|
||||||
|
sub_coords = torch.cat(sub_coords, dim=0)
|
||||||
|
return sub_coords
|
||||||
|
|
||||||
|
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
|
||||||
|
sub_idx = None
|
||||||
|
if isinstance(coord, tuple):
|
||||||
|
coord, sub_idx = coord[0], coord[1]
|
||||||
|
|
||||||
|
if modulation_params_dict is not None:
|
||||||
|
self.check_valid_param_keys(modulation_params_dict)
|
||||||
|
|
||||||
|
batch_size, coord_shape, input_dim = (
|
||||||
|
coord.shape[0],
|
||||||
|
coord.shape[1:-1],
|
||||||
|
coord.shape[-1],
|
||||||
|
)
|
||||||
|
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
|
||||||
|
assert pixel_latent is not None
|
||||||
|
pixel_latent = F.interpolate(
|
||||||
|
pixel_latent.permute(0, 3, 1, 2),
|
||||||
|
size=(coord_shape[1], coord_shape[2]),
|
||||||
|
mode="bilinear",
|
||||||
|
).permute(0, 2, 3, 1)
|
||||||
|
pixel_latent_dim = pixel_latent.shape[-1]
|
||||||
|
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
|
||||||
|
hidden = coord
|
||||||
|
|
||||||
|
hidden = torch.cat([pixel_latent, hidden], dim=-1)
|
||||||
|
|
||||||
|
hidden = self.subsample_coords(hidden, sub_idx)
|
||||||
|
|
||||||
|
for idx in range(self.config.n_layer):
|
||||||
|
param_key = f"linear_wb{idx}"
|
||||||
|
base_param = einops.repeat(
|
||||||
|
self.params_dict[param_key], "n m -> b n m", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if (modulation_params_dict is not None) and (
|
||||||
|
param_key in modulation_params_dict.keys()
|
||||||
|
):
|
||||||
|
modulation_param = modulation_params_dict[param_key]
|
||||||
|
else:
|
||||||
|
if self.config.use_bias:
|
||||||
|
modulation_param = torch.ones_like(base_param[:, :-1])
|
||||||
|
else:
|
||||||
|
modulation_param = torch.ones_like(base_param)
|
||||||
|
|
||||||
|
if self.config.use_bias:
|
||||||
|
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
|
||||||
|
hidden = torch.cat([hidden, ones], dim=-1)
|
||||||
|
|
||||||
|
base_param_w, base_param_b = (
|
||||||
|
base_param[:, :-1, :],
|
||||||
|
base_param[:, -1:, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ignore_base_param_dict[param_key]:
|
||||||
|
base_param_w = 1.0
|
||||||
|
param_w = base_param_w * modulation_param
|
||||||
|
if self.normalize_weight:
|
||||||
|
param_w = F.normalize(param_w, dim=1)
|
||||||
|
modulated_param = torch.cat([param_w, base_param_b], dim=1)
|
||||||
|
else:
|
||||||
|
if self.ignore_base_param_dict[param_key]:
|
||||||
|
base_param = 1.0
|
||||||
|
if self.normalize_weight:
|
||||||
|
modulated_param = F.normalize(base_param * modulation_param, dim=1)
|
||||||
|
else:
|
||||||
|
modulated_param = base_param * modulation_param
|
||||||
|
# print([param_key,hidden.shape,modulated_param.shape])
|
||||||
|
hidden = torch.bmm(hidden, modulated_param)
|
||||||
|
|
||||||
|
if idx < (self.config.n_layer - 1):
|
||||||
|
hidden = self.activation(hidden)
|
||||||
|
|
||||||
|
outputs = hidden + self.output_bias
|
||||||
|
if sub_idx is None:
|
||||||
|
outputs = outputs.view(batch_size, *coord_shape, -1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def compute_params_shape(self):
|
||||||
|
"""
|
||||||
|
Computes the shape of MLP parameters.
|
||||||
|
The computed shapes are used to build the initial weights by `build_base_params_dict`.
|
||||||
|
"""
|
||||||
|
config = self.config
|
||||||
|
use_bias = self.use_bias
|
||||||
|
|
||||||
|
param_shape_dict = dict()
|
||||||
|
|
||||||
|
fan_in = config.input_dim
|
||||||
|
add_dim = self.add_coord_dim
|
||||||
|
fan_in = fan_in + add_dim
|
||||||
|
fan_in = fan_in + 1 if use_bias else fan_in
|
||||||
|
|
||||||
|
for i in range(config.n_layer - 1):
|
||||||
|
fan_out = self.hidden_dims[i]
|
||||||
|
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
|
||||||
|
fan_in = fan_out + 1 if use_bias else fan_out
|
||||||
|
|
||||||
|
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
|
||||||
|
return param_shape_dict
|
||||||
|
|
||||||
|
def build_base_params_dict(self, init_config):
|
||||||
|
assert self.params_shape_dict
|
||||||
|
params_dict = nn.ParameterDict()
|
||||||
|
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
|
||||||
|
is_first = idx == 0
|
||||||
|
params = create_params_with_init(
|
||||||
|
shape,
|
||||||
|
init_type=init_config.weight_init_type,
|
||||||
|
include_bias=self.use_bias,
|
||||||
|
bias_init_type=init_config.bias_init_type,
|
||||||
|
is_first=is_first,
|
||||||
|
siren_w0=self.config.activation.siren_w0, # valid only for siren
|
||||||
|
)
|
||||||
|
params = nn.Parameter(params)
|
||||||
|
params_dict[name] = params
|
||||||
|
self.set_params_dict(params_dict)
|
||||||
|
|
||||||
|
def check_valid_param_keys(self, params_dict):
|
||||||
|
predefined_params_keys = self.params_shape_dict.keys()
|
||||||
|
for param_key in params_dict.keys():
|
||||||
|
if param_key in predefined_params_keys:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise KeyError
|
||||||
|
|
||||||
|
def set_params_dict(self, params_dict):
|
||||||
|
self.check_valid_param_keys(params_dict)
|
||||||
|
self.params_dict = params_dict
|
||||||
42
gimm_vfi_arch/generalizable_INR/modules/layers.py
Normal file
42
gimm_vfi_arch/generalizable_INR/modules/layers.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# define siren layer & Siren model
|
||||||
|
class Sine(nn.Module):
|
||||||
|
"""Sine activation with scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w0 (float): Omega_0 parameter from SIREN paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, w0=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.w0 = w0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.sin(self.w0 * x)
|
||||||
|
|
||||||
|
|
||||||
|
# Damping activation from http://arxiv.org/abs/2306.15242
|
||||||
|
class Damping(nn.Module):
|
||||||
|
"""Sine activation with sublinear factor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w0 (float): Omega_0 parameter from SIREN paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, w0=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.w0 = w0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.clamp(x, min=1e-30)
|
||||||
|
return torch.sin(self.w0 * x) * torch.sqrt(x.abs())
|
||||||
52
gimm_vfi_arch/generalizable_INR/modules/module_config.py
Normal file
52
gimm_vfi_arch/generalizable_INR/modules/module_config.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from omegaconf import MISSING
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypoNetActivationConfig:
|
||||||
|
type: str = "relu"
|
||||||
|
siren_w0: Optional[float] = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypoNetInitConfig:
|
||||||
|
weight_init_type: Optional[str] = "kaiming_uniform"
|
||||||
|
bias_init_type: Optional[str] = "zero"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypoNetConfig:
|
||||||
|
type: str = "mlp"
|
||||||
|
n_layer: int = 5
|
||||||
|
hidden_dim: List[int] = MISSING
|
||||||
|
use_bias: bool = True
|
||||||
|
input_dim: int = 2
|
||||||
|
output_dim: int = 3
|
||||||
|
output_bias: float = 0.5
|
||||||
|
activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig)
|
||||||
|
initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig)
|
||||||
|
|
||||||
|
normalize_weight: bool = True
|
||||||
|
linear_interpo: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CoordSamplerConfig:
|
||||||
|
data_type: str = "image"
|
||||||
|
t_coord_only: bool = False
|
||||||
|
coord_range: List[float] = MISSING
|
||||||
|
time_range: List[float] = MISSING
|
||||||
|
train_strategy: Optional[str] = MISSING
|
||||||
|
val_strategy: Optional[str] = MISSING
|
||||||
|
patch_size: Optional[int] = MISSING
|
||||||
672
gimm_vfi_arch/generalizable_INR/modules/softsplat.py
Normal file
672
gimm_vfi_arch/generalizable_INR/modules/softsplat.py
Normal file
@@ -0,0 +1,672 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# softmax-splatting: https://github.com/sniklaus/softmax-splatting
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import cupy
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
objCudacache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_int32(intIn: int):
|
||||||
|
return cupy.int32(intIn)
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_float32(fltIn: float):
|
||||||
|
return cupy.float32(fltIn)
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
|
||||||
|
if "device" not in objCudacache:
|
||||||
|
objCudacache["device"] = torch.cuda.get_device_name()
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey = strFunction
|
||||||
|
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
strKey += strVariable
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKey += objValue
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
strKey += str(objValue.dtype)
|
||||||
|
strKey += str(objValue.shape)
|
||||||
|
strKey += str(objValue.stride())
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert False
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey += objCudacache["device"]
|
||||||
|
|
||||||
|
if strKey not in objCudacache:
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "unsigned char")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "half")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "float")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "double")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "int")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||||
|
strKernel = strKernel.replace("{{type}}", "long")
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
print(strVariable, objValue.dtype)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert False
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArg = int(objMatch.group(2))
|
||||||
|
|
||||||
|
strTensor = objMatch.group(4)
|
||||||
|
intSizes = objVariables[strTensor].size()
|
||||||
|
|
||||||
|
strKernel = strKernel.replace(
|
||||||
|
objMatch.group(),
|
||||||
|
str(
|
||||||
|
intSizes[intArg]
|
||||||
|
if torch.is_tensor(intSizes[intArg]) == False
|
||||||
|
else intSizes[intArg].item()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == "(" else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ")" else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(",")
|
||||||
|
|
||||||
|
assert intArgs == len(strArgs) - 1
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append(
|
||||||
|
"(("
|
||||||
|
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
|
||||||
|
+ ")*"
|
||||||
|
+ str(
|
||||||
|
intStrides[intArg]
|
||||||
|
if torch.is_tensor(intStrides[intArg]) == False
|
||||||
|
else intStrides[intArg].item()
|
||||||
|
)
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace(
|
||||||
|
"OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
|
||||||
|
"(" + str.join("+", strIndex) + ")",
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == "(" else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ")" else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(",")
|
||||||
|
|
||||||
|
assert intArgs == len(strArgs) - 1
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append(
|
||||||
|
"(("
|
||||||
|
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
|
||||||
|
+ ")*"
|
||||||
|
+ str(
|
||||||
|
intStrides[intArg]
|
||||||
|
if torch.is_tensor(intStrides[intArg]) == False
|
||||||
|
else intStrides[intArg].item()
|
||||||
|
)
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace(
|
||||||
|
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
|
||||||
|
strTensor + "[" + str.join("+", strIndex) + "]",
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel}
|
||||||
|
# end
|
||||||
|
|
||||||
|
return strKey
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
@cupy.memoize(for_each_device=True)
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def cuda_launch(strKey: str):
|
||||||
|
try:
|
||||||
|
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
|
||||||
|
except Exception:
|
||||||
|
if "CUDA_HOME" not in os.environ:
|
||||||
|
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
|
||||||
|
|
||||||
|
strKernel = objCudacache[strKey]["strKernel"]
|
||||||
|
strFunction = objCudacache[strKey]["strFunction"]
|
||||||
|
|
||||||
|
return cupy.RawModule(
|
||||||
|
code=strKernel,
|
||||||
|
options=(
|
||||||
|
"-I " + os.environ["CUDA_HOME"],
|
||||||
|
"-I " + os.environ["CUDA_HOME"] + "/include",
|
||||||
|
),
|
||||||
|
).get_function(strFunction)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
||||||
|
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
||||||
|
|
||||||
|
if strMode == "sum":
|
||||||
|
assert tenMetric is None
|
||||||
|
if strMode == "avg":
|
||||||
|
assert tenMetric is None
|
||||||
|
if strMode.split("-")[0] == "linear":
|
||||||
|
assert tenMetric is not None
|
||||||
|
if strMode.split("-")[0] == "softmax":
|
||||||
|
assert tenMetric is not None
|
||||||
|
|
||||||
|
if strMode == "avg":
|
||||||
|
tenIn = torch.cat(
|
||||||
|
[
|
||||||
|
tenIn,
|
||||||
|
tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif strMode.split("-")[0] == "linear":
|
||||||
|
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
|
||||||
|
|
||||||
|
elif strMode.split("-")[0] == "softmax":
|
||||||
|
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
|
||||||
|
|
||||||
|
# end
|
||||||
|
if torch.isnan(tenIn).any():
|
||||||
|
print("NaN values detected during training in tenIn. Exiting.")
|
||||||
|
assert False
|
||||||
|
|
||||||
|
tenOut = softsplat_func.apply(tenIn, tenFlow)
|
||||||
|
|
||||||
|
if torch.isnan(tenOut).any():
|
||||||
|
print("NaN values detected during training in tenOut_1. Exiting.")
|
||||||
|
assert False
|
||||||
|
|
||||||
|
if strMode.split("-")[0] in ["avg", "linear", "softmax"]:
|
||||||
|
tenNormalize = tenOut[:, -1:, :, :]
|
||||||
|
|
||||||
|
if len(strMode.split("-")) == 1:
|
||||||
|
tenNormalize = tenNormalize + 0.0000001
|
||||||
|
|
||||||
|
elif strMode.split("-")[1] == "addeps":
|
||||||
|
tenNormalize = tenNormalize + 0.0000001
|
||||||
|
|
||||||
|
elif strMode.split("-")[1] == "zeroeps":
|
||||||
|
tenNormalize[tenNormalize == 0.0] = 1.0
|
||||||
|
|
||||||
|
elif strMode.split("-")[1] == "clipeps":
|
||||||
|
tenNormalize = tenNormalize.clip(0.0000001, None)
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
if return_norm:
|
||||||
|
return tenOut[:, :-1, :, :], tenNormalize
|
||||||
|
|
||||||
|
tenOut = tenOut[:, :-1, :, :] / tenNormalize
|
||||||
|
|
||||||
|
if torch.isnan(tenOut).any():
|
||||||
|
print("NaN values detected during training in tenOut_2. Exiting.")
|
||||||
|
assert False
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
class softsplat_func(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
||||||
|
def forward(self, tenIn, tenFlow):
|
||||||
|
tenOut = tenIn.new_zeros(
|
||||||
|
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||||
|
)
|
||||||
|
|
||||||
|
if tenIn.is_cuda == True:
|
||||||
|
cuda_launch(
|
||||||
|
cuda_kernel(
|
||||||
|
"softsplat_out",
|
||||||
|
"""
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
{{type}}* __restrict__ tenOut
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
|
||||||
|
}
|
||||||
|
} }
|
||||||
|
""",
|
||||||
|
{"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut},
|
||||||
|
)
|
||||||
|
)(
|
||||||
|
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[
|
||||||
|
cuda_int32(tenOut.nelement()),
|
||||||
|
tenIn.data_ptr(),
|
||||||
|
tenFlow.data_ptr(),
|
||||||
|
tenOut.data_ptr(),
|
||||||
|
],
|
||||||
|
stream=collections.namedtuple("Stream", "ptr")(
|
||||||
|
torch.cuda.current_stream().cuda_stream
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif tenIn.is_cuda != True:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
self.save_for_backward(tenIn, tenFlow)
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.compiler.disable()
|
||||||
|
@torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
def backward(self, tenOutgrad):
|
||||||
|
tenIn, tenFlow = self.saved_tensors
|
||||||
|
|
||||||
|
tenOutgrad = tenOutgrad.contiguous()
|
||||||
|
assert tenOutgrad.is_cuda == True
|
||||||
|
|
||||||
|
tenIngrad = (
|
||||||
|
tenIn.new_zeros(
|
||||||
|
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||||
|
)
|
||||||
|
if self.needs_input_grad[0] == True
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
tenFlowgrad = (
|
||||||
|
tenFlow.new_zeros(
|
||||||
|
[tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]
|
||||||
|
)
|
||||||
|
if self.needs_input_grad[1] == True
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if tenIngrad is not None:
|
||||||
|
cuda_launch(
|
||||||
|
cuda_kernel(
|
||||||
|
"softsplat_ingrad",
|
||||||
|
"""
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
{{type}}* __restrict__ tenIngrad,
|
||||||
|
{{type}}* __restrict__ tenFlowgrad
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltIngrad = 0.0f;
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
tenIngrad[intIndex] = fltIngrad;
|
||||||
|
} }
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"tenIn": tenIn,
|
||||||
|
"tenFlow": tenFlow,
|
||||||
|
"tenOutgrad": tenOutgrad,
|
||||||
|
"tenIngrad": tenIngrad,
|
||||||
|
"tenFlowgrad": tenFlowgrad,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)(
|
||||||
|
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[
|
||||||
|
cuda_int32(tenIngrad.nelement()),
|
||||||
|
tenIn.data_ptr(),
|
||||||
|
tenFlow.data_ptr(),
|
||||||
|
tenOutgrad.data_ptr(),
|
||||||
|
tenIngrad.data_ptr(),
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
stream=collections.namedtuple("Stream", "ptr")(
|
||||||
|
torch.cuda.current_stream().cuda_stream
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
if tenFlowgrad is not None:
|
||||||
|
cuda_launch(
|
||||||
|
cuda_kernel(
|
||||||
|
"softsplat_flowgrad",
|
||||||
|
"""
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
{{type}}* __restrict__ tenIngrad,
|
||||||
|
{{type}}* __restrict__ tenFlowgrad
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltFlowgrad = 0.0f;
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = 0.0f;
|
||||||
|
{{type}} fltNortheast = 0.0f;
|
||||||
|
{{type}} fltSouthwest = 0.0f;
|
||||||
|
{{type}} fltSoutheast = 0.0f;
|
||||||
|
|
||||||
|
if (intC == 0) {
|
||||||
|
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
} else if (intC == 1) {
|
||||||
|
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
|
||||||
|
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
|
||||||
|
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
|
||||||
|
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
|
||||||
|
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tenFlowgrad[intIndex] = fltFlowgrad;
|
||||||
|
} }
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"tenIn": tenIn,
|
||||||
|
"tenFlow": tenFlow,
|
||||||
|
"tenOutgrad": tenOutgrad,
|
||||||
|
"tenIngrad": tenIngrad,
|
||||||
|
"tenFlowgrad": tenFlowgrad,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)(
|
||||||
|
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[
|
||||||
|
cuda_int32(tenFlowgrad.nelement()),
|
||||||
|
tenIn.data_ptr(),
|
||||||
|
tenFlow.data_ptr(),
|
||||||
|
tenOutgrad.data_ptr(),
|
||||||
|
None,
|
||||||
|
tenFlowgrad.data_ptr(),
|
||||||
|
],
|
||||||
|
stream=collections.namedtuple("Stream", "ptr")(
|
||||||
|
torch.cuda.current_stream().cuda_stream
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
return tenIngrad, tenFlowgrad
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
# end
|
||||||
76
gimm_vfi_arch/generalizable_INR/modules/utils.py
Normal file
76
gimm_vfi_arch/generalizable_INR/modules/utils.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .layers import Sine, Damping
|
||||||
|
|
||||||
|
|
||||||
|
def convert_int_to_list(size, len_list=2):
|
||||||
|
if isinstance(size, int):
|
||||||
|
return [size] * len_list
|
||||||
|
else:
|
||||||
|
assert len(size) == len_list
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_params(params, init_type, **kwargs):
|
||||||
|
fan_in, fan_out = params.shape[0], params.shape[1]
|
||||||
|
if init_type is None or init_type == "normal":
|
||||||
|
nn.init.normal_(params)
|
||||||
|
elif init_type == "kaiming_uniform":
|
||||||
|
nn.init.kaiming_uniform_(params, a=math.sqrt(5))
|
||||||
|
elif init_type == "uniform_fan_in":
|
||||||
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||||
|
nn.init.uniform_(params, -bound, bound)
|
||||||
|
elif init_type == "zero":
|
||||||
|
nn.init.zeros_(params)
|
||||||
|
elif "siren" == init_type:
|
||||||
|
assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
|
||||||
|
w0 = kwargs["siren_w0"]
|
||||||
|
if kwargs["is_first"]:
|
||||||
|
w_std = 1 / fan_in
|
||||||
|
else:
|
||||||
|
w_std = math.sqrt(6.0 / fan_in) / w0
|
||||||
|
nn.init.uniform_(params, -w_std, w_std)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def create_params_with_init(
|
||||||
|
shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
|
||||||
|
):
|
||||||
|
if not include_bias:
|
||||||
|
params = torch.empty([shape[0], shape[1]])
|
||||||
|
initialize_params(params, init_type, **kwargs)
|
||||||
|
return params
|
||||||
|
else:
|
||||||
|
params = torch.empty([shape[0] - 1, shape[1]])
|
||||||
|
bias = torch.empty([1, shape[1]])
|
||||||
|
|
||||||
|
initialize_params(params, init_type, **kwargs)
|
||||||
|
initialize_params(bias, bias_init_type, **kwargs)
|
||||||
|
return torch.cat([params, bias], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def create_activation(config):
|
||||||
|
if config.type == "relu":
|
||||||
|
activation = nn.ReLU()
|
||||||
|
elif config.type == "siren":
|
||||||
|
activation = Sine(config.siren_w0)
|
||||||
|
elif config.type == "silu":
|
||||||
|
activation = nn.SiLU()
|
||||||
|
elif config.type == "damp":
|
||||||
|
activation = Damping(config.siren_w0)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return activation
|
||||||
1
gimm_vfi_arch/generalizable_INR/raft/__init__.py
Normal file
1
gimm_vfi_arch/generalizable_INR/raft/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .raft import RAFT
|
||||||
175
gimm_vfi_arch/generalizable_INR/raft/corr.py
Normal file
175
gimm_vfi_arch/generalizable_INR/raft/corr.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# amt: https://github.com/MCG-NKU/AMT
|
||||||
|
# raft: https://github.com/princeton-vl/RAFT
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .utils.utils import bilinear_sampler, coords_grid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import alt_cuda_corr
|
||||||
|
except:
|
||||||
|
# alt_cuda_corr is not compiled
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BidirCorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
self.corr_pyramid = []
|
||||||
|
self.corr_pyramid_T = []
|
||||||
|
|
||||||
|
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
||||||
|
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||||
|
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
||||||
|
|
||||||
|
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||||
|
corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1)
|
||||||
|
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
self.corr_pyramid_T.append(corr_T)
|
||||||
|
|
||||||
|
for _ in range(self.num_levels - 1):
|
||||||
|
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||||
|
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
self.corr_pyramid_T.append(corr_T)
|
||||||
|
|
||||||
|
def __call__(self, coords0, coords1):
|
||||||
|
r = self.radius
|
||||||
|
coords0 = coords0.permute(0, 2, 3, 1)
|
||||||
|
coords1 = coords1.permute(0, 2, 3, 1)
|
||||||
|
assert (
|
||||||
|
coords0.shape == coords1.shape
|
||||||
|
), f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
||||||
|
batch, h1, w1, _ = coords0.shape
|
||||||
|
|
||||||
|
out_pyramid = []
|
||||||
|
out_pyramid_T = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
corr = self.corr_pyramid[i]
|
||||||
|
corr_T = self.corr_pyramid_T[i]
|
||||||
|
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||||
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
|
||||||
|
centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||||
|
centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||||
|
coords_lvl_0 = centroid_lvl_0 + delta_lvl
|
||||||
|
coords_lvl_1 = centroid_lvl_1 + delta_lvl
|
||||||
|
|
||||||
|
corr = bilinear_sampler(corr, coords_lvl_0)
|
||||||
|
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
||||||
|
corr = corr.view(batch, h1, w1, -1)
|
||||||
|
corr_T = corr_T.view(batch, h1, w1, -1)
|
||||||
|
out_pyramid.append(corr)
|
||||||
|
out_pyramid_T.append(corr_T)
|
||||||
|
|
||||||
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
|
out_T = torch.cat(out_pyramid_T, dim=-1)
|
||||||
|
return (
|
||||||
|
out.permute(0, 3, 1, 2).contiguous().float(),
|
||||||
|
out_T.permute(0, 3, 1, 2).contiguous().float(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def corr(fmap1, fmap2):
|
||||||
|
batch, dim, ht, wd = fmap1.shape
|
||||||
|
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||||
|
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||||
|
|
||||||
|
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||||
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
||||||
|
|
||||||
|
class AlternateCorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
|
||||||
|
self.pyramid = [(fmap1, fmap2)]
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||||
|
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||||
|
self.pyramid.append((fmap1, fmap2))
|
||||||
|
|
||||||
|
def __call__(self, coords):
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
B, H, W, _ = coords.shape
|
||||||
|
dim = self.pyramid[0][0].shape[1]
|
||||||
|
|
||||||
|
corr_list = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
r = self.radius
|
||||||
|
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||||
|
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||||
|
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||||
|
corr_list.append(corr.squeeze(1))
|
||||||
|
|
||||||
|
corr = torch.stack(corr_list, dim=1)
|
||||||
|
corr = corr.reshape(B, -1, H, W)
|
||||||
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
||||||
|
|
||||||
|
class CorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
self.corr_pyramid = []
|
||||||
|
|
||||||
|
# all pairs correlation
|
||||||
|
corr = CorrBlock.corr(fmap1, fmap2)
|
||||||
|
|
||||||
|
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||||
|
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||||
|
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
for i in range(self.num_levels - 1):
|
||||||
|
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||||
|
self.corr_pyramid.append(corr)
|
||||||
|
|
||||||
|
def __call__(self, coords):
|
||||||
|
r = self.radius
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
batch, h1, w1, _ = coords.shape
|
||||||
|
|
||||||
|
out_pyramid = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
corr = self.corr_pyramid[i]
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||||
|
|
||||||
|
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||||
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords_lvl = centroid_lvl + delta_lvl
|
||||||
|
|
||||||
|
corr = bilinear_sampler(corr, coords_lvl)
|
||||||
|
corr = corr.view(batch, h1, w1, -1)
|
||||||
|
out_pyramid.append(corr)
|
||||||
|
|
||||||
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
|
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def corr(fmap1, fmap2):
|
||||||
|
batch, dim, ht, wd = fmap1.shape
|
||||||
|
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||||
|
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||||
|
|
||||||
|
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||||
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
293
gimm_vfi_arch/generalizable_INR/raft/extractor.py
Normal file
293
gimm_vfi_arch/generalizable_INR/raft/extractor.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
y = self.relu(self.norm3(self.conv3(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0, only_feat=False):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
self.only_feat = only_feat
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 64
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=2)
|
||||||
|
self.layer3 = self._make_layer(128, stride=2)
|
||||||
|
|
||||||
|
if not self.only_feat:
|
||||||
|
# output convolution
|
||||||
|
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x, return_feature=False, mif=False):
|
||||||
|
features = []
|
||||||
|
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x_2 = F.interpolate(x, scale_factor=1 / 2, mode="bilinear", align_corners=False)
|
||||||
|
x_4 = F.interpolate(x, scale_factor=1 / 4, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
|
def f1(feat):
|
||||||
|
feat = self.conv1(feat)
|
||||||
|
feat = self.norm1(feat)
|
||||||
|
feat = self.relu1(feat)
|
||||||
|
feat = self.layer1(feat)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
x = f1(x)
|
||||||
|
features.append(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
if mif:
|
||||||
|
x_2_2 = f1(x_2)
|
||||||
|
features.append(torch.cat([x, x_2_2], dim=1))
|
||||||
|
else:
|
||||||
|
features.append(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
if mif:
|
||||||
|
x_2_4 = self.layer2(x_2_2)
|
||||||
|
x_4_4 = f1(x_4)
|
||||||
|
features.append(torch.cat([x, x_2_4, x_4_4], dim=1))
|
||||||
|
else:
|
||||||
|
features.append(x)
|
||||||
|
|
||||||
|
if not self.only_feat:
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
features = [torch.split(f, [batch_dim, batch_dim], dim=0) for f in features]
|
||||||
|
if return_feature:
|
||||||
|
return x, features
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SmallEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||||
|
super(SmallEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(32)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 32
|
||||||
|
self.layer1 = self._make_layer(32, stride=1)
|
||||||
|
self.layer2 = self._make_layer(64, stride=2)
|
||||||
|
self.layer3 = self._make_layer(96, stride=2)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
238
gimm_vfi_arch/generalizable_INR/raft/other_raft.py
Normal file
238
gimm_vfi_arch/generalizable_INR/raft/other_raft.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .update import BasicUpdateBlock, SmallUpdateBlock
|
||||||
|
from .extractor import BasicEncoder, SmallEncoder
|
||||||
|
from .corr import BidirCorrBlock, AlternateCorrBlock
|
||||||
|
from .utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||||
|
|
||||||
|
try:
|
||||||
|
autocast = torch.cuda.amp.autocast
|
||||||
|
except:
|
||||||
|
# dummy autocast for PyTorch < 1.6
|
||||||
|
class autocast:
|
||||||
|
def __init__(self, enabled):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# BiRAFT
|
||||||
|
class RAFT(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(RAFT, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
if args.small:
|
||||||
|
self.hidden_dim = hdim = 96
|
||||||
|
self.context_dim = cdim = 64
|
||||||
|
args.corr_levels = 4
|
||||||
|
args.corr_radius = 3
|
||||||
|
self.corr_levels = 4
|
||||||
|
self.corr_radius = 3
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.hidden_dim = hdim = 128
|
||||||
|
self.context_dim = cdim = 128
|
||||||
|
args.corr_levels = 4
|
||||||
|
args.corr_radius = 4
|
||||||
|
self.corr_levels = 4
|
||||||
|
self.corr_radius = 4
|
||||||
|
|
||||||
|
if "dropout" not in args._get_kwargs():
|
||||||
|
self.args.dropout = 0
|
||||||
|
|
||||||
|
if "alternate_corr" not in args._get_kwargs():
|
||||||
|
self.args.alternate_corr = False
|
||||||
|
|
||||||
|
# feature network, context network, and update block
|
||||||
|
if args.small:
|
||||||
|
self.fnet = SmallEncoder(
|
||||||
|
output_dim=128, norm_fn="instance", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.cnet = SmallEncoder(
|
||||||
|
output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.fnet = BasicEncoder(
|
||||||
|
output_dim=256, norm_fn="instance", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.cnet = BasicEncoder(
|
||||||
|
output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
def freeze_bn(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def build_coord(self, img):
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
coords = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def initialize_flow(self, img, img2):
|
||||||
|
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||||
|
assert img.shape == img2.shape
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
coords01 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
coords02 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
coords2 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
|
||||||
|
# optical flow computed as difference: flow = coords1 - coords0
|
||||||
|
return coords01, coords02, coords1, coords2
|
||||||
|
|
||||||
|
def upsample_flow(self, flow, mask):
|
||||||
|
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
|
||||||
|
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
||||||
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||||
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
||||||
|
|
||||||
|
def get_corr_fn(self, image1, image2, projector=None):
|
||||||
|
# run the feature network
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
fmaps, feats = self.fnet([image1, image2], return_feature=True)
|
||||||
|
fmap1, fmap2 = fmaps
|
||||||
|
fmap1 = fmap1.float()
|
||||||
|
fmap2 = fmap2.float()
|
||||||
|
corr_fn1 = None
|
||||||
|
if self.args.alternate_corr:
|
||||||
|
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
if projector is not None:
|
||||||
|
corr_fn1 = AlternateCorrBlock(
|
||||||
|
projector(feats[-1][0]),
|
||||||
|
projector(feats[-1][1]),
|
||||||
|
radius=self.args.corr_radius,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
if projector is not None:
|
||||||
|
corr_fn1 = BidirCorrBlock(
|
||||||
|
projector(feats[-1][0]),
|
||||||
|
projector(feats[-1][1]),
|
||||||
|
radius=self.args.corr_radius,
|
||||||
|
)
|
||||||
|
if corr_fn1 is None:
|
||||||
|
return corr_fn, corr_fn
|
||||||
|
else:
|
||||||
|
return corr_fn, corr_fn1
|
||||||
|
|
||||||
|
def get_corr_fn_from_feat(self, fmap1, fmap2):
|
||||||
|
fmap1 = fmap1.float()
|
||||||
|
fmap2 = fmap2.float()
|
||||||
|
if self.args.alternate_corr:
|
||||||
|
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
else:
|
||||||
|
corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
|
return corr_fn
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
image2,
|
||||||
|
iters=12,
|
||||||
|
flow_init=None,
|
||||||
|
upsample=True,
|
||||||
|
test_mode=False,
|
||||||
|
corr_fn=None,
|
||||||
|
mif=False,
|
||||||
|
):
|
||||||
|
"""Estimate optical flow between pair of frames"""
|
||||||
|
assert flow_init is None
|
||||||
|
|
||||||
|
image1 = 2 * (image1 / 255.0) - 1.0
|
||||||
|
image2 = 2 * (image2 / 255.0) - 1.0
|
||||||
|
|
||||||
|
image1 = image1.contiguous()
|
||||||
|
image2 = image2.contiguous()
|
||||||
|
|
||||||
|
hdim = self.hidden_dim
|
||||||
|
cdim = self.context_dim
|
||||||
|
|
||||||
|
if corr_fn is None:
|
||||||
|
corr_fn, _ = self.get_corr_fn(image1, image2)
|
||||||
|
|
||||||
|
# # run the feature network
|
||||||
|
# with autocast(enabled=self.args.mixed_precision):
|
||||||
|
# fmap1, fmap2 = self.fnet([image1, image2])
|
||||||
|
|
||||||
|
# fmap1 = fmap1.float()
|
||||||
|
# fmap2 = fmap2.float()
|
||||||
|
# if self.args.alternate_corr:
|
||||||
|
# corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
# else:
|
||||||
|
# corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
|
# run the context network
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
# for image1
|
||||||
|
cnet1, features1 = self.cnet(image1, return_feature=True, mif=mif)
|
||||||
|
net1, inp1 = torch.split(cnet1, [hdim, cdim], dim=1)
|
||||||
|
net1 = torch.tanh(net1)
|
||||||
|
inp1 = torch.relu(inp1)
|
||||||
|
# for image2
|
||||||
|
cnet2, features2 = self.cnet(image2, return_feature=True, mif=mif)
|
||||||
|
net2, inp2 = torch.split(cnet2, [hdim, cdim], dim=1)
|
||||||
|
net2 = torch.tanh(net2)
|
||||||
|
inp2 = torch.relu(inp2)
|
||||||
|
|
||||||
|
coords01, coords02, coords1, coords2 = self.initialize_flow(image1, image2)
|
||||||
|
|
||||||
|
# if flow_init is not None:
|
||||||
|
# coords1 = coords1 + flow_init
|
||||||
|
|
||||||
|
# flow_predictions1 = []
|
||||||
|
# flow_predictions2 = []
|
||||||
|
for itr in range(iters):
|
||||||
|
coords1 = coords1.detach()
|
||||||
|
coords2 = coords2.detach()
|
||||||
|
corr1, corr2 = corr_fn(coords1, coords2) # index correlation volume
|
||||||
|
|
||||||
|
flow1 = coords1 - coords01
|
||||||
|
flow2 = coords2 - coords02
|
||||||
|
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
net1, up_mask1, delta_flow1 = self.update_block(
|
||||||
|
net1, inp1, corr1, flow1
|
||||||
|
)
|
||||||
|
net2, up_mask2, delta_flow2 = self.update_block(
|
||||||
|
net2, inp2, corr2, flow2
|
||||||
|
)
|
||||||
|
|
||||||
|
# F(t+1) = F(t) + \Delta(t)
|
||||||
|
coords1 = coords1 + delta_flow1
|
||||||
|
coords2 = coords2 + delta_flow2
|
||||||
|
flow_low1 = coords1 - coords01
|
||||||
|
flow_low2 = coords2 - coords02
|
||||||
|
# upsample predictions
|
||||||
|
if up_mask1 is None:
|
||||||
|
flow_up1 = upflow8(coords1 - coords01)
|
||||||
|
flow_up2 = upflow8(coords2 - coords02)
|
||||||
|
else:
|
||||||
|
flow_up1 = self.upsample_flow(coords1 - coords01, up_mask1)
|
||||||
|
flow_up2 = self.upsample_flow(coords2 - coords02, up_mask2)
|
||||||
|
|
||||||
|
# flow_predictions.append(flow_up)
|
||||||
|
return flow_up1, flow_up2, flow_low1, flow_low2, features1, features2
|
||||||
|
# if test_mode:
|
||||||
|
# return coords1 - coords0, flow_up
|
||||||
|
|
||||||
|
# return flow_predictions
|
||||||
169
gimm_vfi_arch/generalizable_INR/raft/raft.py
Normal file
169
gimm_vfi_arch/generalizable_INR/raft/raft.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .update import BasicUpdateBlock, SmallUpdateBlock
|
||||||
|
from .extractor import BasicEncoder, SmallEncoder
|
||||||
|
from .corr import CorrBlock, AlternateCorrBlock
|
||||||
|
from .utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||||
|
|
||||||
|
try:
|
||||||
|
autocast = torch.cuda.amp.autocast
|
||||||
|
except:
|
||||||
|
# dummy autocast for PyTorch < 1.6
|
||||||
|
class autocast:
|
||||||
|
def __init__(self, enabled):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RAFT(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(RAFT, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
if args.small:
|
||||||
|
self.hidden_dim = hdim = 96
|
||||||
|
self.context_dim = cdim = 64
|
||||||
|
args.corr_levels = 4
|
||||||
|
args.corr_radius = 3
|
||||||
|
self.corr_levels = 4
|
||||||
|
self.corr_radius = 3
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.hidden_dim = hdim = 128
|
||||||
|
self.context_dim = cdim = 128
|
||||||
|
args.corr_levels = 4
|
||||||
|
args.corr_radius = 4
|
||||||
|
self.corr_levels = 4
|
||||||
|
self.corr_radius = 4
|
||||||
|
|
||||||
|
if "dropout" not in args._get_kwargs():
|
||||||
|
self.args.dropout = 0
|
||||||
|
|
||||||
|
if "alternate_corr" not in args._get_kwargs():
|
||||||
|
self.args.alternate_corr = False
|
||||||
|
|
||||||
|
# feature network, context network, and update block
|
||||||
|
if args.small:
|
||||||
|
self.fnet = SmallEncoder(
|
||||||
|
output_dim=128, norm_fn="instance", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.cnet = SmallEncoder(
|
||||||
|
output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.fnet = BasicEncoder(
|
||||||
|
output_dim=256, norm_fn="instance", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.cnet = BasicEncoder(
|
||||||
|
output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
|
||||||
|
)
|
||||||
|
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
def freeze_bn(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def initialize_flow(self, img):
|
||||||
|
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||||
|
N, C, H, W = img.shape
|
||||||
|
coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||||
|
|
||||||
|
# optical flow computed as difference: flow = coords1 - coords0
|
||||||
|
return coords0, coords1
|
||||||
|
|
||||||
|
def upsample_flow(self, flow, mask):
|
||||||
|
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
|
||||||
|
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
||||||
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||||
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
image2,
|
||||||
|
iters=12,
|
||||||
|
flow_init=None,
|
||||||
|
upsample=True,
|
||||||
|
test_mode=False,
|
||||||
|
return_feat=True,
|
||||||
|
):
|
||||||
|
"""Estimate optical flow between pair of frames"""
|
||||||
|
|
||||||
|
image1 = 2 * (image1 / 255.0) - 1.0
|
||||||
|
image2 = 2 * (image2 / 255.0) - 1.0
|
||||||
|
|
||||||
|
image1 = image1.contiguous()
|
||||||
|
image2 = image2.contiguous()
|
||||||
|
|
||||||
|
hdim = self.hidden_dim
|
||||||
|
cdim = self.context_dim
|
||||||
|
|
||||||
|
# run the feature network
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
fmap1, fmap2 = self.fnet([image1, image2])
|
||||||
|
|
||||||
|
fmap1 = fmap1.float()
|
||||||
|
fmap2 = fmap2.float()
|
||||||
|
if self.args.alternate_corr:
|
||||||
|
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
else:
|
||||||
|
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
|
# run the context network
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
cnet, feats = self.cnet(image1, return_feature=True)
|
||||||
|
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||||
|
net = torch.tanh(net)
|
||||||
|
inp = torch.relu(inp)
|
||||||
|
|
||||||
|
coords0, coords1 = self.initialize_flow(image1)
|
||||||
|
|
||||||
|
if flow_init is not None:
|
||||||
|
coords1 = coords1 + flow_init
|
||||||
|
|
||||||
|
flow_predictions = []
|
||||||
|
for itr in range(iters):
|
||||||
|
coords1 = coords1.detach()
|
||||||
|
corr = corr_fn(coords1) # index correlation volume
|
||||||
|
|
||||||
|
flow = coords1 - coords0
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||||
|
|
||||||
|
# F(t+1) = F(t) + \Delta(t)
|
||||||
|
coords1 = coords1 + delta_flow
|
||||||
|
|
||||||
|
# upsample predictions
|
||||||
|
if up_mask is None:
|
||||||
|
flow_up = upflow8(coords1 - coords0)
|
||||||
|
else:
|
||||||
|
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||||
|
|
||||||
|
flow_predictions.append(flow_up)
|
||||||
|
|
||||||
|
if test_mode:
|
||||||
|
return coords1 - coords0, flow_up
|
||||||
|
|
||||||
|
if return_feat:
|
||||||
|
return flow_up, feats[1:], fmap1
|
||||||
|
|
||||||
|
return flow_predictions
|
||||||
154
gimm_vfi_arch/generalizable_INR/raft/update.py
Normal file
154
gimm_vfi_arch/generalizable_INR/raft/update.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FlowHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(ConvGRU, self).__init__()
|
||||||
|
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
|
||||||
|
z = torch.sigmoid(self.convz(hx))
|
||||||
|
r = torch.sigmoid(self.convr(hx))
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||||
|
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||||
|
super(SepConvGRU, self).__init__()
|
||||||
|
self.convz1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convr1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
self.convq1 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convr2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
self.convq2 = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
# horizontal
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz1(hx))
|
||||||
|
r = torch.sigmoid(self.convr1(hx))
|
||||||
|
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
# vertical
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz2(hx))
|
||||||
|
r = torch.sigmoid(self.convr2(hx))
|
||||||
|
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SmallMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(SmallMotionEncoder, self).__init__()
|
||||||
|
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||||
|
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicMotionEncoder, self).__init__()
|
||||||
|
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||||
|
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||||
|
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, flow, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
cor = F.relu(self.convc2(cor))
|
||||||
|
flo = F.relu(self.convf1(flow))
|
||||||
|
flo = F.relu(self.convf2(flo))
|
||||||
|
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_flo))
|
||||||
|
return torch.cat([out, flow], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class SmallUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=96):
|
||||||
|
super(SmallUpdateBlock, self).__init__()
|
||||||
|
self.encoder = SmallMotionEncoder(args)
|
||||||
|
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow):
|
||||||
|
motion_features = self.encoder(flow, corr)
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
return net, None, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||||
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, upsample=True):
|
||||||
|
motion_features = self.encoder(flow.to(inp), corr.to(inp))
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
|
||||||
|
net = self.gru(net, inp)
|
||||||
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
# scale mask to balence gradients
|
||||||
|
mask = 0.25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
93
gimm_vfi_arch/generalizable_INR/raft/utils/utils.py
Normal file
93
gimm_vfi_arch/generalizable_INR/raft/utils/utils.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from scipy import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
"""Pads images such that dimensions are divisible by 8"""
|
||||||
|
|
||||||
|
def __init__(self, dims, mode="sintel"):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||||
|
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||||
|
if mode == "sintel":
|
||||||
|
self._pad = [
|
||||||
|
pad_wd // 2,
|
||||||
|
pad_wd - pad_wd // 2,
|
||||||
|
pad_ht // 2,
|
||||||
|
pad_ht - pad_ht // 2,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, x):
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||||
|
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||||
|
|
||||||
|
|
||||||
|
def forward_interpolate(flow):
|
||||||
|
flow = flow.detach().cpu().numpy()
|
||||||
|
dx, dy = flow[0], flow[1]
|
||||||
|
|
||||||
|
ht, wd = dx.shape
|
||||||
|
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||||
|
|
||||||
|
x1 = x0 + dx
|
||||||
|
y1 = y0 + dy
|
||||||
|
|
||||||
|
x1 = x1.reshape(-1)
|
||||||
|
y1 = y1.reshape(-1)
|
||||||
|
dx = dx.reshape(-1)
|
||||||
|
dy = dy.reshape(-1)
|
||||||
|
|
||||||
|
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||||
|
x1 = x1[valid]
|
||||||
|
y1 = y1[valid]
|
||||||
|
dx = dx[valid]
|
||||||
|
dy = dy[valid]
|
||||||
|
|
||||||
|
flow_x = interpolate.griddata(
|
||||||
|
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
flow_y = interpolate.griddata(
|
||||||
|
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
flow = np.stack([flow_x, flow_y], axis=0)
|
||||||
|
return torch.from_numpy(flow).float()
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||||
|
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||||
|
xgrid = 2 * xgrid / (W - 1) - 1
|
||||||
|
ygrid = 2 * ygrid / (H - 1) - 1
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd, device):
|
||||||
|
coords = torch.meshgrid(
|
||||||
|
torch.arange(ht, device=device), torch.arange(wd, device=device)
|
||||||
|
)
|
||||||
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
return coords[None].repeat(batch, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def upflow8(flow, mode="bilinear"):
|
||||||
|
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||||
|
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||||
0
gimm_vfi_arch/utils/__init__.py
Normal file
0
gimm_vfi_arch/utils/__init__.py
Normal file
52
gimm_vfi_arch/utils/utils.py
Normal file
52
gimm_vfi_arch/utils/utils.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from easydict import EasyDict as edict
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
"""Pads images such that dimensions are divisible by divisor"""
|
||||||
|
|
||||||
|
def __init__(self, dims, divisor=16):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
||||||
|
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
||||||
|
self._pad = [
|
||||||
|
pad_wd // 2,
|
||||||
|
pad_wd - pad_wd // 2,
|
||||||
|
pad_ht // 2,
|
||||||
|
pad_ht - pad_ht // 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return F.pad(inputs[0], self._pad, mode="replicate")
|
||||||
|
else:
|
||||||
|
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, *inputs):
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return self._unpad(inputs[0])
|
||||||
|
else:
|
||||||
|
return [self._unpad(x) for x in inputs]
|
||||||
|
|
||||||
|
def _unpad(self, x):
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||||
|
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||||
|
def easydict_to_dict(obj):
|
||||||
|
if not isinstance(obj, edict):
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return {k: easydict_to_dict(v) for k, v in obj.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class RaftArgs:
|
||||||
|
def __init__(self, small, mixed_precision, alternate_corr):
|
||||||
|
self.small = small
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
self.alternate_corr = alternate_corr
|
||||||
|
|
||||||
|
def _get_kwargs(self):
|
||||||
|
return {
|
||||||
|
"small": self.small,
|
||||||
|
"mixed_precision": self.mixed_precision,
|
||||||
|
"alternate_corr": self.alternate_corr
|
||||||
|
}
|
||||||
180
inference.py
180
inference.py
@@ -441,3 +441,183 @@ class SGMVFIModel:
|
|||||||
pred = self._inference(img0, img1, timestep=time_step)
|
pred = self._inference(img0, img1, timestep=time_step)
|
||||||
pred = padder.unpad(pred)
|
pred = padder.unpad(pred)
|
||||||
return torch.clamp(pred, 0, 1)
|
return torch.clamp(pred, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GIMM-VFI model wrapper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GIMMVFIModel:
|
||||||
|
"""Clean inference wrapper around GIMM-VFI for ComfyUI integration.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- interpolate_batch(): standard single-midpoint interface (compatible with
|
||||||
|
recursive _interpolate_frames machinery used by other models)
|
||||||
|
- interpolate_multi(): GIMM-VFI's unique single-pass mode, generates all
|
||||||
|
N-1 intermediate frames between each pair in one forward pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_path, flow_checkpoint_path, variant="auto",
|
||||||
|
ds_factor=1.0, device="cpu"):
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from .gimm_vfi_arch import (
|
||||||
|
GIMMVFI_R, GIMMVFI_F, GIMMVFIConfig,
|
||||||
|
GIMM_RAFT, GIMM_FlowFormer, gimm_get_flowformer_cfg,
|
||||||
|
GIMMInputPadder, GIMMRaftArgs, easydict_to_dict,
|
||||||
|
)
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
self.ds_factor = ds_factor
|
||||||
|
self.device = device
|
||||||
|
self._InputPadder = GIMMInputPadder
|
||||||
|
|
||||||
|
filename = os.path.basename(checkpoint_path).lower()
|
||||||
|
|
||||||
|
# Detect variant from filename
|
||||||
|
if variant == "auto":
|
||||||
|
self.is_flowformer = "gimmvfi_f" in filename
|
||||||
|
else:
|
||||||
|
self.is_flowformer = (variant == "flowformer")
|
||||||
|
|
||||||
|
self.variant_name = "flowformer" if self.is_flowformer else "raft"
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if self.is_flowformer:
|
||||||
|
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_f_arb.yaml")
|
||||||
|
else:
|
||||||
|
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_r_arb.yaml")
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
config = easydict_to_dict(config)
|
||||||
|
config = OmegaConf.create(config)
|
||||||
|
arch_defaults = GIMMVFIConfig.create(config.arch)
|
||||||
|
config = OmegaConf.merge(arch_defaults, config.arch)
|
||||||
|
|
||||||
|
# Build model + flow estimator
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if self.is_flowformer:
|
||||||
|
self.model = GIMMVFI_F(dtype, config)
|
||||||
|
cfg = gimm_get_flowformer_cfg()
|
||||||
|
flow_estimator = GIMM_FlowFormer(cfg.latentcostformer)
|
||||||
|
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
||||||
|
flow_estimator.load_state_dict(flow_sd, strict=True)
|
||||||
|
else:
|
||||||
|
self.model = GIMMVFI_R(dtype, config)
|
||||||
|
raft_args = GIMMRaftArgs(small=False, mixed_precision=False, alternate_corr=False)
|
||||||
|
flow_estimator = GIMM_RAFT(raft_args)
|
||||||
|
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
||||||
|
flow_estimator.load_state_dict(flow_sd, strict=True)
|
||||||
|
|
||||||
|
# Load main model weights
|
||||||
|
sd = comfy.utils.load_torch_file(checkpoint_path)
|
||||||
|
self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
self.model.flow_estimator = flow_estimator
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Move model to device (returns self for chaining)."""
|
||||||
|
self.device = device if isinstance(device, str) else str(device)
|
||||||
|
self.model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
||||||
|
"""Interpolate a single midpoint frame per pair (standard interface).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
time_step: float in (0, 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in range(frames0.shape[0]):
|
||||||
|
I0 = frames0[i:i+1].to(device)
|
||||||
|
I2 = frames1[i:i+1].to(device)
|
||||||
|
|
||||||
|
padder = self._InputPadder(I0.shape, 32)
|
||||||
|
I0_p, I2_p = padder.pad(I0, I2)
|
||||||
|
|
||||||
|
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
||||||
|
batch_size = xs.shape[0]
|
||||||
|
s_shape = xs.shape[-2:]
|
||||||
|
|
||||||
|
coord_inputs = [(
|
||||||
|
self.model.sample_coord_input(
|
||||||
|
batch_size, s_shape, [time_step],
|
||||||
|
device=xs.device, upsample_ratio=self.ds_factor,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)]
|
||||||
|
timesteps = [
|
||||||
|
time_step * torch.ones(xs.shape[0]).to(xs.device)
|
||||||
|
]
|
||||||
|
|
||||||
|
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
||||||
|
pred = padder.unpad(all_outputs["imgt_pred"][0])
|
||||||
|
results.append(torch.clamp(pred, 0, 1))
|
||||||
|
|
||||||
|
return torch.cat(results, dim=0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_multi(self, frame0, frame1, num_intermediates):
|
||||||
|
"""Generate all intermediate frames between a pair in one forward pass.
|
||||||
|
|
||||||
|
This is GIMM-VFI's unique capability -- arbitrary timestep interpolation
|
||||||
|
without recursive 2x passes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
num_intermediates: int, number of intermediate frames to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [1, C, H, W] tensors, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
I0 = frame0.to(device)
|
||||||
|
I2 = frame1.to(device)
|
||||||
|
|
||||||
|
padder = self._InputPadder(I0.shape, 32)
|
||||||
|
I0_p, I2_p = padder.pad(I0, I2)
|
||||||
|
|
||||||
|
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
||||||
|
batch_size = xs.shape[0]
|
||||||
|
s_shape = xs.shape[-2:]
|
||||||
|
interp_factor = num_intermediates + 1
|
||||||
|
|
||||||
|
coord_inputs = [
|
||||||
|
(
|
||||||
|
self.model.sample_coord_input(
|
||||||
|
batch_size, s_shape,
|
||||||
|
[1.0 / interp_factor * i],
|
||||||
|
device=xs.device,
|
||||||
|
upsample_ratio=self.ds_factor,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
for i in range(1, interp_factor)
|
||||||
|
]
|
||||||
|
timesteps = [
|
||||||
|
i * 1.0 / interp_factor * torch.ones(xs.shape[0]).to(xs.device)
|
||||||
|
for i in range(1, interp_factor)
|
||||||
|
]
|
||||||
|
|
||||||
|
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for pred in all_outputs["imgt_pred"]:
|
||||||
|
unpadded = padder.unpad(pred)
|
||||||
|
results.append(torch.clamp(unpadded, 0, 1))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
396
nodes.py
396
nodes.py
@@ -8,10 +8,11 @@ import torch
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel
|
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
|
||||||
from .bim_vfi_arch import clear_backwarp_cache
|
from .bim_vfi_arch import clear_backwarp_cache
|
||||||
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
||||||
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
||||||
|
from .gimm_vfi_arch import clear_gimm_caches
|
||||||
|
|
||||||
logger = logging.getLogger("Tween")
|
logger = logging.getLogger("Tween")
|
||||||
|
|
||||||
@@ -40,6 +41,17 @@ SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi")
|
|||||||
if not os.path.exists(SGM_MODEL_DIR):
|
if not os.path.exists(SGM_MODEL_DIR):
|
||||||
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
|
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# GIMM-VFI
|
||||||
|
GIMM_HF_REPO = "Kijai/GIMM-VFI_safetensors"
|
||||||
|
GIMM_AVAILABLE_MODELS = [
|
||||||
|
"gimmvfi_r_arb_lpips_fp32.safetensors",
|
||||||
|
"gimmvfi_f_arb_lpips_fp32.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
GIMM_MODEL_DIR = os.path.join(folder_paths.models_dir, "gimm-vfi")
|
||||||
|
if not os.path.exists(GIMM_MODEL_DIR):
|
||||||
|
os.makedirs(GIMM_MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
"""List available checkpoint files in the bim-vfi model directory."""
|
"""List available checkpoint files in the bim-vfi model directory."""
|
||||||
@@ -1113,3 +1125,385 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
|
|||||||
result = result[1:] # skip duplicate boundary frame
|
result = result[1:] # skip duplicate boundary frame
|
||||||
|
|
||||||
return (result, model)
|
return (result, model)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GIMM-VFI nodes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_available_gimm_models():
|
||||||
|
"""List available GIMM-VFI checkpoint files in the gimm-vfi model directory."""
|
||||||
|
models = []
|
||||||
|
if os.path.isdir(GIMM_MODEL_DIR):
|
||||||
|
for f in os.listdir(GIMM_MODEL_DIR):
|
||||||
|
if f.endswith((".safetensors", ".pth", ".pt", ".ckpt")):
|
||||||
|
# Exclude flow estimator checkpoints from the model list
|
||||||
|
if f.startswith(("raft-", "flowformer_")):
|
||||||
|
continue
|
||||||
|
models.append(f)
|
||||||
|
if not models:
|
||||||
|
models = list(GIMM_AVAILABLE_MODELS)
|
||||||
|
return sorted(models)
|
||||||
|
|
||||||
|
|
||||||
|
def download_gimm_model(filename, dest_dir):
|
||||||
|
"""Download a GIMM-VFI file from HuggingFace."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"huggingface_hub is required to auto-download GIMM-VFI models. "
|
||||||
|
"Install it with: pip install huggingface_hub"
|
||||||
|
)
|
||||||
|
logger.info(f"Downloading {filename} from HuggingFace ({GIMM_HF_REPO})...")
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=GIMM_HF_REPO,
|
||||||
|
filename=filename,
|
||||||
|
local_dir=dest_dir,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
dest_path = os.path.join(dest_dir, filename)
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
raise RuntimeError(f"Failed to download {filename} to {dest_path}")
|
||||||
|
logger.info(f"Downloaded {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGIMMVFIModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_path": (get_available_gimm_models(), {
|
||||||
|
"default": GIMM_AVAILABLE_MODELS[0],
|
||||||
|
"tooltip": "Checkpoint file from models/gimm-vfi/. Auto-downloads from HuggingFace on first use. "
|
||||||
|
"RAFT variant (~80MB) or FlowFormer variant (~123MB) auto-detected from filename.",
|
||||||
|
}),
|
||||||
|
"ds_factor": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.125, "max": 1.0, "step": 0.125,
|
||||||
|
"tooltip": "Downscale factor for internal processing. 1.0 = full resolution. "
|
||||||
|
"Lower values reduce VRAM usage and speed up inference at the cost of quality. "
|
||||||
|
"Try 0.5 for 4K inputs.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GIMM_VFI_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = "video/GIMM-VFI"
|
||||||
|
|
||||||
|
def load_model(self, model_path, ds_factor):
|
||||||
|
full_path = os.path.join(GIMM_MODEL_DIR, model_path)
|
||||||
|
|
||||||
|
# Auto-download main model if missing
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
logger.info(f"Model not found at {full_path}, attempting download...")
|
||||||
|
download_gimm_model(model_path, GIMM_MODEL_DIR)
|
||||||
|
|
||||||
|
# Detect and download matching flow estimator
|
||||||
|
if "gimmvfi_f" in model_path.lower():
|
||||||
|
flow_filename = "flowformer_sintel_fp32.safetensors"
|
||||||
|
else:
|
||||||
|
flow_filename = "raft-things_fp32.safetensors"
|
||||||
|
|
||||||
|
flow_path = os.path.join(GIMM_MODEL_DIR, flow_filename)
|
||||||
|
if not os.path.exists(flow_path):
|
||||||
|
logger.info(f"Flow estimator not found, downloading {flow_filename}...")
|
||||||
|
download_gimm_model(flow_filename, GIMM_MODEL_DIR)
|
||||||
|
|
||||||
|
wrapper = GIMMVFIModel(
|
||||||
|
checkpoint_path=full_path,
|
||||||
|
flow_checkpoint_path=flow_path,
|
||||||
|
variant="auto",
|
||||||
|
ds_factor=ds_factor,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"GIMM-VFI model loaded (variant={wrapper.variant_name}, ds_factor={ds_factor})")
|
||||||
|
return (wrapper,)
|
||||||
|
|
||||||
|
|
||||||
|
class GIMMVFIInterpolate:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE", {
|
||||||
|
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
|
||||||
|
}),
|
||||||
|
"model": ("GIMM_VFI_MODEL", {
|
||||||
|
"tooltip": "GIMM-VFI model from the Load GIMM-VFI Model node.",
|
||||||
|
}),
|
||||||
|
"multiplier": ([2, 4, 8], {
|
||||||
|
"default": 2,
|
||||||
|
"tooltip": "Frame rate multiplier. In single-pass mode, all intermediate frames are generated "
|
||||||
|
"in one forward pass per pair. In recursive mode, uses 2x passes like other models.",
|
||||||
|
}),
|
||||||
|
"single_pass": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Use GIMM-VFI's single-pass arbitrary-timestep mode. Generates all intermediate frames "
|
||||||
|
"per pair in one forward pass (no recursive 2x passes). Disable to use the standard "
|
||||||
|
"recursive approach (same as BIM/EMA/SGM).",
|
||||||
|
}),
|
||||||
|
"clear_cache_after_n_frames": ("INT", {
|
||||||
|
"default": 10, "min": 1, "max": 100, "step": 1,
|
||||||
|
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
|
||||||
|
}),
|
||||||
|
"keep_device": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
|
||||||
|
}),
|
||||||
|
"all_on_gpu": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
|
||||||
|
}),
|
||||||
|
"batch_size": ("INT", {
|
||||||
|
"default": 1, "min": 1, "max": 64, "step": 1,
|
||||||
|
"tooltip": "Number of frame pairs to process simultaneously in recursive mode. Ignored in single-pass mode (pairs are processed one at a time since each generates multiple frames).",
|
||||||
|
}),
|
||||||
|
"chunk_size": ("INT", {
|
||||||
|
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
FUNCTION = "interpolate"
|
||||||
|
CATEGORY = "video/GIMM-VFI"
|
||||||
|
|
||||||
|
def _interpolate_frames_single_pass(self, frames, model, multiplier,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref):
|
||||||
|
"""Single-pass interpolation using GIMM-VFI's arbitrary timestep capability."""
|
||||||
|
num_intermediates = multiplier - 1
|
||||||
|
new_frames = []
|
||||||
|
num_pairs = frames.shape[0] - 1
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
for i in range(num_pairs):
|
||||||
|
frame0 = frames[i:i+1]
|
||||||
|
frame1 = frames[i+1:i+2]
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
mids = model.interpolate_multi(frame0, frame1, num_intermediates)
|
||||||
|
mids = [m.to(storage_device) for m in mids]
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
new_frames.append(frames[i:i+1])
|
||||||
|
for m in mids:
|
||||||
|
new_frames.append(m)
|
||||||
|
|
||||||
|
step_ref[0] += 1
|
||||||
|
pbar.update_absolute(step_ref[0])
|
||||||
|
|
||||||
|
pairs_since_clear += 1
|
||||||
|
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
|
||||||
|
clear_gimm_caches()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
new_frames.append(frames[-1:])
|
||||||
|
result = torch.cat(new_frames, dim=0)
|
||||||
|
|
||||||
|
if not all_on_gpu and torch.cuda.is_available():
|
||||||
|
clear_gimm_caches()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _interpolate_frames(self, frames, model, num_passes, batch_size,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref):
|
||||||
|
"""Recursive 2x interpolation (standard approach, same as other models)."""
|
||||||
|
for pass_idx in range(num_passes):
|
||||||
|
new_frames = []
|
||||||
|
num_pairs = frames.shape[0] - 1
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
for i in range(0, num_pairs, batch_size):
|
||||||
|
batch_end = min(i + batch_size, num_pairs)
|
||||||
|
actual_batch = batch_end - i
|
||||||
|
|
||||||
|
frames0 = frames[i:batch_end]
|
||||||
|
frames1 = frames[i + 1:batch_end + 1]
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
|
||||||
|
mids = mids.to(storage_device)
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
for j in range(actual_batch):
|
||||||
|
new_frames.append(frames[i + j:i + j + 1])
|
||||||
|
new_frames.append(mids[j:j+1])
|
||||||
|
|
||||||
|
step_ref[0] += actual_batch
|
||||||
|
pbar.update_absolute(step_ref[0])
|
||||||
|
|
||||||
|
pairs_since_clear += actual_batch
|
||||||
|
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
|
||||||
|
clear_gimm_caches()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
new_frames.append(frames[-1:])
|
||||||
|
frames = torch.cat(new_frames, dim=0)
|
||||||
|
|
||||||
|
if not all_on_gpu and torch.cuda.is_available():
|
||||||
|
clear_gimm_caches()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _count_steps(num_frames, num_passes):
|
||||||
|
"""Count total interpolation steps for recursive mode."""
|
||||||
|
n = num_frames
|
||||||
|
total = 0
|
||||||
|
for _ in range(num_passes):
|
||||||
|
total += n - 1
|
||||||
|
n = 2 * n - 1
|
||||||
|
return total
|
||||||
|
|
||||||
|
def interpolate(self, images, model, multiplier, single_pass,
|
||||||
|
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||||
|
batch_size, chunk_size):
|
||||||
|
if images.shape[0] < 2:
|
||||||
|
return (images,)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if not single_pass:
|
||||||
|
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||||
|
|
||||||
|
if all_on_gpu:
|
||||||
|
keep_device = True
|
||||||
|
|
||||||
|
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||||
|
|
||||||
|
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
|
||||||
|
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
|
||||||
|
total_input = all_frames.shape[0]
|
||||||
|
|
||||||
|
# Build chunk boundaries (1-frame overlap between consecutive chunks)
|
||||||
|
if chunk_size < 2 or chunk_size >= total_input:
|
||||||
|
chunks = [(0, total_input)]
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
while start < total_input - 1:
|
||||||
|
end = min(start + chunk_size, total_input)
|
||||||
|
chunks.append((start, end))
|
||||||
|
start = end - 1 # overlap by 1 frame
|
||||||
|
if end == total_input:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate total progress steps across all chunks
|
||||||
|
if single_pass:
|
||||||
|
total_steps = sum(ce - cs - 1 for cs, ce in chunks)
|
||||||
|
else:
|
||||||
|
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
|
||||||
|
pbar = ProgressBar(total_steps)
|
||||||
|
step_ref = [0]
|
||||||
|
|
||||||
|
if keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
result_chunks = []
|
||||||
|
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
|
||||||
|
chunk_frames = all_frames[chunk_start:chunk_end].clone()
|
||||||
|
|
||||||
|
if single_pass:
|
||||||
|
chunk_result = self._interpolate_frames_single_pass(
|
||||||
|
chunk_frames, model, multiplier,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_result = self._interpolate_frames(
|
||||||
|
chunk_frames, model, num_passes, batch_size,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
|
||||||
|
if chunk_idx > 0:
|
||||||
|
chunk_result = chunk_result[1:]
|
||||||
|
|
||||||
|
# Move completed chunk to CPU to bound memory when chunking
|
||||||
|
if len(chunks) > 1:
|
||||||
|
chunk_result = chunk_result.cpu()
|
||||||
|
|
||||||
|
result_chunks.append(chunk_result)
|
||||||
|
|
||||||
|
result = torch.cat(result_chunks, dim=0)
|
||||||
|
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||||
|
result = result.cpu().permute(0, 2, 3, 1)
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||||
|
"""Process a numbered segment of the input batch for GIMM-VFI.
|
||||||
|
|
||||||
|
Chain multiple instances with Save nodes between them to bound peak RAM.
|
||||||
|
The model pass-through output forces sequential execution so each segment
|
||||||
|
saves and frees from RAM before the next starts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
base = GIMMVFIInterpolate.INPUT_TYPES()
|
||||||
|
base["required"]["segment_index"] = ("INT", {
|
||||||
|
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
|
||||||
|
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
|
||||||
|
"Chain the model output to the next Segment Interpolate to force sequential execution.",
|
||||||
|
})
|
||||||
|
base["required"]["segment_size"] = ("INT", {
|
||||||
|
"default": 500, "min": 2, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
|
||||||
|
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
|
||||||
|
})
|
||||||
|
return base
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "GIMM_VFI_MODEL")
|
||||||
|
RETURN_NAMES = ("images", "model")
|
||||||
|
FUNCTION = "interpolate"
|
||||||
|
CATEGORY = "video/GIMM-VFI"
|
||||||
|
|
||||||
|
def interpolate(self, images, model, multiplier, single_pass,
|
||||||
|
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||||
|
batch_size, chunk_size, segment_index, segment_size):
|
||||||
|
total_input = images.shape[0]
|
||||||
|
|
||||||
|
# Compute segment boundaries (1-frame overlap)
|
||||||
|
start = segment_index * (segment_size - 1)
|
||||||
|
end = min(start + segment_size, total_input)
|
||||||
|
|
||||||
|
if start >= total_input - 1:
|
||||||
|
# Past the end — return empty single frame + model
|
||||||
|
return (images[:1], model)
|
||||||
|
|
||||||
|
segment_images = images[start:end]
|
||||||
|
is_continuation = segment_index > 0
|
||||||
|
|
||||||
|
# Delegate to the parent interpolation logic
|
||||||
|
(result,) = super().interpolate(
|
||||||
|
segment_images, model, multiplier, single_pass,
|
||||||
|
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||||
|
batch_size, chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_continuation:
|
||||||
|
result = result[1:] # skip duplicate boundary frame
|
||||||
|
|
||||||
|
return (result, model)
|
||||||
|
|||||||
@@ -1 +1,6 @@
|
|||||||
gdown
|
gdown
|
||||||
|
omegaconf
|
||||||
|
yacs
|
||||||
|
easydict
|
||||||
|
einops
|
||||||
|
huggingface_hub
|
||||||
|
|||||||
Reference in New Issue
Block a user