Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation

Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 13:11:45 +01:00
parent 3c3d4b2537
commit d642255e70
56 changed files with 9774 additions and 1 deletions

View File

@@ -34,6 +34,14 @@ def _auto_install_deps():
except Exception as e:
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
# GIMM-VFI dependencies
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"):
try:
__import__(pkg)
except ImportError:
logger.info(f"[Tween] Installing {pkg}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
_auto_install_deps()
@@ -41,6 +49,7 @@ from .nodes import (
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
)
WEB_DIRECTORY = "./web"
@@ -56,6 +65,9 @@ NODE_CLASS_MAPPINGS = {
"LoadSGMVFIModel": LoadSGMVFIModel,
"SGMVFIInterpolate": SGMVFIInterpolate,
"SGMVFISegmentInterpolate": SGMVFISegmentInterpolate,
"LoadGIMMVFIModel": LoadGIMMVFIModel,
"GIMMVFIInterpolate": GIMMVFIInterpolate,
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -69,4 +81,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadSGMVFIModel": "Load SGM-VFI Model",
"SGMVFIInterpolate": "SGM-VFI Interpolate",
"SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate",
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
}

15
gimm_vfi_arch/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .generalizable_INR.gimmvfi_r import GIMMVFI_R
from .generalizable_INR.gimmvfi_f import GIMMVFI_F
from .generalizable_INR.configs import GIMMVFIConfig
from .generalizable_INR.raft.raft import RAFT as GIMM_RAFT
from .generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer as GIMM_FlowFormer
from .generalizable_INR.flowformer.configs.submission import get_cfg as gimm_get_flowformer_cfg
from .utils.utils import InputPadder as GIMMInputPadder, RaftArgs as GIMMRaftArgs, easydict_to_dict
from .generalizable_INR.modules.softsplat import objCudacache as gimm_softsplat_cache
def clear_gimm_caches():
"""Clear cached CUDA kernels and warp grids for GIMM-VFI."""
from .generalizable_INR.modules.fi_utils import backwarp_tenGrid
backwarp_tenGrid.clear()
gimm_softsplat_cache.clear()

View File

View File

@@ -0,0 +1,57 @@
trainer: stage_inr
dataset:
type: vimeo_arb
path: ./data/vimeo90k/vimeo_septuplet
aug: true
arch:
type: gimmvfi_f
ema: true
modulated_layer_idxs: [1]
coord_range: [-1., 1.]
hyponet:
type: mlp
n_layer: 5 # including the output layer
hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
use_bias: true
input_dim: 3
output_dim: 2
output_bias: 0.5
activation:
type: siren
siren_w0: 1.0
initialization:
weight_init_type: siren
bias_init_type: siren
loss:
subsample:
type: random
ratio: 0.1
optimizer:
type: adamw
init_lr: 0.00008
weight_decay: 0.00004
betas: [0.9, 0.999]
ft: true
warmup:
epoch: 1
multiplier: 1
buffer_epoch: 0
min_lr: 0.000008
mode: fix
start_from_zero: True
max_gn: null
experiment:
amp: True
batch_size: 4
total_batch_size: 32
epochs: 60
save_ckpt_freq: 10
test_freq: 10
test_imlog_freq: 10

View File

@@ -0,0 +1,57 @@
trainer: stage_inr
dataset:
type: vimeo_arb
path: ./data/vimeo90k/vimeo_septuplet
aug: true
arch:
type: gimmvfi_r
ema: true
modulated_layer_idxs: [1]
coord_range: [-1., 1.]
hyponet:
type: mlp
n_layer: 5 # including the output layer
hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
use_bias: true
input_dim: 3
output_dim: 2
output_bias: 0.5
activation:
type: siren
siren_w0: 1.0
initialization:
weight_init_type: siren
bias_init_type: siren
loss:
subsample:
type: random
ratio: 0.1
optimizer:
type: adamw
init_lr: 0.00008
weight_decay: 0.00004
betas: [0.9, 0.999]
ft: true
warmup:
epoch: 1
multiplier: 1
buffer_epoch: 0
min_lr: 0.000008
mode: fix
start_from_zero: True
max_gn: null
experiment:
amp: True
batch_size: 4
total_batch_size: 32
epochs: 60
save_ckpt_freq: 10
test_freq: 10
test_imlog_freq: 10

View File

@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
from typing import List, Optional
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from .modules.module_config import HypoNetConfig
@dataclass
class GIMMConfig:
type: str = "gimm"
ema: Optional[bool] = None
ema_value: Optional[float] = None
fwarp_type: str = "linear"
hyponet: HypoNetConfig = field(default_factory=HypoNetConfig)
coord_range: List[float] = MISSING
modulated_layer_idxs: Optional[List[int]] = None
@classmethod
def create(cls, config):
# We need to specify the type of the default DataEncoderConfig.
# Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
# hence merging with the config with other type would cause config error.
defaults = OmegaConf.structured(cls(ema=False))
config = OmegaConf.merge(defaults, config)
return config
@dataclass
class GIMMVFIConfig:
type: str = "gimmvfi"
ema: Optional[bool] = None
ema_value: Optional[float] = None
fwarp_type: str = "linear"
rec_weight: float = 0.1
hyponet: HypoNetConfig = field(default_factory=HypoNetConfig)
raft_iter: int = 20
coord_range: List[float] = MISSING
modulated_layer_idxs: Optional[List[int]] = None
@classmethod
def create(cls, config):
# We need to specify the type of the default DataEncoderConfig.
# Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
# hence merging with the config with other type would cause config error.
defaults = OmegaConf.structured(cls(ema=False))
config = OmegaConf.merge(defaults, config)
return config

View File

@@ -0,0 +1,77 @@
from yacs.config import CfgNode as CN
_CN = CN()
_CN.name = ""
_CN.suffix = ""
_CN.gamma = 0.8
_CN.max_flow = 400
_CN.batch_size = 6
_CN.sum_freq = 100
_CN.val_freq = 5000000
_CN.image_size = [432, 960]
_CN.add_noise = False
_CN.critical_params = []
_CN.transformer = "latentcostformer"
_CN.model = "pretrained_ckpt/flowformer_sintel.pth"
# latentcostformer
_CN.latentcostformer = CN()
_CN.latentcostformer.pe = "linear"
_CN.latentcostformer.dropout = 0.0
_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
_CN.latentcostformer.query_latent_dim = 64
_CN.latentcostformer.cost_latent_input_dim = 64
_CN.latentcostformer.cost_latent_token_num = 8
_CN.latentcostformer.cost_latent_dim = 128
_CN.latentcostformer.arc_type = "transformer"
_CN.latentcostformer.cost_heads_num = 1
# encoder
_CN.latentcostformer.pretrain = True
_CN.latentcostformer.context_concat = False
_CN.latentcostformer.encoder_depth = 3
_CN.latentcostformer.feat_cross_attn = False
_CN.latentcostformer.patch_size = 8
_CN.latentcostformer.patch_embed = "single"
_CN.latentcostformer.no_pe = False
_CN.latentcostformer.gma = "GMA"
_CN.latentcostformer.kernel_size = 9
_CN.latentcostformer.rm_res = True
_CN.latentcostformer.vert_c_dim = 64
_CN.latentcostformer.cost_encoder_res = True
_CN.latentcostformer.cnet = "twins"
_CN.latentcostformer.fnet = "twins"
_CN.latentcostformer.no_sc = False
_CN.latentcostformer.only_global = False
_CN.latentcostformer.add_flow_token = True
_CN.latentcostformer.use_mlp = False
_CN.latentcostformer.vertical_conv = False
# decoder
_CN.latentcostformer.decoder_depth = 32
_CN.latentcostformer.critical_params = [
"cost_heads_num",
"vert_c_dim",
"cnet",
"pretrain",
"add_flow_token",
"encoder_depth",
"gma",
"cost_encoder_res",
]
### TRAINER
_CN.trainer = CN()
_CN.trainer.scheduler = "OneCycleLR"
_CN.trainer.optimizer = "adamw"
_CN.trainer.canonical_lr = 12.5e-5
_CN.trainer.adamw_decay = 1e-4
_CN.trainer.clip = 1.0
_CN.trainer.num_steps = 120000
_CN.trainer.epsilon = 1e-8
_CN.trainer.anneal_strategy = "linear"
def get_cfg():
return _CN.clone()

View File

@@ -0,0 +1,197 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops.layers.torch import Rearrange
from einops import rearrange
class BroadMultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(BroadMultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, _, _ = K.shape
_, N, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = (
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
) # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, HW, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
return out
# class MultiHeadAttentionRelative_encoder(nn.Module):
# def __init__(self, dim, heads):
# super(MultiHeadAttentionRelative, self).__init__()
# self.dim = dim
# self.heads = heads
# self.scale = (dim/heads) ** -0.5
# self.attend = nn.Softmax(dim=-1)
# def attend_with_rpe(self, Q, K, Q_r, K_r):
# """
# Q: [BH1W1, H3W3, dim]
# K: [BH1W1, H3W3, dim]
# Q_r: [BH1W1, H3W3, H3W3, dim]
# K_r: [BH1W1, H3W3, H3W3, dim]
# """
# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# # context-context similarity
# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3]
# # context-position similarity
# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
# # position-context similarity
# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:])
# p_c = torch.squeeze(p_c, dim=4)
# p_c = p_c.permute(0, 1, 3, 2)
# dots = c_c + c_p + p_c
# return self.attend(dots)
# def forward(self, Q, K, V, Q_r, K_r):
# attn = self.attend_with_rpe(Q, K, Q_r, K_r)
# B, HW, _ = Q.shape
# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
# out = einsum('bhij, bhjd -> bhid', attn, V)
# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
# return out
class MultiHeadAttentionRelative(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttentionRelative, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K, Q_r, K_r):
"""
Q: [BH1W1, 1, dim]
K: [BH1W1, H3W3, dim]
Q_r: [BH1W1, H3W3, dim]
K_r: [BH1W1, H3W3, dim]
"""
Q = rearrange(
Q, "b i (heads d) -> b heads i d", heads=self.heads
) # [BH1W1, heads, 1, dim]
K = rearrange(
K, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
K_r = rearrange(
K_r, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
Q_r = rearrange(
Q_r, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
# context-context similarity
c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3]
# context-position similarity
c_p = (
einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale
) # [(B H1W1) heads 1 H3W3]
# position-context similarity
p_c = (
einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :])
* self.scale
)
p_c = torch.squeeze(p_c, dim=4)
p_c = p_c.permute(0, 1, 3, 2)
dots = c_c + c_p + p_c
return self.attend(dots)
def forward(self, Q, K, V, Q_r, K_r):
attn = self.attend_with_rpe(Q, K, Q_r, K_r)
B, HW, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
return out
def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
return torch.cat(
[
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
],
dim=-1,
)
def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
return torch.cat(
[
torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
],
dim=-1,
)

View File

@@ -0,0 +1,649 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
import math
import numpy as np
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
mul = input_dim // 3
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64 * mul)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64 * mul)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64 * mul
self.layer1 = self._make_layer(64 * mul, stride=1)
self.layer2 = self._make_layer(96 * mul, stride=2)
self.layer3 = self._make_layer(128 * mul, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class ConvNets(nn.Module):
def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1):
super(ConvNets, self).__init__()
self.conv_first = nn.Conv2d(
in_dim, inter_dim, kernel_size=3, padding=1, stride=stride
)
self.conv_last = nn.Conv2d(
inter_dim, out_dim, kernel_size=3, padding=1, stride=stride
)
self.relu = nn.ReLU(inplace=True)
self.inter_convs = nn.ModuleList(
[
ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1)
for i in range(depth)
]
)
def forward(self, x):
x = self.relu(self.conv_first(x))
for inter_conv in self.inter_convs:
x = inter_conv(x)
x = self.conv_last(x)
return x
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.motion_feature_dim
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicFuseMotion(nn.Module):
def __init__(self, args):
super(BasicFuseMotion, self).__init__()
cor_planes = args.motion_feature_dim
out_planes = args.query_latent_dim
self.normf1 = nn.InstanceNorm2d(128)
self.normf2 = nn.InstanceNorm2d(128)
self.convf1 = nn.Conv2d(2, 128, 3, padding=1)
self.convf2 = nn.Conv2d(128, 128, 3, padding=1)
self.convf3 = nn.Conv2d(128, 64, 3, padding=1)
s = 1
self.normc1 = nn.InstanceNorm2d(256 * s)
self.normc2 = nn.InstanceNorm2d(256 * s)
self.normc3 = nn.InstanceNorm2d(256 * s)
self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0)
self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0)
def forward(self, flow, feat, context1=None):
flo = F.relu(self.normf1(self.convf1(flow)))
flo = F.relu(self.normf2(self.convf2(flo)))
flo = self.convf3(flo)
feat = torch.cat([feat, context1], dim=1)
feat = F.relu(self.normc1(self.convc1(feat)))
feat = F.relu(self.normc2(self.convc2(feat)))
feat = F.relu(self.normc3(self.convc3(feat)))
feat = self.convc4(feat)
feat = torch.cat([flo, feat], dim=1)
feat = F.relu(self.conv(feat))
return feat
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
class DirectMeanMaskPredictor(nn.Module):
def __init__(self, args):
super(DirectMeanMaskPredictor, self).__init__()
self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(args.predictor_dim, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, motion_features):
delta_flow = self.flow_head(motion_features)
mask = 0.25 * self.mask(motion_features)
return mask, delta_flow
class BaiscMeanPredictor(nn.Module):
def __init__(self, args, hidden_dim=128):
super(BaiscMeanPredictor, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, latent, flow):
motion_features = self.encoder(flow, latent)
delta_flow = self.flow_head(motion_features)
mask = 0.25 * self.mask(motion_features)
return mask, delta_flow
class BasicRPEEncoder(nn.Module):
def __init__(self, args):
super(BasicRPEEncoder, self).__init__()
self.args = args
dim = args.query_latent_dim
self.encoder = nn.Sequential(
nn.Linear(2, dim // 2),
nn.ReLU(inplace=True),
nn.Linear(dim // 2, dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim),
)
def forward(self, rpe_tokens):
return self.encoder(rpe_tokens)
from .twins import Block, CrossBlock
class TwinsSelfAttentionLayer(nn.Module):
def __init__(self, args):
super(TwinsSelfAttentionLayer, self).__init__()
self.args = args
embed_dim = 256
num_heads = 8
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = 0.0
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
)
self.global_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward(self, x, tgt, size):
x = self.local_block(x, size)
x = self.global_block(x, size)
tgt = self.local_block(tgt, size)
tgt = self.global_block(tgt, size)
return x, tgt
class TwinsCrossAttentionLayer(nn.Module):
def __init__(self, args):
super(TwinsCrossAttentionLayer, self).__init__()
self.args = args
embed_dim = 256
num_heads = 8
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = 0.0
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
)
self.global_block = CrossBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward(self, x, tgt, size):
x = self.local_block(x, size)
tgt = self.local_block(tgt, size)
x, tgt = self.global_block(x, tgt, size)
return x, tgt

View File

@@ -0,0 +1,98 @@
#from turtle import forward
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class ConvNextLayer(nn.Module):
def __init__(self, dim, depth=4):
super().__init__()
self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)])
def forward(self, x):
return self.net(x)
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class ConvNextBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# print(f"conv next layer")
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + x
return x
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

View File

@@ -0,0 +1,316 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ...utils.utils import coords_grid, bilinear_sampler
from .attention import (
MultiHeadAttention,
LinearPositionEmbeddingSine,
ExpPositionEmbeddingSine,
)
from timm.models.layers import DropPath
from .gru import BasicUpdateBlock, GMAUpdateBlock
from .gma import Attention
def initialize_flow(img):
"""Flow is represented as difference between two means flow = mean1 - mean0"""
N, C, H, W = img.shape
mean = coords_grid(N, H, W).to(img.device)
mean_init = coords_grid(N, H, W).to(img.device)
# optical flow computed as difference: flow = mean1 - mean0
return mean, mean_init
class CrossAttentionLayer(nn.Module):
# def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
def __init__(
self,
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
add_flow_token=True,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
pe="linear",
):
super(CrossAttentionLayer, self).__init__()
head_dim = qk_dim // num_heads
self.scale = head_dim**-0.5
self.query_token_dim = query_token_dim
self.pe = pe
self.norm1 = nn.LayerNorm(query_token_dim)
self.norm2 = nn.LayerNorm(query_token_dim)
self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(query_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, v_dim, bias=True),
)
self.proj = nn.Linear(v_dim * 2, query_token_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(query_token_dim, query_token_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(query_token_dim, query_token_dim),
nn.Dropout(dropout),
)
self.add_flow_token = add_flow_token
self.dim = qk_dim
def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3):
"""
query_coord [B, 2, H1, W1]
"""
B, _, H1, W1 = query_coord.shape
if key is None and value is None:
key = self.k(memory)
value = self.v(memory)
# [B, 2, H1, W1] -> [BH1W1, 1, 2]
query_coord = query_coord.contiguous()
query_coord = (
query_coord.view(B, 2, -1)
.permute(0, 2, 1)[:, :, None, :]
.contiguous()
.view(B * H1 * W1, 1, 2)
)
if self.pe == "linear":
query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim)
elif self.pe == "exp":
query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim)
short_cut = query
query = self.norm1(query)
if self.add_flow_token:
q = self.q(query + query_coord_enc)
else:
q = self.q(query_coord_enc)
k, v = key, value
x = self.multi_head_attn(q, k, v)
x = self.proj(torch.cat([x, short_cut], dim=2))
x = short_cut + self.proj_drop(x)
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x, k, v
class MemoryDecoderLayer(nn.Module):
def __init__(self, dim, cfg):
super(MemoryDecoderLayer, self).__init__()
self.cfg = cfg
self.patch_size = cfg.patch_size # for converting coords into H2', W2' space
query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim
qk_dim, v_dim = query_token_dim, query_token_dim
self.cross_attend = CrossAttentionLayer(
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
add_flow_token=cfg.add_flow_token,
dropout=cfg.dropout,
)
def forward(self, query, key, value, memory, coords1, size, size_h3w3):
"""
x: [B*H1*W1, 1, C]
memory: [B*H1*W1, H2'*W2', C]
coords1 [B, 2, H2, W2]
size: B, C, H1, W1
1. Note that here coords0 and coords1 are in H2, W2 space.
Should first convert it into H2', W2' space.
2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0]
"""
x_global, k, v = self.cross_attend(
query, key, value, memory, coords1, self.patch_size, size_h3w3
)
B, C, H1, W1 = size
C = self.cfg.query_latent_dim
x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2)
return x_global, k, v
class ReverseCostExtractor(nn.Module):
def __init__(self, cfg):
super(ReverseCostExtractor, self).__init__()
self.cfg = cfg
def forward(self, cost_maps, coords0, coords1):
"""
cost_maps - B*H1*W1, cost_heads_num, H2, W2
coords - B, 2, H1, W1
"""
BH1W1, heads, H2, W2 = cost_maps.shape
B, _, H1, W1 = coords1.shape
assert (H1 == H2) and (W1 == W2)
assert BH1W1 == B * H1 * W1
cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2)
coords = coords1.permute(0, 2, 3, 1)
corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2]
corr = rearrange(
corr,
"b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1",
b=B,
heads=heads,
h1=H1,
w1=W1,
h2=H2,
w2=W2,
)
r = 4
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device)
centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2)
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords = centroid + delta
corr = bilinear_sampler(corr, coords)
corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2)
return corr
class MemoryDecoder(nn.Module):
def __init__(self, cfg):
super(MemoryDecoder, self).__init__()
dim = self.dim = cfg.query_latent_dim
self.cfg = cfg
self.flow_token_encoder = nn.Sequential(
nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1),
nn.GELU(),
nn.Conv2d(dim, dim, 1, 1),
)
self.proj = nn.Conv2d(256, 256, 1)
self.depth = cfg.decoder_depth
self.decoder_layer = MemoryDecoderLayer(dim, cfg)
if self.cfg.gma:
self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128)
self.att = Attention(
args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128
)
else:
self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128)
def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def encode_flow_token(self, cost_maps, coords):
"""
cost_maps - B*H1*W1, cost_heads_num, H2, W2
coords - B, 2, H1, W1
"""
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
r = 4
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid = coords.reshape(batch * h1 * w1, 1, 1, 2)
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords = centroid + delta
corr = bilinear_sampler(cost_maps, coords)
corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2)
return corr
def forward(self, cost_memory, context, data={}, flow_init=None, iters=None):
"""
memory: [B*H1*W1, H2'*W2', C]
context: [B, D, H1, W1]
"""
cost_maps = data["cost_maps"]
coords0, coords1 = initialize_flow(context)
if flow_init is not None:
# print("[Using warm start]")
coords1 = coords1 + flow_init
# flow = coords1
flow_predictions = []
context = self.proj(context)
net, inp = torch.split(context, [128, 128], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
if self.cfg.gma:
attention = self.att(inp)
size = net.shape
key, value = None, None
if iters is None:
iters = self.depth
for idx in range(iters):
coords1 = coords1.detach()
cost_forward = self.encode_flow_token(cost_maps, coords1)
# cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1)
query = self.flow_token_encoder(cost_forward)
query = (
query.permute(0, 2, 3, 1)
.contiguous()
.view(size[0] * size[2] * size[3], 1, self.dim)
)
cost_global, key, value = self.decoder_layer(
query, key, value, cost_memory, coords1, size, data["H3W3"]
)
if self.cfg.only_global:
corr = cost_global
else:
corr = torch.cat([cost_global, cost_forward], dim=1)
flow = coords1 - coords0
if self.cfg.gma:
net, up_mask, delta_flow = self.update_block(
net, inp, corr, flow, attention
)
else:
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# flow = delta_flow
coords1 = coords1 + delta_flow
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
# if self.training:
# return flow_predictions
# else:
return flow_predictions[-1], coords1 - coords0

View File

@@ -0,0 +1,534 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
import numpy as np
from einops import rearrange
from ...utils.utils import coords_grid
from .attention import (
BroadMultiHeadAttention,
MultiHeadAttention,
LinearPositionEmbeddingSine,
ExpPositionEmbeddingSine,
)
from ..encoders import twins_svt_large
from typing import Tuple
from .twins import Size_
from .cnn import BasicEncoder
from .mlpmixer import MLPMixerLayer
from .convnext import ConvNextLayer
from timm.models.layers import DropPath
class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"):
super().__init__()
self.patch_size = patch_size
self.dim = embed_dim
self.pe = pe
# assert patch_size == 8
if patch_size == 8:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(
embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2
),
nn.ReLU(),
nn.Conv2d(
embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2
),
)
elif patch_size == 4:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(
embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2
),
)
else:
print(f"patch size = {patch_size} is unacceptable.")
self.ffn_with_coord = nn.Sequential(
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
nn.ReLU(),
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
)
self.norm = nn.LayerNorm(embed_dim * 2)
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
B, C, H, W = x.shape # C == 1
pad_l = pad_t = 0
pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
x = self.proj(x)
out_size = x.shape[2:]
patch_coord = (
coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size
+ self.patch_size / 2
) # in feature coordinate space
patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
if self.pe == "linear":
patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
elif self.pe == "exp":
patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(
B, -1, out_size[0], out_size[1]
)
x_pe = torch.cat([x, patch_coord_enc], dim=1)
x = self.ffn_with_coord(x_pe)
x = self.norm(x.flatten(2).transpose(1, 2))
return x, out_size
from .twins import Block, CrossBlock
class GroupVerticalSelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(GroupVerticalSelfAttentionLayer, self).__init__()
self.cfg = cfg
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
embed_dim = dim
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = dropout
attn_drop_rate = 0.0
self.block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
groupattention=True,
cfg=self.cfg,
)
def forward(self, x, size, context=None):
x = self.block(x, size, context)
return x
class VerticalSelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(VerticalSelfAttentionLayer, self).__init__()
self.cfg = cfg
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
embed_dim = dim
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = dropout
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
)
self.global_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
)
def forward(self, x, size, context=None):
x = self.local_block(x, size, context)
x = self.global_block(x, size, context)
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class SelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(SelfAttentionLayer, self).__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.multi_head_attn = MultiHeadAttention(dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(dim, dim, bias=True),
nn.Linear(dim, dim, bias=True),
nn.Linear(dim, dim, bias=True),
)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = x
x = self.norm1(x)
q, k, v = self.q(x), self.k(x), self.v(x)
x = self.multi_head_attn(q, k, v)
x = self.proj(x)
x = short_cut + self.proj_drop(x)
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class CrossAttentionLayer(nn.Module):
def __init__(
self,
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(CrossAttentionLayer, self).__init__()
assert (
qk_dim % num_heads == 0
), f"dim {qk_dim} should be divided by num_heads {num_heads}."
assert (
v_dim % num_heads == 0
), f"dim {v_dim} should be divided by num_heads {num_heads}."
"""
Query Token: [N, C] -> [N, qk_dim] (Q)
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
"""
self.num_heads = num_heads
head_dim = qk_dim // num_heads
self.scale = head_dim**-0.5
self.norm1 = nn.LayerNorm(query_token_dim)
self.norm2 = nn.LayerNorm(query_token_dim)
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(query_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, v_dim, bias=True),
)
self.proj = nn.Linear(v_dim, query_token_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(query_token_dim, query_token_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(query_token_dim, query_token_dim),
nn.Dropout(dropout),
)
def forward(self, query, tgt_token):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = query
query = self.norm1(query)
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
x = self.multi_head_attn(q, k, v)
x = short_cut + self.proj_drop(self.proj(x))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class CostPerceiverEncoder(nn.Module):
def __init__(self, cfg):
super(CostPerceiverEncoder, self).__init__()
self.cfg = cfg
self.patch_size = cfg.patch_size
self.patch_embed = PatchEmbed(
in_chans=self.cfg.cost_heads_num,
patch_size=self.patch_size,
embed_dim=cfg.cost_latent_input_dim,
pe=cfg.pe,
)
self.depth = cfg.encoder_depth
self.latent_tokens = nn.Parameter(
torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)
)
query_token_dim, tgt_token_dim = (
cfg.cost_latent_dim,
cfg.cost_latent_input_dim * 2,
)
qk_dim, v_dim = query_token_dim, query_token_dim
self.input_layer = CrossAttentionLayer(
qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout
)
if cfg.use_mlp:
self.encoder_layers = nn.ModuleList(
[
MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
for idx in range(self.depth)
]
)
else:
self.encoder_layers = nn.ModuleList(
[
SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
for idx in range(self.depth)
]
)
if self.cfg.vertical_conv:
self.vertical_encoder_layers = nn.ModuleList(
[ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]
)
else:
self.vertical_encoder_layers = nn.ModuleList(
[
VerticalSelfAttentionLayer(
cfg.cost_latent_dim, cfg, dropout=cfg.dropout
)
for idx in range(self.depth)
]
)
self.cost_scale_aug = None
if "cost_scale_aug" in cfg.keys():
self.cost_scale_aug = cfg.cost_scale_aug
print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug))
def forward(self, cost_volume, data, context=None):
B, heads, H1, W1, H2, W2 = cost_volume.shape
cost_maps = (
cost_volume.permute(0, 2, 3, 1, 4, 5)
.contiguous()
.view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
)
data["cost_maps"] = cost_maps
if self.cost_scale_aug is not None:
scale_factor = (
torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
.uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1])
.to(cost_maps.device)
)
cost_maps = cost_maps * scale_factor
x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C
data["H3W3"] = size
H3, W3 = size
x = self.input_layer(self.latent_tokens, x)
short_cut = x
for idx, layer in enumerate(self.encoder_layers):
x = layer(x)
if self.cfg.vertical_conv:
# B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1
x = (
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
.permute(0, 3, 1, 2)
.reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1)
)
x = self.vertical_encoder_layers[idx](x)
# B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D
x = (
x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1)
.permute(0, 2, 3, 1)
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
)
else:
x = (
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
.permute(0, 2, 1, 3)
.reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1)
)
x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
x = (
x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1)
.permute(0, 2, 1, 3)
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
)
if self.cfg.cost_encoder_res is True:
x = x + short_cut
# print("~~~~")
return x
class MemoryEncoder(nn.Module):
def __init__(self, cfg):
super(MemoryEncoder, self).__init__()
self.cfg = cfg
if cfg.fnet == "twins":
self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
elif cfg.fnet == "basicencoder":
self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
else:
exit()
self.channel_convertor = nn.Conv2d(
cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False
)
self.cost_perceiver_encoder = CostPerceiverEncoder(cfg)
def corr(self, fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = rearrange(
fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
)
fmap2 = rearrange(
fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
)
corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2)
corr = corr.permute(0, 2, 1, 3).view(
batch * ht * wd, self.cfg.cost_heads_num, ht, wd
)
# corr = self.norm(self.relu(corr))
corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute(
0, 2, 1, 3
)
corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd)
return corr
def forward(self, img1, img2, data, context=None, return_feat=False):
# The original implementation
# feat_s = self.feat_encoder(img1)
# feat_t = self.feat_encoder(img2)
# feat_s = self.channel_convertor(feat_s)
# feat_t = self.channel_convertor(feat_t)
imgs = torch.cat([img1, img2], dim=0)
feats = self.feat_encoder(imgs)
feats = self.channel_convertor(feats)
B = feats.shape[0] // 2
feat_s = feats[:B]
if return_feat:
ffeat = feats[:B]
feat_t = feats[B:]
B, C, H, W = feat_s.shape
size = (H, W)
if self.cfg.feat_cross_attn:
feat_s = feat_s.flatten(2).transpose(1, 2)
feat_t = feat_t.flatten(2).transpose(1, 2)
for layer in self.layers:
feat_s, feat_t = layer(feat_s, feat_t, size)
feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
cost_volume = self.corr(feat_s, feat_t)
x = self.cost_perceiver_encoder(cost_volume, data, context)
if return_feat:
return x, ffeat
return x

View File

@@ -0,0 +1,123 @@
import torch
from torch import nn, einsum
from einops import rearrange
class RelPosEmb(nn.Module):
def __init__(self, max_pos_size, dim_head):
super().__init__()
self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(
max_pos_size
).view(-1, 1)
rel_ind = deltas + max_pos_size - 1
self.register_buffer("rel_ind", rel_ind)
def forward(self, q):
batch, heads, h, w, c = q.shape
height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h)
width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w)
height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb)
width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb)
return height_score + width_score
class Attention(nn.Module):
def __init__(
self,
*,
args,
dim,
max_pos_size=100,
heads=4,
dim_head=128,
):
super().__init__()
self.args = args
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
self.pos_emb = RelPosEmb(max_pos_size, dim_head)
for param in self.pos_emb.parameters():
param.requires_grad = False
def forward(self, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
q, k = self.to_qk(fmap).chunk(2, dim=1)
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
q = self.scale * q
# if self.args.position_only:
# sim = self.pos_emb(q)
# elif self.args.position_and_content:
# sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
# sim_pos = self.pos_emb(q)
# sim = sim_content + sim_pos
# else:
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
attn = sim.softmax(dim=-1)
return attn
class Aggregate(nn.Module):
def __init__(
self,
args,
dim,
heads=4,
dim_head=128,
):
super().__init__()
self.args = args
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.gamma = nn.Parameter(torch.zeros(1))
if dim != inner_dim:
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
else:
self.project = None
def forward(self, attn, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
v = self.to_v(fmap)
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
if self.project is not None:
out = self.project(out)
out = fmap + self.gamma * out
return out
if __name__ == "__main__":
att = Attention(dim=128, heads=1)
fmap = torch.randn(2, 128, 40, 90)
out = att(fmap)
print(out.shape)

View File

@@ -0,0 +1,160 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
if args.only_global:
print("[Decoding with only global cost]")
cor_planes = args.query_latent_dim
else:
cor_planes = 81 + args.query_latent_dim
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
from .gma import Aggregate
class GMAUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128):
super().__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(
hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim
)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1)
def forward(self, net, inp, corr, flow, attention):
motion_features = self.encoder(flow, corr)
motion_features_global = self.aggregator(attention, motion_features)
inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
# Attentional update
net = self.gru(net, inp_cat)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,55 @@
from torch import nn
from einops.layers.torch import Rearrange, Reduce
from functools import partial
import numpy as np
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout),
)
class MLPMixerLayer(nn.Module):
def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0):
super(MLPMixerLayer, self).__init__()
# print(f"use mlp mixer layer")
K = cfg.cost_latent_token_num
expansion_factor = cfg.mlp_expansion_factor
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
self.mlpmixer = nn.Sequential(
PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)),
PreNormResidual(
dim, FeedForward(dim, expansion_factor, dropout, chan_last)
),
)
def compute_params(self):
num = 0
for param in self.mlpmixer.parameters():
num += np.prod(param.size())
return num
def forward(self, x):
"""
x: [BH1W1, K, D]
"""
return self.mlpmixer(x)

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from ...utils.utils import coords_grid
from ..encoders import twins_svt_large
from .encoder import MemoryEncoder
from .decoder import MemoryDecoder
from .cnn import BasicEncoder
class FlowFormer(nn.Module):
def __init__(self, cfg):
super(FlowFormer, self).__init__()
self.cfg = cfg
self.memory_encoder = MemoryEncoder(cfg)
self.memory_decoder = MemoryDecoder(cfg)
if cfg.cnet == "twins":
self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
elif cfg.cnet == "basicencoder":
self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
def build_coord(self, img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8)
return coords
def forward(
self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None
):
# Following https://github.com/princeton-vl/RAFT/
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
data = {}
if self.cfg.context_concat:
context = self.context_encoder(torch.cat([image1, image2], dim=1))
else:
if return_feat:
context, cfeat = self.context_encoder(image1, return_feat=return_feat)
else:
context = self.context_encoder(image1)
if return_feat:
cost_memory, ffeat = self.memory_encoder(
image1, image2, data, context, return_feat=return_feat
)
else:
cost_memory = self.memory_encoder(image1, image2, data, context)
flow_predictions = self.memory_decoder(
cost_memory, context, data, flow_init=flow_init, iters=iters
)
if return_feat:
return flow_predictions, cfeat, ffeat
return flow_predictions

View File

@@ -0,0 +1,7 @@
def build_flowformer(cfg):
name = cfg.transformer
if name == "latentcostformer":
from .LatentCostFormer.transformer import FlowFormer
else:
raise ValueError(f"FlowFormer = {name} is not a valid architecture!")
return FlowFormer(cfg[name])

View File

@@ -0,0 +1,562 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from ..utils.utils import bilinear_sampler, indexing
def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300):
"""
x is of shape [*, 2]. The last dimension are two coordinates (x and y).
"""
freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device)
return torch.cat(
[
x * NORMALIZE_FACOR,
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
],
dim=-1,
)
def sampler_gaussian(latent, mean, std, image_size, point_num=25):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
* STD_MAX
* delta
* 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
return sampled_latents, sampled_weights
def sampler_gaussian_zy(
latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta
if return_deltaXY:
return sampled_latents, sampled_weights, delta_3sigma
else:
return sampled_latents, sampled_weights
def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
* STD_MAX
* delta
* 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
if return_deltaXY:
return sampled_latents, sampled_weights, delta_3sigma
else:
return sampled_latents, sampled_weights
def sampler_gaussian_fix(latent, mean, image_size, point_num=49):
# latent [B, H*W, D]
# mean [B, 2, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return sampled_latents, sampled_weights
def sampler_gaussian_fix_pyramid(
latent, feat_pyramid, scale_weight, mean, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# scale weight [B, H*W, layer_num]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
for i in range(len(feat_pyramid)):
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = (centroid + delta) / 2**i
coords = rearrange(
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
)
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W, point_num, dim, layer_num]
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
vis_out = scale_weight
scale_weight = torch.unsqueeze(
torch.unsqueeze(scale_weight, dim=2), dim=2
) # [B, HW, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_weight, dim=-1
) # [B, H*W, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights, vis_out
def sampler_gaussian_pyramid(
latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# scale weight [B, H*W, layer_num]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
for i in range(len(feat_pyramid)):
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = (centroid + delta_3sigma) / 2**i
coords = rearrange(
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
)
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W, point_num, dim, layer_num]
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
vis_out = scale_weight
scale_weight = torch.unsqueeze(
torch.unsqueeze(scale_weight, dim=2), dim=2
) # [B, HW, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_weight, dim=-1
) # [B, H*W, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights, vis_out
def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25):
"""different heads have different mean"""
# latent [B, H*W, D]
# mean [B, 2, H, W, heands]
H, W = image_size
B, HW, D = latent.shape
_, _, _, _, HEADS = mean.shape
STD_MAX = 20
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = (
torch.stack(torch.meshgrid(dy, dx), axis=-1)
.to(mean.device)
.repeat(HEADS, 1, 1, 1)
) # [HEADS, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
coords = centroid + delta
coords = rearrange(
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
)
sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum]
sampled_latents = sampled_latents.permute(
0, 2, 3, 1
) # [B, H*W*HEADS, pointnum, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return sampled_latents, sampled_weights
def sampler_gaussian_fix_pyramid_MH(
latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W, heands]
# scale_head weight [B, H*W, layer_num*heads]
H, W = image_size
B, HW, D = latent.shape
_, _, _, _, HEADS = mean.shape
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
for i in range(len(feat_pyramid)):
coords = (centroid) / 2**i + delta
coords = rearrange(
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
)
sampled_latents.append(
bilinear_sampler(feat_pyramid[i], coords)
) # [B, dim, H*W*HEADS, point_num]
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W*HEADS, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W*HEADS, point_num, dim, layer_num]
scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1)
scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num]
scale_head_weight = torch.unsqueeze(
torch.unsqueeze(scale_head_weight, dim=2), dim=2
) # [B, H*W*HEADS, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_head_weight, dim=-1
) # [B, H*W*HEADS, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights
def sampler(feat, center, window_size):
# feat [B, C, H, W]
# center [B, 2, H, W]
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
B, H, W, C = center.shape
radius = window_size // 2
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
center.device
) # [B*H*W, window_size, point_num**0.5, 2]
center = center.reshape(B * H * W, 1, 1, 2)
coords = center + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
feat, coords
) # [B*H*W, dim, window_size, window_size]
# sampled_latents = sampled_latents.permute(0, 2, 3, 1)
return sampled_latents
def retrieve_tokens(feat, center, window_size, sampler):
# feat [B, C, H, W]
# center [B, 2, H, W]
radius = window_size // 2
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
center.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
B, H, W, C = center.shape
centroid = center.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
if sampler == "nn":
sampled_latents = indexing(feat, coords)
elif sampler == "bilinear":
sampled_latents = bilinear_sampler(feat, coords)
else:
raise ValueError("invalid sampler")
# [B, dim, H*W, point_num]
return sampled_latents
def pyramid_retrieve_tokens(
feat_pyramid, center, image_size, window_sizes, sampler="bilinear"
):
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
sampled_latents_pyramid = []
for idx in range(len(window_sizes)):
sampled_latents_pyramid.append(
retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler)
)
center = center / 2
return torch.cat(sampled_latents_pyramid, dim=-1)
class FeedForward(nn.Module):
def __init__(self, dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
x = self.net(x)
return x
class MLP(nn.Module):
def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5):
super().__init__()
self.FC1 = nn.Linear(in_dim, innter_dim)
self.FC_out = nn.Linear(innter_dim, out_dim)
self.relu = torch.nn.LeakyReLU(0.2)
self.FC_inter = nn.ModuleList(
[nn.Linear(innter_dim, innter_dim) for i in range(depth)]
)
def forward(self, x):
x = self.FC1(x)
x = self.relu(x)
for inter_fc in self.FC_inter:
x = inter_fc(x)
x = self.relu(x)
x = self.FC_out(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.num_kv_tokens = num_kv_tokens
self.scale = (dim / heads) ** -0.5
self.rpe = cfg.rpe
self.attend = nn.Softmax(dim=-1)
self.use_rpe = use_rpe
if use_rpe:
if rpe_bias is None:
if self.rpe == "element-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(heads, self.num_kv_tokens, dim // heads)
)
elif self.rpe == "head-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(1, heads, 1, self.num_kv_tokens)
)
elif self.rpe == "token-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(1, 1, 1, self.num_kv_tokens)
) # 81 is point_num
elif self.rpe == "implicit":
pass
# self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2)
# raise ValueError('Implicit Encoding Not Implemented')
elif self.rpe == "element-wise-value":
self.rpe_bias = nn.Parameter(
torch.zeros(heads, self.num_kv_tokens, dim // heads)
)
self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim))
else:
raise ValueError("Not Implemented")
else:
self.rpe_bias = rpe_bias
def attend_with_rpe(self, Q, K, rpe_bias):
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = (
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
) # (b hw) heads 1 pointnum
if self.use_rpe:
if self.rpe == "element-wise":
rpe_bias_weight = (
einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale
) # (b hw) heads 1 pointnum
dots = dots + rpe_bias_weight
elif self.rpe == "implicit":
pass
rpe_bias_weight = (
einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale
) # (b hw) heads 1 pointnum
dots = dots + rpe_bias_weight
elif self.rpe == "head-wise" or self.rpe == "token-wise":
dots = dots + rpe_bias
return self.attend(dots), dots
def forward(self, Q, K, V, rpe_bias=None):
if self.use_rpe:
if rpe_bias is None or self.rpe == "element-wise":
rpe_bias = self.rpe_bias
else:
rpe_bias = rearrange(
rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads
)
attn, dots = self.attend_with_rpe(Q, K, rpe_bias)
else:
attn, dots = self.attend_with_rpe(Q, K, None)
B, HW, _ = Q.shape
if V is not None:
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
else:
out = None
# dots = torch.squeeze(dots, 2)
# dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW)
return out, dots

View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import timm
import numpy as np
class twins_svt_large(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model("twins_svt_large", pretrained=pretrained)
del self.svt.head
del self.svt.patch_embeds[2]
del self.svt.patch_embeds[2]
del self.svt.blocks[2]
del self.svt.blocks[2]
del self.svt.pos_block[2]
del self.svt.pos_block[2]
self.svt.norm.weight.requires_grad = False
self.svt.norm.bias.requires_grad = False
def forward(self, x, data=None, layer=2, return_feat=False):
B = x.shape[0]
if return_feat:
feat = []
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if return_feat:
feat.append(x)
if i == layer - 1:
break
if return_feat:
return x, feat
return x
def compute_params(self, layer=2):
num = 0
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
for param in embed.parameters():
num += np.prod(param.size())
for param in drop.parameters():
num += np.prod(param.size())
for param in blocks.parameters():
num += np.prod(param.size())
for param in pos_blk.parameters():
num += np.prod(param.size())
if i == layer - 1:
break
for param in self.svt.head.parameters():
num += np.prod(param.size())
return num
class twins_svt_large_context(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained)
def forward(self, x, data=None, layer=2):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i == layer - 1:
break
return x
if __name__ == "__main__":
m = twins_svt_large()
input = torch.randn(2, 3, 400, 800)
out = m.extract_feature(input)
print(out.shape)

View File

@@ -0,0 +1,90 @@
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@@ -0,0 +1,267 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

View File

@@ -0,0 +1,100 @@
import math
import torch
from torch import nn
class PositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(
torch.arange(0, d_model // 2, 2).float()
* (-math.log(10000.0) / d_model // 2)
)
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)
self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, : x.size(2), : x.size(3)]
class LinearPositionEncoding(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = (
torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1
) / max_shape[0]
x_position = (
torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1
) / max_shape[1]
div_term = torch.arange(0, d_model // 2, 2).float()
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi)
pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi)
pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi)
pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi)
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
# assert x.shape[2] == 80 and x.shape[3] == 80
return x + self.pe[:, :, : x.size(2), : x.size(3)]
class LearnedPositionEncoding(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(80, 80)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model))
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
# assert x.shape[2] == 80 and x.shape[3] == 80
return x + self.pe

View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,113 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
elif mode == "kitti400":
self._pad = [0, 0, 0, 400 - self.ht]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def indexing(img, coords, mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
"""
TODO: directly indexing features instead of sampling
"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True, mode="nearest")
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode="bilinear"):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

View File

@@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# motif: https://github.com/sichun233746/MoTIF
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configs import GIMMConfig
from .modules.coord_sampler import CoordSampler3D
from .modules.fi_components import LateralBlock
from .modules.hyponet import HypoNet
from .modules.fi_utils import warp
from .modules.softsplat import softsplat
class GIMM(nn.Module):
Config = GIMMConfig
def __init__(self, config: GIMMConfig):
super().__init__()
self.config = config = config.copy()
self.hyponet_config = config.hyponet
self.coord_sampler = CoordSampler3D(config.coord_range)
self.fwarp_type = config.fwarp_type
# Motion Encoder
channel = 32
in_dim = 2
self.cnn_encoder = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
LateralBlock(channel),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
# Latent Refiner
channel = 64
in_dim = 64
self.res_conv = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
self.g_filter = torch.nn.Parameter(
torch.FloatTensor(
[
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
]
).reshape(1, 1, 1, 3, 3),
requires_grad=False,
)
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
def cal_splatting_weights(self, raft_flow01, raft_flow10):
batch_size = raft_flow01.shape[0]
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
## flow variance metric
sqaure_mean, mean_square = torch.split(
F.conv3d(
F.pad(
torch.cat([raft_flows**2, raft_flows], 1),
(1, 1, 1, 1),
mode="reflect",
).unsqueeze(1),
self.g_filter,
).squeeze(1),
2,
dim=1,
)
var = (
(sqaure_mean - mean_square**2)
.clamp(1e-9, None)
.sqrt()
.mean(1)
.unsqueeze(1)
)
var01 = var[:batch_size]
var10 = var[batch_size:]
## flow warp metirc
f01_warp = -warp(raft_flow10, raft_flow01)
f10_warp = -warp(raft_flow01, raft_flow10)
err01 = (
torch.nn.functional.l1_loss(
input=f01_warp, target=raft_flow01, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
err02 = (
torch.nn.functional.l1_loss(
input=f10_warp, target=raft_flow10, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
return weights1, weights2
def forward(
self, xs, coord=None, keep_xs_shape=True, ori_flow=None, timesteps=None
):
coord = self.sample_coord_input(xs) if coord is None else coord
raft_flow01 = ori_flow[:, :, 0]
raft_flow10 = ori_flow[:, :, 1]
# calculate splatting metrics
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
# b,c,h,w
pixel_latent_0 = self.cnn_encoder(xs[:, :, 0])
pixel_latent_1 = self.cnn_encoder(xs[:, :, 1])
pixel_latent = []
modulation_params_dict = None
strtype = self.fwarp_type
if isinstance(timesteps, list):
assert isinstance(coord, list)
assert len(timesteps) == len(coord)
for i, cur_t in enumerate(timesteps):
cur_t = cur_t.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
all_outputs = []
for idx, c in enumerate(coord):
outputs = self.hyponet(
c,
modulation_params_dict=modulation_params_dict,
pixel_latent=pixel_latent[idx],
)
if keep_xs_shape:
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
outputs = outputs.permute(0, -1, *permute_idx_range)
all_outputs.append(outputs)
return all_outputs
else:
cur_t = timesteps.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent = tmp_pixel_latent.permute(0, 2, 3, 1)
# predict all pixels of coord after applying the modulation_parms into hyponet
outputs = self.hyponet(
coord,
modulation_params_dict=modulation_params_dict,
pixel_latent=pixel_latent,
)
if keep_xs_shape:
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
outputs = outputs.permute(0, -1, *permute_idx_range)
return outputs
def compute_loss(self, preds, targets, reduction="mean", single=False):
assert reduction in ["mean", "sum", "none"]
batch_size = preds.shape[0]
sample_mses = 0
assert preds.shape[2] == 1
assert targets.shape[2] == 1
for i in range(preds.shape[2]):
sample_mses += torch.reshape(
(preds[:, :, i] - targets[:, :, i]) ** 2, (batch_size, -1)
).mean(dim=-1)
sample_mses = sample_mses / preds.shape[2]
if reduction == "mean":
total_loss = sample_mses.mean()
psnr = (-10 * torch.log10(sample_mses)).mean()
elif reduction == "sum":
total_loss = sample_mses.sum()
psnr = (-10 * torch.log10(sample_mses)).sum()
else:
total_loss = sample_mses
psnr = -10 * torch.log10(sample_mses)
return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr}
def sample_coord_input(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
assert device is not None
assert coord_range is None
coord_inputs = self.coord_sampler(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
return coord_inputs

View File

@@ -0,0 +1,471 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# motif: https://github.com/sichun233746/MoTIF
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configs import GIMMVFIConfig
from .modules.coord_sampler import CoordSampler3D
from .modules.hyponet import HypoNet
from .modules.fi_components import *
#from .flowformer import initialize_Flowformer
from .modules.fi_utils import normalize_flow, unnormalize_flow, warp, resize
from .raft.corr import BidirCorrBlock
from .modules.softsplat import softsplat
class GIMMVFI_F(nn.Module):
Config = GIMMVFIConfig
def __init__(self, dtype, config: GIMMVFIConfig):
super().__init__()
self.config = config = config.copy()
self.hyponet_config = config.hyponet
self.raft_iter = config.raft_iter
self.dtype = dtype
######### Encoder and Decoder Settings #########
#self.flow_estimator = initialize_Flowformer()
f_dims = [256, 128]
skip_channels = f_dims[-1] // 2
self.num_flows = 3
self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels)
self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels)
self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0)
self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None)
self.amt_comb_block = nn.Sequential(
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
nn.PReLU(6 * self.num_flows),
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
)
################ GIMM settings #################
self.coord_sampler = CoordSampler3D(config.coord_range)
self.g_filter = torch.nn.Parameter(
torch.FloatTensor(
[
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
]
).reshape(1, 1, 1, 3, 3),
requires_grad=False,
)
self.fwarp_type = config.fwarp_type
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
channel = 32
in_dim = 2
self.cnn_encoder = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
LateralBlock(channel),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
channel = 64
in_dim = 64
self.res_conv = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(
cdim=cdim,
hidden_dim=192,
flow_dim=64,
corr_dim=256,
corr_dim2=192,
fc_dim=188,
scale_factor=scale_factor,
corr_levels=4,
radius=4,
)
def cal_bidirection_flow(self, im0, im1):
f01, features0, fnet0 = self.flow_estimator(
im0, im1, return_feat=True, iters=None
)
f10, features1, fnet1 = self.flow_estimator(
im1, im0, return_feat=True, iters=None
)
f01 = f01[0]
f10 = f10[0]
corr_fn = BidirCorrBlock(fnet0, fnet1, radius=4)
flow01 = f01.unsqueeze(2)
flow10 = f10.unsqueeze(2)
noraml_flows = torch.cat([flow01, -flow10], dim=2)
noraml_flows, flow_scalers = normalize_flow(noraml_flows)
ori_flows = torch.cat([flow01, flow10], dim=2)
return (
noraml_flows,
ori_flows,
flow_scalers,
features0,
features1,
corr_fn,
torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2),
)
def predict_flow(self, f, coord, t, flows):
raft_flow01 = flows[:, :, 0].detach()
raft_flow10 = flows[:, :, 1].detach()
# calculate splatting metrics
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
strtype = self.fwarp_type + "-zeroeps"
# b,c,h,w
pixel_latent_0 = self.cnn_encoder(f[:, :, 0])
pixel_latent_1 = self.cnn_encoder(f[:, :, 1])
pixel_latent = []
for i, cur_t in enumerate(t):
cur_t = cur_t.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype,
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype,
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
all_outputs = []
permute_idx_range = [i for i in range(1, f.ndim - 1)]
for idx, c in enumerate(coord):
assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze()
assert isinstance(c, tuple)
if c[1] is None:
outputs = self.hyponet(
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
).permute(0, -1, *permute_idx_range)
else:
outputs = self.hyponet(
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
)
all_outputs.append(outputs)
return all_outputs
def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1):
ft0 = scale * resize(ft0, scale_factor=scale)
ft1 = scale * resize(ft1, scale_factor=scale)
mask = resize(mask, scale_factor=scale).sigmoid()
img0_warp = warp(img0, ft0)
img1_warp = warp(img1, ft1)
img_warp = mask * img0_warp + (1 - mask) * img1_warp
return img_warp
@torch.compiler.disable()
def frame_synthesize(
self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None
):
"""
flow_t: b,2,h,w
cur_t: b,1,1,1
"""
batch_size = img_xs.shape[0]
img0 = 2 * img_xs[:, :, 0] - 1.0
img1 = 2 * img_xs[:, :, 1] - 1.0
##################### update the predicted flow #####################
## initialize coordinates for looking up
lookup_coord = self.flow_estimator.build_coord(img_xs[:, :, 0]).to(
img_xs[:, :, 0].device
)
flow_t0_fullsize = flow_t * (-cur_t)
flow_t1_fullsize = flow_t * (1.0 - cur_t)
inv = 1 / 4
flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv)
flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv)
############################# scale 1/4 #############################
# i. Initialize feature t at scale 1/4
flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder(
features0[-1],
features1[-1],
flow_t0_inr4,
flow_t1_inr4,
img0=img0,
img1=img1,
)
mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:]
img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4)
img_warp_4 = (img_warp_4 + 1.0) / 2
img_warp_4 = torch.clamp(img_warp_4, 0, 1)
corr_4, flow_4_lr = self._amt_corr_scale_lookup(
corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2
)
delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4)
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
flowt0_4 = flowt0_4 + delta_flow0_4
flowt1_4 = flowt1_4 + delta_flow1_4
ft_4_ = ft_4_ + delta_ft_4_
# iii. residue update with lookup corr
corr_4 = resize(corr_4, scale_factor=2.0)
flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1)
delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4)
flowt0_4 = flowt0_4 + delta_flow_4[:, :2]
flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4]
ft_4_ = ft_4_ + delta_ft_4_
############################# scale 1/1 #############################
flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder(
ft_4_,
features0[0],
features1[0],
flowt0_4,
flowt1_4,
mask=mask_4_,
img0=img0,
img1=img1,
)
if full_img is not None:
img0 = 2 * full_img[:, :, 0] - 1.0
img1 = 2 * full_img[:, :, 1] - 1.0
inv = img1.shape[2] / flowt0_1.shape[2]
flowt0_1 = inv * resize(flowt0_1, scale_factor=inv)
flowt1_1 = inv * resize(flowt1_1, scale_factor=inv)
flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv)
flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv)
mask = resize(mask, scale_factor=inv)
img_res = resize(img_res, scale_factor=inv)
imgt_pred = multi_flow_combine(
self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None
)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
######################################################################
flowt0_1 = flowt0_1.reshape(
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
)
flowt1_1 = flowt1_1.reshape(
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
)
flowt0_pred = [flowt0_1, flowt0_4]
flowt1_pred = [flowt1_1, flowt1_4]
other_pred = [img_warp_4]
return imgt_pred, flowt0_pred, flowt1_pred, other_pred
def forward(self, img_xs, coord=None, t=None, ds_factor=None):
assert isinstance(t, list)
assert isinstance(coord, list)
assert len(t) == len(coord)
full_size_img = None
if ds_factor is not None:
full_size_img = img_xs.clone()
img_xs = torch.cat(
[
resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2),
resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2),
],
dim=2,
)
(
normal_flows,
flows,
flow_scalers,
features0,
features1,
corr_fn,
preserved_raft_flows,
) = self.cal_bidirection_flow(255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1])
assert coord is not None
# List of flows
normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows)
############ Unnormalize the predicted/reconstructed flow ############
start_idx = 0
if coord[0][1] is not None:
# Subsmapled flows for reconstruction supervision in the GIMM module
# In such case, first two coords in the list are subsampled for supervision up-mentioned
# Normalized flow_t towards positive t-axis
assert len(coord) > 2
flow_t = [
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
for i in range(2, len(coord))
]
start_idx = 2
else:
flow_t = [
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
for i in range(len(coord))
]
imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], []
for idx in range(start_idx, len(coord)):
cur_flow_t = flow_t[idx - start_idx]
cur_t = t[idx].reshape(-1, 1, 1, 1)
if cur_flow_t.ndim != 4:
cur_flow_t = cur_flow_t.unsqueeze(0)
assert cur_flow_t.ndim == 4
imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize(
img_xs,
cur_flow_t,
features0,
features1,
corr_fn,
cur_t,
full_img=full_size_img,
)
imgt_preds.append(imgt_pred)
flowt0_preds.append(flowt0_pred)
flowt1_preds.append(flowt1_pred)
all_others.append(others)
return {
"imgt_pred": imgt_preds,
"other_pred": all_others,
"flowt0_pred": flowt0_preds,
"flowt1_pred": flowt1_preds,
"raft_flow": preserved_raft_flows,
"ninrflow": normal_inr_flows,
"nflow": normal_flows,
"flowt": flow_t,
}
def warp_frame(self, frame, flow):
return warp(frame, flow)
def sample_coord_input(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
assert device is not None
assert coord_range is None
coord_inputs = self.coord_sampler(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
return coord_inputs
def cal_splatting_weights(self, raft_flow01, raft_flow10):
batch_size = raft_flow01.shape[0]
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
## flow variance metric
sqaure_mean, mean_square = torch.split(
F.conv3d(
F.pad(
torch.cat([raft_flows**2, raft_flows], 1),
(1, 1, 1, 1),
mode="reflect",
).unsqueeze(1),
self.g_filter,
).squeeze(1),
2,
dim=1,
)
var = (
(sqaure_mean - mean_square**2)
.clamp(1e-9, None)
.sqrt()
.mean(1)
.unsqueeze(1)
)
var01 = var[:batch_size]
var10 = var[batch_size:]
## flow warp metirc
f01_warp = -warp(raft_flow10, raft_flow01)
f10_warp = -warp(raft_flow01, raft_flow10)
err01 = (
torch.nn.functional.l1_loss(
input=f01_warp, target=raft_flow01, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
err02 = (
torch.nn.functional.l1_loss(
input=f10_warp, target=raft_flow10, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
return weights1, weights2
def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t0_scale = 1.0 / embt
t1_scale = 1.0 / (1.0 - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow

View File

@@ -0,0 +1,508 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# motif: https://github.com/sichun233746/MoTIF
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
from .configs import GIMMVFIConfig
from .modules.coord_sampler import CoordSampler3D
from .modules.hyponet import HypoNet
from .modules.fi_components import *
from .modules.fi_utils import (
normalize_flow,
unnormalize_flow,
warp,
resize,
build_coord,
)
import torch.nn.functional as F
from .raft.corr import BidirCorrBlock
from .modules.softsplat import softsplat
class GIMMVFI_R(nn.Module):
Config = GIMMVFIConfig
def __init__(self, dtype, config: GIMMVFIConfig):
super().__init__()
self.config = config = config.copy()
self.hyponet_config = config.hyponet
self.raft_iter = 20
######### Encoder and Decoder Settings #########
#self.flow_estimator = initialize_RAFT()
cur_f_dims = [128, 96]
f_dims = [256, 128]
self.dtype = dtype
skip_channels = f_dims[-1] // 2
self.num_flows = 3
self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1)
self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1)
self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1)
self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels)
self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels)
self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0)
self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None)
self.amt_comb_block = nn.Sequential(
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
nn.PReLU(6 * self.num_flows),
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
)
################ GIMM settings #################
self.coord_sampler = CoordSampler3D(config.coord_range)
self.g_filter = torch.nn.Parameter(
torch.FloatTensor(
[
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
]
).reshape(1, 1, 1, 3, 3),
requires_grad=False,
)
self.fwarp_type = config.fwarp_type
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
channel = 32
in_dim = 2
self.cnn_encoder = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
LateralBlock(channel),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
channel = 64
in_dim = 64
self.res_conv = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(
cdim=cdim,
hidden_dim=192,
flow_dim=64,
corr_dim=256,
corr_dim2=192,
fc_dim=188,
scale_factor=scale_factor,
corr_levels=4,
radius=4,
)
def cal_bidirection_flow(self, im0, im1, iters=20):
f01, features0, fnet0 = self.flow_estimator(
im0.to(self.dtype), im1.to(self.dtype), return_feat=True, iters=20
)
f10, features1, fnet1 = self.flow_estimator(
im1.to(self.dtype), im0.to(self.dtype), return_feat=True, iters=20
)
corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4)
features0 = [
self.amt_second_last_cproj(features0[0]),
self.amt_last_cproj(features0[1]),
]
features1 = [
self.amt_second_last_cproj(features1[0]),
self.amt_last_cproj(features1[1]),
]
flow01 = f01.unsqueeze(2)
flow10 = f10.unsqueeze(2)
noraml_flows = torch.cat([flow01, -flow10], dim=2)
noraml_flows, flow_scalers = normalize_flow(noraml_flows)
ori_flows = torch.cat([flow01, flow10], dim=2)
return (
noraml_flows,
ori_flows,
flow_scalers,
features0,
features1,
corr_fn,
torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2),
)
def predict_flow(self, f, coord, t, flows):
raft_flow01 = flows[:, :, 0].detach()
raft_flow10 = flows[:, :, 1].detach()
# calculate splatting metrics
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
strtype = self.fwarp_type + "-zeroeps"
# b,c,h,w
pixel_latent_0 = self.cnn_encoder(f[:, :, 0])
pixel_latent_1 = self.cnn_encoder(f[:, :, 1])
pixel_latent = []
for i, cur_t in enumerate(t):
cur_t = cur_t.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype,
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype,
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
all_outputs = []
permute_idx_range = [i for i in range(1, f.ndim - 1)]
for idx, c in enumerate(coord):
assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze()
assert isinstance(c, tuple)
if c[1] is None:
outputs = self.hyponet(
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
).permute(0, -1, *permute_idx_range)
else:
outputs = self.hyponet(
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
)
all_outputs.append(outputs)
return all_outputs
def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1):
ft0 = scale * resize(ft0, scale_factor=scale)
ft1 = scale * resize(ft1, scale_factor=scale)
mask = resize(mask, scale_factor=scale).sigmoid()
img0_warp = warp(img0, ft0)
img1_warp = warp(img1, ft1)
img_warp = mask * img0_warp + (1 - mask) * img1_warp
return img_warp
@torch.compiler.disable()
def frame_synthesize(
self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None
):
"""
flow_t: b,2,h,w
cur_t: b,1,1,1
"""
batch_size = img_xs.shape[0] # b,c,t,h,w
img0 = 2 * img_xs[:, :, 0] - 1.0
img1 = 2 * img_xs[:, :, 1] - 1.0
##################### update the predicted flow #####################
##initialize coordinates for looking up
lookup_coord = build_coord(img_xs[:, :, 0]).to(
img_xs[:, :, 0].device
) # H//8,W//8
flow_t0_fullsize = flow_t * (-cur_t)
flow_t1_fullsize = flow_t * (1.0 - cur_t)
inv = 1 / 4
flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv)
flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv)
############################# scale 1/4 #############################
# i. Initialize feature t at scale 1/4
flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder(
features0[-1],
features1[-1],
flow_t0_inr4,
flow_t1_inr4,
img0=img0,
img1=img1,
)
features0, features1 = features0[:-1], features1[:-1]
mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:]
img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4)
img_warp_4 = (img_warp_4 + 1.0) / 2
img_warp_4 = torch.clamp(img_warp_4, 0, 1)
corr_4, flow_4_lr = self._amt_corr_scale_lookup(
corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2
)
delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4)
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
flowt0_4 = flowt0_4 + delta_flow0_4
flowt1_4 = flowt1_4 + delta_flow1_4
ft_4_ = ft_4_ + delta_ft_4_
# iii. residue update with lookup corr
corr_4 = resize(corr_4, scale_factor=2.0)
flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1)
delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4)
flowt0_4 = flowt0_4 + delta_flow_4[:, :2]
flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4]
ft_4_ = ft_4_ + delta_ft_4_
############################# scale 1/1 #############################
flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder(
ft_4_,
features0[0],
features1[0],
flowt0_4,
flowt1_4,
mask=mask_4_,
img0=img0,
img1=img1,
)
if full_img is not None:
img0 = 2 * full_img[:, :, 0] - 1.0
img1 = 2 * full_img[:, :, 1] - 1.0
inv = img1.shape[2] / flowt0_1.shape[2]
flowt0_1 = inv * resize(flowt0_1, scale_factor=inv)
flowt1_1 = inv * resize(flowt1_1, scale_factor=inv)
flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv)
flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv)
mask = resize(mask, scale_factor=inv)
img_res = resize(img_res, scale_factor=inv)
imgt_pred = multi_flow_combine(
self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None
)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
######################################################################
flowt0_1 = flowt0_1.reshape(
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
)
flowt1_1 = flowt1_1.reshape(
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
)
flowt0_pred = [flowt0_1, flowt0_4]
flowt1_pred = [flowt1_1, flowt1_4]
other_pred = [img_warp_4]
return imgt_pred, flowt0_pred, flowt1_pred, other_pred
def forward(self, img_xs, coord=None, t=None, iters=None, ds_factor=None):
assert isinstance(t, list)
assert isinstance(coord, list)
assert len(t) == len(coord)
full_size_img = None
if ds_factor is not None:
full_size_img = img_xs.clone()
img_xs = torch.cat(
[
resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2),
resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2),
],
dim=2,
)
iters = self.raft_iter if iters is None else iters
(
normal_flows,
flows,
flow_scalers,
features0,
features1,
corr_fn,
preserved_raft_flows,
) = self.cal_bidirection_flow(
255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters
)
assert coord is not None
# List of flows
normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows)
############ Unnormalize the predicted/reconstructed flow ############
start_idx = 0
if coord[0][1] is not None:
# Subsmapled flows for reconstruction supervision in the GIMM module
# In such case, by default, first two coords are subsampled for supervision up-mentioned
# normalized flow_t versus positive t-axis
assert len(coord) > 2
flow_t = [
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
for i in range(2, len(coord))
]
start_idx = 2
else:
flow_t = [
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
for i in range(len(coord))
]
imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], []
for idx in range(start_idx, len(coord)):
cur_flow_t = flow_t[idx - start_idx]
cur_t = t[idx].reshape(-1, 1, 1, 1)
if cur_flow_t.ndim != 4:
cur_flow_t = cur_flow_t.unsqueeze(0)
assert cur_flow_t.ndim == 4
imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize(
img_xs,
cur_flow_t,
features0,
features1,
corr_fn,
cur_t,
full_img=full_size_img,
)
imgt_preds.append(imgt_pred)
flowt0_preds.append(flowt0_pred)
flowt1_preds.append(flowt1_pred)
all_others.append(others)
return {
"imgt_pred": imgt_preds,
"other_pred": all_others,
"flowt0_pred": flowt0_preds,
"flowt1_pred": flowt1_preds,
"raft_flow": preserved_raft_flows,
"ninrflow": normal_inr_flows,
"nflow": normal_flows,
"flowt": flow_t,
}
def warp_frame(self, frame, flow):
return warp(frame, flow)
def compute_psnr(self, preds, targets, reduction="mean"):
assert reduction in ["mean", "sum", "none"]
batch_size = preds.shape[0]
sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean(
dim=-1
)
if reduction == "mean":
psnr = (-10 * torch.log10(sample_mses)).mean()
elif reduction == "sum":
psnr = (-10 * torch.log10(sample_mses)).sum()
else:
psnr = -10 * torch.log10(sample_mses)
return psnr
def sample_coord_input(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
assert device is not None
assert coord_range is None
coord_inputs = self.coord_sampler(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
return coord_inputs
def cal_splatting_weights(self, raft_flow01, raft_flow10):
batch_size = raft_flow01.shape[0]
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
## flow variance metric
sqaure_mean, mean_square = torch.split(
F.conv3d(
F.pad(
torch.cat([raft_flows**2, raft_flows], 1),
(1, 1, 1, 1),
mode="reflect",
).unsqueeze(1),
self.g_filter,
).squeeze(1),
2,
dim=1,
)
var = (
(sqaure_mean - mean_square**2)
.clamp(1e-9, None)
.sqrt()
.mean(1)
.unsqueeze(1)
)
var01 = var[:batch_size]
var10 = var[batch_size:]
## flow warp metirc
f01_warp = -warp(raft_flow10, raft_flow01)
f10_warp = -warp(raft_flow01, raft_flow10)
err01 = (
torch.nn.functional.l1_loss(
input=f01_warp, target=raft_flow01, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
err02 = (
torch.nn.functional.l1_loss(
input=f10_warp, target=raft_flow10, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
return weights1, weights2
def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t0_scale = 1.0 / embt
t1_scale = 1.0 / (1.0 - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow

View File

@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
class CoordSampler3D(nn.Module):
def __init__(self, coord_range, t_coord_only=False):
super().__init__()
self.coord_range = coord_range
self.t_coord_only = t_coord_only
def shape2coordinate(
self,
batch_size,
spatial_shape,
t_ids,
coord_range=(-1.0, 1.0),
upsample_ratio=1,
device=None,
):
coords = []
assert isinstance(t_ids, list)
_coords = torch.tensor(t_ids, device=device) / 1.0
coords.append(_coords.to(torch.float32))
for num_s in spatial_shape:
num_s = int(num_s * upsample_ratio)
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
coords.append(_coords)
coords = torch.meshgrid(*coords, indexing="ij")
coords = torch.stack(coords, dim=-1)
ones_like_shape = (1,) * coords.ndim
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
return coords # (B,T,H,W,3)
def batchshape2coordinate(
self,
batch_size,
spatial_shape,
t_ids,
coord_range=(-1.0, 1.0),
upsample_ratio=1,
device=None,
):
coords = []
_coords = torch.tensor(1, device=device)
coords.append(_coords.to(torch.float32))
for num_s in spatial_shape:
num_s = int(num_s * upsample_ratio)
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
coords.append(_coords)
coords = torch.meshgrid(*coords, indexing="ij")
coords = torch.stack(coords, dim=-1)
ones_like_shape = (1,) * coords.ndim
# Now coords b,1,h,w,3, coords[...,0]=1.
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
# assign per-sample timestep within the batch
coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1)
return coords
def forward(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
coord_range = self.coord_range if coord_range is None else coord_range
if isinstance(t_ids, list):
coords = self.shape2coordinate(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
elif isinstance(t_ids, torch.Tensor):
coords = self.batchshape2coordinate(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
if self.t_coord_only:
coords = coords[..., :1]
return coords

View File

@@ -0,0 +1,340 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# motif: https://github.com/sichun233746/MoTIF
# --------------------------------------------------------
import torch
import torch.nn as nn
from .fi_utils import warp, resize
class LateralBlock(nn.Module):
def __init__(self, dim):
super(LateralBlock, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
)
def forward(self, x):
res = x
x = self.layers(x)
return x + res
def convrelu(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
):
return nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias,
),
nn.PReLU(out_channels),
)
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
assert mean is None
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
res = comb_block(img_warps.view(b, -1, h, w))
imgt_pred = img_warps.mean(1) + res
imgt_pred = (imgt_pred + 1.0) / 2
return imgt_pred
class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__()
self.side_channels = side_channels
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv3 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv4 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv5 = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
out = self.conv1(x)
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv2(side_feat)
out = self.conv3(torch.cat([res_feat, side_feat], 1))
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv4(side_feat)
out = self.conv5(torch.cat([res_feat, side_feat], 1))
out = self.prelu(x + out)
return out
class BasicUpdateBlock(nn.Module):
def __init__(
self,
cdim,
hidden_dim,
flow_dim,
corr_dim,
corr_dim2,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
out_num=1,
):
super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) ** 2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
self.feat_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
)
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = (
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
inp = self.lrelu(self.conv(cor_flo))
inp = torch.cat([inp, flow, net], dim=1)
out = self.gru(inp)
delta_net = self.feat_head(out)
delta_flow = self.flow_head(out)
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow
def get_bn():
return nn.BatchNorm2d
class NewInitDecoder(nn.Module):
def __init__(self, in_ch, skip_ch):
super().__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
in_ch = in_ch // 2
self.convblock = nn.Sequential(
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
)
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
f0 = self.upsample(f0)
f1 = self.upsample(f1)
f0_warp_ks = warp(f0, flow0_in)
f1_warp_ks = warp(f1, flow1_in)
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
assert img0 is not None
assert img1 is not None
scale_factor = f_in.shape[2] / img0.shape[2]
img0 = resize(img0, scale_factor=scale_factor)
img1 = resize(img1, scale_factor=scale_factor)
warped_img0 = warp(img0, flow0_in)
warped_img1 = warp(img1, flow1_in)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
ft_ = out[:, 4:, ...]
flow0 = flow0_in + out[:, :2, ...]
flow1 = flow1_in + out[:, 2:4, ...]
return flow0, flow1, ft_
class NewMultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(NewMultiFlowDecoder, self).__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
nn.PixelShuffle(2),
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
self.num_flows = num_flows
ch_factor = 2
self.convblock = nn.Sequential(
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
)
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
f0 = self.upsample(f0)
# print([f1.shape,f0.shape])
f1 = self.upsample(f1)
n = self.num_flows
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
ft_ = resize(ft_, scale_factor=4.0)
mask = resize(mask, scale_factor=4.0)
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
assert mask is not None
f_in = torch.cat([f_in, mask], 1)
assert img0 is not None
assert img1 is not None
warped_img0 = warp(img0, flow0)
warped_img1 = warp(img1, flow1)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res

View File

@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# raft: https://github.com/princeton-vl/RAFT
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
# --------------------------------------------------------
import torch
import torch.nn.functional as F
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = (
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
.view(1, 1, 1, tenFlow.shape[3])
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
)
tenVertical = (
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
.view(1, 1, tenFlow.shape[2], 1)
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
)
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
tenFlow = torch.cat(
[
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
],
1,
)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(
input=tenInput,
grid=g,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
def normalize_flow(flows):
# FIXME: MULTI-DIMENSION
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
-1, 1, 1, 1, 1
)
flows = flows / flow_scaler # [-1,1]
# # Adapt to [0,1]
flows = (flows + 1.0) / 2.0
return flows, flow_scaler
def unnormalize_flow(flows, flow_scaler):
return (flows * 2.0 - 1.0) * flow_scaler
def resize(x, scale_factor):
return F.interpolate(
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def build_coord(img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8)
return coords

View File

@@ -0,0 +1,198 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from ..configs import HypoNetConfig
from .utils import create_params_with_init, create_activation
class HypoNet(nn.Module):
r"""
The Hyponetwork with a coordinate-based MLP to be modulated.
"""
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
super().__init__()
self.config = config
self.use_bias = config.use_bias
self.init_config = config.initialization
self.num_layer = config.n_layer
self.hidden_dims = config.hidden_dim
self.add_coord_dim = add_coord_dim
if len(self.hidden_dims) == 1:
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
self.num_layer - 1
) # exclude output layer
else:
assert len(self.hidden_dims) == self.num_layer - 1
if self.config.activation.type == "siren":
assert self.init_config.weight_init_type == "siren"
assert self.init_config.bias_init_type == "siren"
# after computes the shape of trainable parameters, initialize them
self.params_dict = None
self.params_shape_dict = self.compute_params_shape()
self.activation = create_activation(self.config.activation)
self.build_base_params_dict(self.config.initialization)
self.output_bias = config.output_bias
self.normalize_weight = config.normalize_weight
self.ignore_base_param_dict = {name: False for name in self.params_dict}
@staticmethod
def subsample_coords(coords, subcoord_idx=None):
if subcoord_idx is None:
return coords
batch_size = coords.shape[0]
sub_coords = []
coords = coords.view(batch_size, -1, coords.shape[-1])
for idx in range(batch_size):
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
sub_coords = torch.cat(sub_coords, dim=0)
return sub_coords
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
sub_idx = None
if isinstance(coord, tuple):
coord, sub_idx = coord[0], coord[1]
if modulation_params_dict is not None:
self.check_valid_param_keys(modulation_params_dict)
batch_size, coord_shape, input_dim = (
coord.shape[0],
coord.shape[1:-1],
coord.shape[-1],
)
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
assert pixel_latent is not None
pixel_latent = F.interpolate(
pixel_latent.permute(0, 3, 1, 2),
size=(coord_shape[1], coord_shape[2]),
mode="bilinear",
).permute(0, 2, 3, 1)
pixel_latent_dim = pixel_latent.shape[-1]
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
hidden = coord
hidden = torch.cat([pixel_latent, hidden], dim=-1)
hidden = self.subsample_coords(hidden, sub_idx)
for idx in range(self.config.n_layer):
param_key = f"linear_wb{idx}"
base_param = einops.repeat(
self.params_dict[param_key], "n m -> b n m", b=batch_size
)
if (modulation_params_dict is not None) and (
param_key in modulation_params_dict.keys()
):
modulation_param = modulation_params_dict[param_key]
else:
if self.config.use_bias:
modulation_param = torch.ones_like(base_param[:, :-1])
else:
modulation_param = torch.ones_like(base_param)
if self.config.use_bias:
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
hidden = torch.cat([hidden, ones], dim=-1)
base_param_w, base_param_b = (
base_param[:, :-1, :],
base_param[:, -1:, :],
)
if self.ignore_base_param_dict[param_key]:
base_param_w = 1.0
param_w = base_param_w * modulation_param
if self.normalize_weight:
param_w = F.normalize(param_w, dim=1)
modulated_param = torch.cat([param_w, base_param_b], dim=1)
else:
if self.ignore_base_param_dict[param_key]:
base_param = 1.0
if self.normalize_weight:
modulated_param = F.normalize(base_param * modulation_param, dim=1)
else:
modulated_param = base_param * modulation_param
# print([param_key,hidden.shape,modulated_param.shape])
hidden = torch.bmm(hidden, modulated_param)
if idx < (self.config.n_layer - 1):
hidden = self.activation(hidden)
outputs = hidden + self.output_bias
if sub_idx is None:
outputs = outputs.view(batch_size, *coord_shape, -1)
return outputs
def compute_params_shape(self):
"""
Computes the shape of MLP parameters.
The computed shapes are used to build the initial weights by `build_base_params_dict`.
"""
config = self.config
use_bias = self.use_bias
param_shape_dict = dict()
fan_in = config.input_dim
add_dim = self.add_coord_dim
fan_in = fan_in + add_dim
fan_in = fan_in + 1 if use_bias else fan_in
for i in range(config.n_layer - 1):
fan_out = self.hidden_dims[i]
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
fan_in = fan_out + 1 if use_bias else fan_out
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
return param_shape_dict
def build_base_params_dict(self, init_config):
assert self.params_shape_dict
params_dict = nn.ParameterDict()
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
is_first = idx == 0
params = create_params_with_init(
shape,
init_type=init_config.weight_init_type,
include_bias=self.use_bias,
bias_init_type=init_config.bias_init_type,
is_first=is_first,
siren_w0=self.config.activation.siren_w0, # valid only for siren
)
params = nn.Parameter(params)
params_dict[name] = params
self.set_params_dict(params_dict)
def check_valid_param_keys(self, params_dict):
predefined_params_keys = self.params_shape_dict.keys()
for param_key in params_dict.keys():
if param_key in predefined_params_keys:
continue
else:
raise KeyError
def set_params_dict(self, params_dict):
self.check_valid_param_keys(params_dict)
self.params_dict = params_dict

View File

@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
from torch import nn
import torch
# define siren layer & Siren model
class Sine(nn.Module):
"""Sine activation with scaling.
Args:
w0 (float): Omega_0 parameter from SIREN paper.
"""
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
# Damping activation from http://arxiv.org/abs/2306.15242
class Damping(nn.Module):
"""Sine activation with sublinear factor
Args:
w0 (float): Omega_0 parameter from SIREN paper.
"""
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
x = torch.clamp(x, min=1e-30)
return torch.sin(self.w0 * x) * torch.sqrt(x.abs())

View File

@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
from typing import List, Optional
from dataclasses import dataclass, field
from omegaconf import MISSING
@dataclass
class HypoNetActivationConfig:
type: str = "relu"
siren_w0: Optional[float] = 30.0
@dataclass
class HypoNetInitConfig:
weight_init_type: Optional[str] = "kaiming_uniform"
bias_init_type: Optional[str] = "zero"
@dataclass
class HypoNetConfig:
type: str = "mlp"
n_layer: int = 5
hidden_dim: List[int] = MISSING
use_bias: bool = True
input_dim: int = 2
output_dim: int = 3
output_bias: float = 0.5
activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig)
initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig)
normalize_weight: bool = True
linear_interpo: bool = False
@dataclass
class CoordSamplerConfig:
data_type: str = "image"
t_coord_only: bool = False
coord_range: List[float] = MISSING
time_range: List[float] = MISSING
train_strategy: Optional[str] = MISSING
val_strategy: Optional[str] = MISSING
patch_size: Optional[int] = MISSING

View File

@@ -0,0 +1,672 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# softmax-splatting: https://github.com/sniklaus/softmax-splatting
# --------------------------------------------------------
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn: int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn: float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
if "device" not in objCudacache:
objCudacache["device"] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert False
# end
# end
strKey += objCudacache["device"]
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace("{{type}}", "unsigned char")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace("{{type}}", "half")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace("{{type}}", "float")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace("{{type}}", "double")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace("{{type}}", "int")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace("{{type}}", "long")
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert False
elif True:
print(strVariable, type(objValue))
assert False
# end
# end
while True:
objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(
objMatch.group(),
str(
intSizes[intArg]
if torch.is_tensor(intSizes[intArg]) == False
else intSizes[intArg].item()
),
)
# end
while True:
objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")
assert intArgs == len(strArgs) - 1
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end
strKernel = strKernel.replace(
"OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
"(" + str.join("+", strIndex) + ")",
)
# end
while True:
objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")
assert intArgs == len(strArgs) - 1
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end
strKernel = strKernel.replace(
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
strTensor + "[" + str.join("+", strIndex) + "]",
)
# end
objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
@torch.compiler.disable()
def cuda_launch(strKey: str):
try:
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
except Exception:
if "CUDA_HOME" not in os.environ:
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
strKernel = objCudacache[strKey]["strKernel"]
strFunction = objCudacache[strKey]["strFunction"]
return cupy.RawModule(
code=strKernel,
options=(
"-I " + os.environ["CUDA_HOME"],
"-I " + os.environ["CUDA_HOME"] + "/include",
),
).get_function(strFunction)
##########################################################
@torch.compiler.disable()
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
if strMode == "sum":
assert tenMetric is None
if strMode == "avg":
assert tenMetric is None
if strMode.split("-")[0] == "linear":
assert tenMetric is not None
if strMode.split("-")[0] == "softmax":
assert tenMetric is not None
if strMode == "avg":
tenIn = torch.cat(
[
tenIn,
tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]),
],
1,
)
elif strMode.split("-")[0] == "linear":
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
elif strMode.split("-")[0] == "softmax":
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
# end
if torch.isnan(tenIn).any():
print("NaN values detected during training in tenIn. Exiting.")
assert False
tenOut = softsplat_func.apply(tenIn, tenFlow)
if torch.isnan(tenOut).any():
print("NaN values detected during training in tenOut_1. Exiting.")
assert False
if strMode.split("-")[0] in ["avg", "linear", "softmax"]:
tenNormalize = tenOut[:, -1:, :, :]
if len(strMode.split("-")) == 1:
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "addeps":
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "zeroeps":
tenNormalize[tenNormalize == 0.0] = 1.0
elif strMode.split("-")[1] == "clipeps":
tenNormalize = tenNormalize.clip(0.0000001, None)
# end
if return_norm:
return tenOut[:, :-1, :, :], tenNormalize
tenOut = tenOut[:, :-1, :, :] / tenNormalize
if torch.isnan(tenOut).any():
print("NaN values detected during training in tenOut_2. Exiting.")
assert False
# end
return tenOut
# end
class softsplat_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if tenIn.is_cuda == True:
cuda_launch(
cuda_kernel(
"softsplat_out",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
}
} }
""",
{"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut},
)
)(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenOut.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOut.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
elif tenIn.is_cuda != True:
assert False
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
@staticmethod
@torch.compiler.disable()
@torch.amp.custom_bwd(device_type="cuda")
def backward(self, tenOutgrad):
tenIn, tenFlow = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous()
assert tenOutgrad.is_cuda == True
tenIngrad = (
tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if self.needs_input_grad[0] == True
else None
)
tenFlowgrad = (
tenFlow.new_zeros(
[tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]
)
if self.needs_input_grad[1] == True
else None
)
if tenIngrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_ingrad",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltIngrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
}
tenIngrad[intIndex] = fltIngrad;
} }
""",
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenIngrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
tenIngrad.data_ptr(),
None,
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
if tenFlowgrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_flowgrad",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltFlowgrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = 0.0f;
{{type}} fltNortheast = 0.0f;
{{type}} fltSouthwest = 0.0f;
{{type}} fltSoutheast = 0.0f;
if (intC == 0) {
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
} else if (intC == 1) {
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
}
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
}
}
tenFlowgrad[intIndex] = fltFlowgrad;
} }
""",
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenFlowgrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
None,
tenFlowgrad.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
return tenIngrad, tenFlowgrad
# end
# end

View File

@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
from .layers import Sine, Damping
def convert_int_to_list(size, len_list=2):
if isinstance(size, int):
return [size] * len_list
else:
assert len(size) == len_list
return size
def initialize_params(params, init_type, **kwargs):
fan_in, fan_out = params.shape[0], params.shape[1]
if init_type is None or init_type == "normal":
nn.init.normal_(params)
elif init_type == "kaiming_uniform":
nn.init.kaiming_uniform_(params, a=math.sqrt(5))
elif init_type == "uniform_fan_in":
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(params, -bound, bound)
elif init_type == "zero":
nn.init.zeros_(params)
elif "siren" == init_type:
assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
w0 = kwargs["siren_w0"]
if kwargs["is_first"]:
w_std = 1 / fan_in
else:
w_std = math.sqrt(6.0 / fan_in) / w0
nn.init.uniform_(params, -w_std, w_std)
else:
raise NotImplementedError
def create_params_with_init(
shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
):
if not include_bias:
params = torch.empty([shape[0], shape[1]])
initialize_params(params, init_type, **kwargs)
return params
else:
params = torch.empty([shape[0] - 1, shape[1]])
bias = torch.empty([1, shape[1]])
initialize_params(params, init_type, **kwargs)
initialize_params(bias, bias_init_type, **kwargs)
return torch.cat([params, bias], dim=0)
def create_activation(config):
if config.type == "relu":
activation = nn.ReLU()
elif config.type == "siren":
activation = Sine(config.siren_w0)
elif config.type == "silu":
activation = nn.SiLU()
elif config.type == "damp":
activation = Damping(config.siren_w0)
else:
raise NotImplementedError
return activation

View File

@@ -0,0 +1 @@
from .raft import RAFT

View File

@@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# raft: https://github.com/princeton-vl/RAFT
# --------------------------------------------------------
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class BidirCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.corr_pyramid_T = []
corr = BidirCorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
for _ in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
def __call__(self, coords0, coords1):
r = self.radius
coords0 = coords0.permute(0, 2, 3, 1)
coords1 = coords1.permute(0, 2, 3, 1)
assert (
coords0.shape == coords1.shape
), f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
batch, h1, w1, _ = coords0.shape
out_pyramid = []
out_pyramid_T = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
corr_T = self.corr_pyramid_T[i]
dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
coords_lvl_0 = centroid_lvl_0 + delta_lvl
coords_lvl_1 = centroid_lvl_1 + delta_lvl
corr = bilinear_sampler(corr, coords_lvl_0)
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
corr = corr.view(batch, h1, w1, -1)
corr_T = corr_T.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out_pyramid_T.append(corr_T)
out = torch.cat(out_pyramid, dim=-1)
out_T = torch.cat(out_pyramid_T, dim=-1)
return (
out.permute(0, 3, 1, 2).contiguous().float(),
out_T.permute(0, 3, 1, 2).contiguous().float(),
)
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@@ -0,0 +1,293 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0, only_feat=False):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.only_feat = only_feat
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
if not self.only_feat:
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, return_feature=False, mif=False):
features = []
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x_2 = F.interpolate(x, scale_factor=1 / 2, mode="bilinear", align_corners=False)
x_4 = F.interpolate(x, scale_factor=1 / 4, mode="bilinear", align_corners=False)
def f1(feat):
feat = self.conv1(feat)
feat = self.norm1(feat)
feat = self.relu1(feat)
feat = self.layer1(feat)
return feat
x = f1(x)
features.append(x)
x = self.layer2(x)
if mif:
x_2_2 = f1(x_2)
features.append(torch.cat([x, x_2_2], dim=1))
else:
features.append(x)
x = self.layer3(x)
if mif:
x_2_4 = self.layer2(x_2_2)
x_4_4 = f1(x_4)
features.append(torch.cat([x, x_2_4, x_4_4], dim=1))
else:
features.append(x)
if not self.only_feat:
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
features = [torch.split(f, [batch_dim, batch_dim], dim=0) for f in features]
if return_feature:
return x, features
else:
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

View File

@@ -0,0 +1,238 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import BidirCorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
# BiRAFT
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
self.corr_levels = 4
self.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
self.corr_levels = 4
self.corr_radius = 4
if "dropout" not in args._get_kwargs():
self.args.dropout = 0
if "alternate_corr" not in args._get_kwargs():
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(
output_dim=128, norm_fn="instance", dropout=args.dropout
)
self.cnet = SmallEncoder(
output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(
output_dim=256, norm_fn="instance", dropout=args.dropout
)
self.cnet = BasicEncoder(
output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def build_coord(self, img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8, device=img.device)
return coords
def initialize_flow(self, img, img2):
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
assert img.shape == img2.shape
N, C, H, W = img.shape
coords01 = coords_grid(N, H // 8, W // 8, device=img.device)
coords02 = coords_grid(N, H // 8, W // 8, device=img.device)
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
coords2 = coords_grid(N, H // 8, W // 8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords01, coords02, coords1, coords2
def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def get_corr_fn(self, image1, image2, projector=None):
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmaps, feats = self.fnet([image1, image2], return_feature=True)
fmap1, fmap2 = fmaps
fmap1 = fmap1.float()
fmap2 = fmap2.float()
corr_fn1 = None
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
if projector is not None:
corr_fn1 = AlternateCorrBlock(
projector(feats[-1][0]),
projector(feats[-1][1]),
radius=self.args.corr_radius,
)
else:
corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
if projector is not None:
corr_fn1 = BidirCorrBlock(
projector(feats[-1][0]),
projector(feats[-1][1]),
radius=self.args.corr_radius,
)
if corr_fn1 is None:
return corr_fn, corr_fn
else:
return corr_fn, corr_fn1
def get_corr_fn_from_feat(self, fmap1, fmap2):
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
return corr_fn
def forward(
self,
image1,
image2,
iters=12,
flow_init=None,
upsample=True,
test_mode=False,
corr_fn=None,
mif=False,
):
"""Estimate optical flow between pair of frames"""
assert flow_init is None
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
if corr_fn is None:
corr_fn, _ = self.get_corr_fn(image1, image2)
# # run the feature network
# with autocast(enabled=self.args.mixed_precision):
# fmap1, fmap2 = self.fnet([image1, image2])
# fmap1 = fmap1.float()
# fmap2 = fmap2.float()
# if self.args.alternate_corr:
# corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# else:
# corr_fn = BidirCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
# for image1
cnet1, features1 = self.cnet(image1, return_feature=True, mif=mif)
net1, inp1 = torch.split(cnet1, [hdim, cdim], dim=1)
net1 = torch.tanh(net1)
inp1 = torch.relu(inp1)
# for image2
cnet2, features2 = self.cnet(image2, return_feature=True, mif=mif)
net2, inp2 = torch.split(cnet2, [hdim, cdim], dim=1)
net2 = torch.tanh(net2)
inp2 = torch.relu(inp2)
coords01, coords02, coords1, coords2 = self.initialize_flow(image1, image2)
# if flow_init is not None:
# coords1 = coords1 + flow_init
# flow_predictions1 = []
# flow_predictions2 = []
for itr in range(iters):
coords1 = coords1.detach()
coords2 = coords2.detach()
corr1, corr2 = corr_fn(coords1, coords2) # index correlation volume
flow1 = coords1 - coords01
flow2 = coords2 - coords02
with autocast(enabled=self.args.mixed_precision):
net1, up_mask1, delta_flow1 = self.update_block(
net1, inp1, corr1, flow1
)
net2, up_mask2, delta_flow2 = self.update_block(
net2, inp2, corr2, flow2
)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow1
coords2 = coords2 + delta_flow2
flow_low1 = coords1 - coords01
flow_low2 = coords2 - coords02
# upsample predictions
if up_mask1 is None:
flow_up1 = upflow8(coords1 - coords01)
flow_up2 = upflow8(coords2 - coords02)
else:
flow_up1 = self.upsample_flow(coords1 - coords01, up_mask1)
flow_up2 = self.upsample_flow(coords2 - coords02, up_mask2)
# flow_predictions.append(flow_up)
return flow_up1, flow_up2, flow_low1, flow_low2, features1, features2
# if test_mode:
# return coords1 - coords0, flow_up
# return flow_predictions

View File

@@ -0,0 +1,169 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
self.corr_levels = 4
self.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
self.corr_levels = 4
self.corr_radius = 4
if "dropout" not in args._get_kwargs():
self.args.dropout = 0
if "alternate_corr" not in args._get_kwargs():
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(
output_dim=128, norm_fn="instance", dropout=args.dropout
)
self.cnet = SmallEncoder(
output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(
output_dim=256, norm_fn="instance", dropout=args.dropout
)
self.cnet = BasicEncoder(
output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def forward(
self,
image1,
image2,
iters=12,
flow_init=None,
upsample=True,
test_mode=False,
return_feat=True,
):
"""Estimate optical flow between pair of frames"""
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet, feats = self.cnet(image1, return_feature=True)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
if return_feat:
return flow_up, feats[1:], fmap1
return flow_predictions

View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow.to(inp), corr.to(inp))
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,93 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(
torch.arange(ht, device=device), torch.arange(wd, device=device)
)
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode="bilinear"):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

View File

View File

@@ -0,0 +1,52 @@
from easydict import EasyDict as edict
import torch.nn.functional as F
class InputPadder:
"""Pads images such that dimensions are divisible by divisor"""
def __init__(self, dims, divisor=16):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
def pad(self, *inputs):
if len(inputs) == 1:
return F.pad(inputs[0], self._pad, mode="replicate")
else:
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, *inputs):
if len(inputs) == 1:
return self._unpad(inputs[0])
else:
return [self._unpad(x) for x in inputs]
def _unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def easydict_to_dict(obj):
if not isinstance(obj, edict):
return obj
else:
return {k: easydict_to_dict(v) for k, v in obj.items()}
class RaftArgs:
def __init__(self, small, mixed_precision, alternate_corr):
self.small = small
self.mixed_precision = mixed_precision
self.alternate_corr = alternate_corr
def _get_kwargs(self):
return {
"small": self.small,
"mixed_precision": self.mixed_precision,
"alternate_corr": self.alternate_corr
}

View File

@@ -441,3 +441,183 @@ class SGMVFIModel:
pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1)
# ---------------------------------------------------------------------------
# GIMM-VFI model wrapper
# ---------------------------------------------------------------------------
class GIMMVFIModel:
"""Clean inference wrapper around GIMM-VFI for ComfyUI integration.
Supports two modes:
- interpolate_batch(): standard single-midpoint interface (compatible with
recursive _interpolate_frames machinery used by other models)
- interpolate_multi(): GIMM-VFI's unique single-pass mode, generates all
N-1 intermediate frames between each pair in one forward pass
"""
def __init__(self, checkpoint_path, flow_checkpoint_path, variant="auto",
ds_factor=1.0, device="cpu"):
import os
import yaml
from omegaconf import OmegaConf
from .gimm_vfi_arch import (
GIMMVFI_R, GIMMVFI_F, GIMMVFIConfig,
GIMM_RAFT, GIMM_FlowFormer, gimm_get_flowformer_cfg,
GIMMInputPadder, GIMMRaftArgs, easydict_to_dict,
)
import comfy.utils
self.ds_factor = ds_factor
self.device = device
self._InputPadder = GIMMInputPadder
filename = os.path.basename(checkpoint_path).lower()
# Detect variant from filename
if variant == "auto":
self.is_flowformer = "gimmvfi_f" in filename
else:
self.is_flowformer = (variant == "flowformer")
self.variant_name = "flowformer" if self.is_flowformer else "raft"
# Load config
script_dir = os.path.dirname(os.path.abspath(__file__))
if self.is_flowformer:
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_f_arb.yaml")
else:
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_r_arb.yaml")
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = easydict_to_dict(config)
config = OmegaConf.create(config)
arch_defaults = GIMMVFIConfig.create(config.arch)
config = OmegaConf.merge(arch_defaults, config.arch)
# Build model + flow estimator
dtype = torch.float32
if self.is_flowformer:
self.model = GIMMVFI_F(dtype, config)
cfg = gimm_get_flowformer_cfg()
flow_estimator = GIMM_FlowFormer(cfg.latentcostformer)
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
flow_estimator.load_state_dict(flow_sd, strict=True)
else:
self.model = GIMMVFI_R(dtype, config)
raft_args = GIMMRaftArgs(small=False, mixed_precision=False, alternate_corr=False)
flow_estimator = GIMM_RAFT(raft_args)
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
flow_estimator.load_state_dict(flow_sd, strict=True)
# Load main model weights
sd = comfy.utils.load_torch_file(checkpoint_path)
self.model.load_state_dict(sd, strict=False)
self.model.flow_estimator = flow_estimator
self.model.eval()
def to(self, device):
"""Move model to device (returns self for chaining)."""
self.device = device if isinstance(device, str) else str(device)
self.model.to(device)
return self
@torch.no_grad()
def interpolate_batch(self, frames0, frames1, time_step=0.5):
"""Interpolate a single midpoint frame per pair (standard interface).
Args:
frames0: [B, C, H, W] tensor, float32, range [0, 1]
frames1: [B, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1)
Returns:
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
results = []
for i in range(frames0.shape[0]):
I0 = frames0[i:i+1].to(device)
I2 = frames1[i:i+1].to(device)
padder = self._InputPadder(I0.shape, 32)
I0_p, I2_p = padder.pad(I0, I2)
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
batch_size = xs.shape[0]
s_shape = xs.shape[-2:]
coord_inputs = [(
self.model.sample_coord_input(
batch_size, s_shape, [time_step],
device=xs.device, upsample_ratio=self.ds_factor,
),
None,
)]
timesteps = [
time_step * torch.ones(xs.shape[0]).to(xs.device)
]
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
pred = padder.unpad(all_outputs["imgt_pred"][0])
results.append(torch.clamp(pred, 0, 1))
return torch.cat(results, dim=0)
@torch.no_grad()
def interpolate_multi(self, frame0, frame1, num_intermediates):
"""Generate all intermediate frames between a pair in one forward pass.
This is GIMM-VFI's unique capability -- arbitrary timestep interpolation
without recursive 2x passes.
Args:
frame0: [1, C, H, W] tensor, float32, range [0, 1]
frame1: [1, C, H, W] tensor, float32, range [0, 1]
num_intermediates: int, number of intermediate frames to generate
Returns:
List of [1, C, H, W] tensors, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
I0 = frame0.to(device)
I2 = frame1.to(device)
padder = self._InputPadder(I0.shape, 32)
I0_p, I2_p = padder.pad(I0, I2)
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
batch_size = xs.shape[0]
s_shape = xs.shape[-2:]
interp_factor = num_intermediates + 1
coord_inputs = [
(
self.model.sample_coord_input(
batch_size, s_shape,
[1.0 / interp_factor * i],
device=xs.device,
upsample_ratio=self.ds_factor,
),
None,
)
for i in range(1, interp_factor)
]
timesteps = [
i * 1.0 / interp_factor * torch.ones(xs.shape[0]).to(xs.device)
for i in range(1, interp_factor)
]
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
results = []
for pred in all_outputs["imgt_pred"]:
unpadded = padder.unpad(pred)
results.append(torch.clamp(unpadded, 0, 1))
return results

396
nodes.py
View File

@@ -8,10 +8,11 @@ import torch
import folder_paths
from comfy.utils import ProgressBar
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
from .gimm_vfi_arch import clear_gimm_caches
logger = logging.getLogger("Tween")
@@ -40,6 +41,17 @@ SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi")
if not os.path.exists(SGM_MODEL_DIR):
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
# GIMM-VFI
GIMM_HF_REPO = "Kijai/GIMM-VFI_safetensors"
GIMM_AVAILABLE_MODELS = [
"gimmvfi_r_arb_lpips_fp32.safetensors",
"gimmvfi_f_arb_lpips_fp32.safetensors",
]
GIMM_MODEL_DIR = os.path.join(folder_paths.models_dir, "gimm-vfi")
if not os.path.exists(GIMM_MODEL_DIR):
os.makedirs(GIMM_MODEL_DIR, exist_ok=True)
def get_available_models():
"""List available checkpoint files in the bim-vfi model directory."""
@@ -1113,3 +1125,385 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
result = result[1:] # skip duplicate boundary frame
return (result, model)
# ---------------------------------------------------------------------------
# GIMM-VFI nodes
# ---------------------------------------------------------------------------
def get_available_gimm_models():
"""List available GIMM-VFI checkpoint files in the gimm-vfi model directory."""
models = []
if os.path.isdir(GIMM_MODEL_DIR):
for f in os.listdir(GIMM_MODEL_DIR):
if f.endswith((".safetensors", ".pth", ".pt", ".ckpt")):
# Exclude flow estimator checkpoints from the model list
if f.startswith(("raft-", "flowformer_")):
continue
models.append(f)
if not models:
models = list(GIMM_AVAILABLE_MODELS)
return sorted(models)
def download_gimm_model(filename, dest_dir):
"""Download a GIMM-VFI file from HuggingFace."""
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise RuntimeError(
"huggingface_hub is required to auto-download GIMM-VFI models. "
"Install it with: pip install huggingface_hub"
)
logger.info(f"Downloading {filename} from HuggingFace ({GIMM_HF_REPO})...")
hf_hub_download(
repo_id=GIMM_HF_REPO,
filename=filename,
local_dir=dest_dir,
local_dir_use_symlinks=False,
)
dest_path = os.path.join(dest_dir, filename)
if not os.path.exists(dest_path):
raise RuntimeError(f"Failed to download {filename} to {dest_path}")
logger.info(f"Downloaded {filename}")
class LoadGIMMVFIModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": (get_available_gimm_models(), {
"default": GIMM_AVAILABLE_MODELS[0],
"tooltip": "Checkpoint file from models/gimm-vfi/. Auto-downloads from HuggingFace on first use. "
"RAFT variant (~80MB) or FlowFormer variant (~123MB) auto-detected from filename.",
}),
"ds_factor": ("FLOAT", {
"default": 1.0, "min": 0.125, "max": 1.0, "step": 0.125,
"tooltip": "Downscale factor for internal processing. 1.0 = full resolution. "
"Lower values reduce VRAM usage and speed up inference at the cost of quality. "
"Try 0.5 for 4K inputs.",
}),
}
}
RETURN_TYPES = ("GIMM_VFI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/GIMM-VFI"
def load_model(self, model_path, ds_factor):
full_path = os.path.join(GIMM_MODEL_DIR, model_path)
# Auto-download main model if missing
if not os.path.exists(full_path):
logger.info(f"Model not found at {full_path}, attempting download...")
download_gimm_model(model_path, GIMM_MODEL_DIR)
# Detect and download matching flow estimator
if "gimmvfi_f" in model_path.lower():
flow_filename = "flowformer_sintel_fp32.safetensors"
else:
flow_filename = "raft-things_fp32.safetensors"
flow_path = os.path.join(GIMM_MODEL_DIR, flow_filename)
if not os.path.exists(flow_path):
logger.info(f"Flow estimator not found, downloading {flow_filename}...")
download_gimm_model(flow_filename, GIMM_MODEL_DIR)
wrapper = GIMMVFIModel(
checkpoint_path=full_path,
flow_checkpoint_path=flow_path,
variant="auto",
ds_factor=ds_factor,
device="cpu",
)
logger.info(f"GIMM-VFI model loaded (variant={wrapper.variant_name}, ds_factor={ds_factor})")
return (wrapper,)
class GIMMVFIInterpolate:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
}),
"model": ("GIMM_VFI_MODEL", {
"tooltip": "GIMM-VFI model from the Load GIMM-VFI Model node.",
}),
"multiplier": ([2, 4, 8], {
"default": 2,
"tooltip": "Frame rate multiplier. In single-pass mode, all intermediate frames are generated "
"in one forward pass per pair. In recursive mode, uses 2x passes like other models.",
}),
"single_pass": ("BOOLEAN", {
"default": True,
"tooltip": "Use GIMM-VFI's single-pass arbitrary-timestep mode. Generates all intermediate frames "
"per pair in one forward pass (no recursive 2x passes). Disable to use the standard "
"recursive approach (same as BIM/EMA/SGM).",
}),
"clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
}),
"keep_device": ("BOOLEAN", {
"default": True,
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
}),
"all_on_gpu": ("BOOLEAN", {
"default": False,
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
}),
"batch_size": ("INT", {
"default": 1, "min": 1, "max": 64, "step": 1,
"tooltip": "Number of frame pairs to process simultaneously in recursive mode. Ignored in single-pass mode (pairs are processed one at a time since each generates multiple frames).",
}),
"chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "interpolate"
CATEGORY = "video/GIMM-VFI"
def _interpolate_frames_single_pass(self, frames, model, multiplier,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref):
"""Single-pass interpolation using GIMM-VFI's arbitrary timestep capability."""
num_intermediates = multiplier - 1
new_frames = []
num_pairs = frames.shape[0] - 1
pairs_since_clear = 0
for i in range(num_pairs):
frame0 = frames[i:i+1]
frame1 = frames[i+1:i+2]
if not keep_device:
model.to(device)
mids = model.interpolate_multi(frame0, frame1, num_intermediates)
mids = [m.to(storage_device) for m in mids]
if not keep_device:
model.to("cpu")
new_frames.append(frames[i:i+1])
for m in mids:
new_frames.append(m)
step_ref[0] += 1
pbar.update_absolute(step_ref[0])
pairs_since_clear += 1
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_gimm_caches()
torch.cuda.empty_cache()
pairs_since_clear = 0
new_frames.append(frames[-1:])
result = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available():
clear_gimm_caches()
torch.cuda.empty_cache()
return result
def _interpolate_frames(self, frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref):
"""Recursive 2x interpolation (standard approach, same as other models)."""
for pass_idx in range(num_passes):
new_frames = []
num_pairs = frames.shape[0] - 1
pairs_since_clear = 0
for i in range(0, num_pairs, batch_size):
batch_end = min(i + batch_size, num_pairs)
actual_batch = batch_end - i
frames0 = frames[i:batch_end]
frames1 = frames[i + 1:batch_end + 1]
if not keep_device:
model.to(device)
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
mids = mids.to(storage_device)
if not keep_device:
model.to("cpu")
for j in range(actual_batch):
new_frames.append(frames[i + j:i + j + 1])
new_frames.append(mids[j:j+1])
step_ref[0] += actual_batch
pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_gimm_caches()
torch.cuda.empty_cache()
pairs_since_clear = 0
new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available():
clear_gimm_caches()
torch.cuda.empty_cache()
return frames
@staticmethod
def _count_steps(num_frames, num_passes):
"""Count total interpolation steps for recursive mode."""
n = num_frames
total = 0
for _ in range(num_passes):
total += n - 1
n = 2 * n - 1
return total
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size):
if images.shape[0] < 2:
return (images,)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not single_pass:
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
total_input = all_frames.shape[0]
# Build chunk boundaries (1-frame overlap between consecutive chunks)
if chunk_size < 2 or chunk_size >= total_input:
chunks = [(0, total_input)]
else:
chunks = []
start = 0
while start < total_input - 1:
end = min(start + chunk_size, total_input)
chunks.append((start, end))
start = end - 1 # overlap by 1 frame
if end == total_input:
break
# Calculate total progress steps across all chunks
if single_pass:
total_steps = sum(ce - cs - 1 for cs, ce in chunks)
else:
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
result_chunks = []
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
chunk_frames = all_frames[chunk_start:chunk_end].clone()
if single_pass:
chunk_result = self._interpolate_frames_single_pass(
chunk_frames, model, multiplier,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
else:
chunk_result = self._interpolate_frames(
chunk_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
if chunk_idx > 0:
chunk_result = chunk_result[1:]
# Move completed chunk to CPU to bound memory when chunking
if len(chunks) > 1:
chunk_result = chunk_result.cpu()
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
"""Process a numbered segment of the input batch for GIMM-VFI.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through output forces sequential execution so each segment
saves and frees from RAM before the next starts.
"""
@classmethod
def INPUT_TYPES(cls):
base = GIMMVFIInterpolate.INPUT_TYPES()
base["required"]["segment_index"] = ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
"Chain the model output to the next Segment Interpolate to force sequential execution.",
})
base["required"]["segment_size"] = ("INT", {
"default": 500, "min": 2, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
})
return base
RETURN_TYPES = ("IMAGE", "GIMM_VFI_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "interpolate"
CATEGORY = "video/GIMM-VFI"
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size, segment_index, segment_size):
total_input = images.shape[0]
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
(result,) = super().interpolate(
segment_images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size,
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
return (result, model)

View File

@@ -1 +1,6 @@
gdown
omegaconf
yacs
easydict
einops
huggingface_hub