Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_CN = CN()
|
||||
|
||||
_CN.name = ""
|
||||
_CN.suffix = ""
|
||||
_CN.gamma = 0.8
|
||||
_CN.max_flow = 400
|
||||
_CN.batch_size = 6
|
||||
_CN.sum_freq = 100
|
||||
_CN.val_freq = 5000000
|
||||
_CN.image_size = [432, 960]
|
||||
_CN.add_noise = False
|
||||
_CN.critical_params = []
|
||||
|
||||
_CN.transformer = "latentcostformer"
|
||||
_CN.model = "pretrained_ckpt/flowformer_sintel.pth"
|
||||
|
||||
# latentcostformer
|
||||
_CN.latentcostformer = CN()
|
||||
_CN.latentcostformer.pe = "linear"
|
||||
_CN.latentcostformer.dropout = 0.0
|
||||
_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
|
||||
_CN.latentcostformer.query_latent_dim = 64
|
||||
_CN.latentcostformer.cost_latent_input_dim = 64
|
||||
_CN.latentcostformer.cost_latent_token_num = 8
|
||||
_CN.latentcostformer.cost_latent_dim = 128
|
||||
_CN.latentcostformer.arc_type = "transformer"
|
||||
_CN.latentcostformer.cost_heads_num = 1
|
||||
# encoder
|
||||
_CN.latentcostformer.pretrain = True
|
||||
_CN.latentcostformer.context_concat = False
|
||||
_CN.latentcostformer.encoder_depth = 3
|
||||
_CN.latentcostformer.feat_cross_attn = False
|
||||
_CN.latentcostformer.patch_size = 8
|
||||
_CN.latentcostformer.patch_embed = "single"
|
||||
_CN.latentcostformer.no_pe = False
|
||||
_CN.latentcostformer.gma = "GMA"
|
||||
_CN.latentcostformer.kernel_size = 9
|
||||
_CN.latentcostformer.rm_res = True
|
||||
_CN.latentcostformer.vert_c_dim = 64
|
||||
_CN.latentcostformer.cost_encoder_res = True
|
||||
_CN.latentcostformer.cnet = "twins"
|
||||
_CN.latentcostformer.fnet = "twins"
|
||||
_CN.latentcostformer.no_sc = False
|
||||
_CN.latentcostformer.only_global = False
|
||||
_CN.latentcostformer.add_flow_token = True
|
||||
_CN.latentcostformer.use_mlp = False
|
||||
_CN.latentcostformer.vertical_conv = False
|
||||
|
||||
# decoder
|
||||
_CN.latentcostformer.decoder_depth = 32
|
||||
_CN.latentcostformer.critical_params = [
|
||||
"cost_heads_num",
|
||||
"vert_c_dim",
|
||||
"cnet",
|
||||
"pretrain",
|
||||
"add_flow_token",
|
||||
"encoder_depth",
|
||||
"gma",
|
||||
"cost_encoder_res",
|
||||
]
|
||||
|
||||
### TRAINER
|
||||
_CN.trainer = CN()
|
||||
_CN.trainer.scheduler = "OneCycleLR"
|
||||
_CN.trainer.optimizer = "adamw"
|
||||
_CN.trainer.canonical_lr = 12.5e-5
|
||||
_CN.trainer.adamw_decay = 1e-4
|
||||
_CN.trainer.clip = 1.0
|
||||
_CN.trainer.num_steps = 120000
|
||||
_CN.trainer.epsilon = 1e-8
|
||||
_CN.trainer.anneal_strategy = "linear"
|
||||
|
||||
|
||||
def get_cfg():
|
||||
return _CN.clone()
|
||||
@@ -0,0 +1,197 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class BroadMultiHeadAttention(nn.Module):
|
||||
def __init__(self, dim, heads):
|
||||
super(BroadMultiHeadAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.scale = (dim / heads) ** -0.5
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
|
||||
def attend_with_rpe(self, Q, K):
|
||||
Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
|
||||
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
|
||||
|
||||
return self.attend(dots)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
attn = self.attend_with_rpe(Q, K)
|
||||
B, _, _ = K.shape
|
||||
_, N, _ = Q.shape
|
||||
|
||||
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||
out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, dim, heads):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.scale = (dim / heads) ** -0.5
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
|
||||
def attend_with_rpe(self, Q, K):
|
||||
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
|
||||
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
dots = (
|
||||
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
|
||||
) # (b hw) heads 1 pointnum
|
||||
|
||||
return self.attend(dots)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
attn = self.attend_with_rpe(Q, K)
|
||||
B, HW, _ = Q.shape
|
||||
|
||||
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# class MultiHeadAttentionRelative_encoder(nn.Module):
|
||||
# def __init__(self, dim, heads):
|
||||
# super(MultiHeadAttentionRelative, self).__init__()
|
||||
# self.dim = dim
|
||||
# self.heads = heads
|
||||
# self.scale = (dim/heads) ** -0.5
|
||||
# self.attend = nn.Softmax(dim=-1)
|
||||
|
||||
# def attend_with_rpe(self, Q, K, Q_r, K_r):
|
||||
# """
|
||||
# Q: [BH1W1, H3W3, dim]
|
||||
# K: [BH1W1, H3W3, dim]
|
||||
# Q_r: [BH1W1, H3W3, H3W3, dim]
|
||||
# K_r: [BH1W1, H3W3, H3W3, dim]
|
||||
# """
|
||||
|
||||
# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||
# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||
# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||
# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
|
||||
|
||||
# # context-context similarity
|
||||
# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3]
|
||||
# # context-position similarity
|
||||
# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
|
||||
# # position-context similarity
|
||||
# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:])
|
||||
# p_c = torch.squeeze(p_c, dim=4)
|
||||
# p_c = p_c.permute(0, 1, 3, 2)
|
||||
# dots = c_c + c_p + p_c
|
||||
# return self.attend(dots)
|
||||
|
||||
# def forward(self, Q, K, V, Q_r, K_r):
|
||||
# attn = self.attend_with_rpe(Q, K, Q_r, K_r)
|
||||
# B, HW, _ = Q.shape
|
||||
|
||||
# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
|
||||
|
||||
# out = einsum('bhij, bhjd -> bhid', attn, V)
|
||||
# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
|
||||
|
||||
# return out
|
||||
|
||||
|
||||
class MultiHeadAttentionRelative(nn.Module):
|
||||
def __init__(self, dim, heads):
|
||||
super(MultiHeadAttentionRelative, self).__init__()
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.scale = (dim / heads) ** -0.5
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
|
||||
def attend_with_rpe(self, Q, K, Q_r, K_r):
|
||||
"""
|
||||
Q: [BH1W1, 1, dim]
|
||||
K: [BH1W1, H3W3, dim]
|
||||
Q_r: [BH1W1, H3W3, dim]
|
||||
K_r: [BH1W1, H3W3, dim]
|
||||
"""
|
||||
|
||||
Q = rearrange(
|
||||
Q, "b i (heads d) -> b heads i d", heads=self.heads
|
||||
) # [BH1W1, heads, 1, dim]
|
||||
K = rearrange(
|
||||
K, "b j (heads d) -> b heads j d", heads=self.heads
|
||||
) # [BH1W1, heads, H3W3, dim]
|
||||
K_r = rearrange(
|
||||
K_r, "b j (heads d) -> b heads j d", heads=self.heads
|
||||
) # [BH1W1, heads, H3W3, dim]
|
||||
Q_r = rearrange(
|
||||
Q_r, "b j (heads d) -> b heads j d", heads=self.heads
|
||||
) # [BH1W1, heads, H3W3, dim]
|
||||
|
||||
# context-context similarity
|
||||
c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3]
|
||||
# context-position similarity
|
||||
c_p = (
|
||||
einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale
|
||||
) # [(B H1W1) heads 1 H3W3]
|
||||
# position-context similarity
|
||||
p_c = (
|
||||
einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :])
|
||||
* self.scale
|
||||
)
|
||||
p_c = torch.squeeze(p_c, dim=4)
|
||||
p_c = p_c.permute(0, 1, 3, 2)
|
||||
dots = c_c + c_p + p_c
|
||||
return self.attend(dots)
|
||||
|
||||
def forward(self, Q, K, V, Q_r, K_r):
|
||||
attn = self.attend_with_rpe(Q, K, Q_r, K_r)
|
||||
B, HW, _ = Q.shape
|
||||
|
||||
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
|
||||
# 200 should be enough for a 8x downsampled image
|
||||
# assume x to be [_, _, 2]
|
||||
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
|
||||
return torch.cat(
|
||||
[
|
||||
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
|
||||
# 200 should be enough for a 8x downsampled image
|
||||
# assume x to be [_, _, 2]
|
||||
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
|
||||
return torch.cat(
|
||||
[
|
||||
torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||
torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||
torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||
torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -0,0 +1,649 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
mul = input_dim // 3
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(64 * mul)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(64 * mul)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64 * mul
|
||||
self.layer1 = self._make_layer(64 * mul, stride=1)
|
||||
self.layer2 = self._make_layer(96 * mul, stride=2)
|
||||
self.layer3 = self._make_layer(128 * mul, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def compute_params(self):
|
||||
num = 0
|
||||
for param in self.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
def forward(self, x):
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvNets(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1):
|
||||
super(ConvNets, self).__init__()
|
||||
|
||||
self.conv_first = nn.Conv2d(
|
||||
in_dim, inter_dim, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv_last = nn.Conv2d(
|
||||
inter_dim, out_dim, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.inter_convs = nn.ModuleList(
|
||||
[
|
||||
ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.conv_first(x))
|
||||
for inter_conv in self.inter_convs:
|
||||
x = inter_conv(x)
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convr1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convq1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
|
||||
self.convz2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convr2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convq2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.motion_feature_dim
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicFuseMotion(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicFuseMotion, self).__init__()
|
||||
cor_planes = args.motion_feature_dim
|
||||
out_planes = args.query_latent_dim
|
||||
|
||||
self.normf1 = nn.InstanceNorm2d(128)
|
||||
self.normf2 = nn.InstanceNorm2d(128)
|
||||
|
||||
self.convf1 = nn.Conv2d(2, 128, 3, padding=1)
|
||||
self.convf2 = nn.Conv2d(128, 128, 3, padding=1)
|
||||
self.convf3 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
|
||||
s = 1
|
||||
self.normc1 = nn.InstanceNorm2d(256 * s)
|
||||
self.normc2 = nn.InstanceNorm2d(256 * s)
|
||||
self.normc3 = nn.InstanceNorm2d(256 * s)
|
||||
|
||||
self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||
self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||
self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
|
||||
self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0)
|
||||
|
||||
def forward(self, flow, feat, context1=None):
|
||||
flo = F.relu(self.normf1(self.convf1(flow)))
|
||||
flo = F.relu(self.normf2(self.convf2(flo)))
|
||||
flo = self.convf3(flo)
|
||||
|
||||
feat = torch.cat([feat, context1], dim=1)
|
||||
feat = F.relu(self.normc1(self.convc1(feat)))
|
||||
feat = F.relu(self.normc2(self.convc2(feat)))
|
||||
feat = F.relu(self.normc3(self.convc3(feat)))
|
||||
feat = self.convc4(feat)
|
||||
|
||||
feat = torch.cat([flo, feat], dim=1)
|
||||
feat = F.relu(self.conv(feat))
|
||||
|
||||
return feat
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = 0.25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
class DirectMeanMaskPredictor(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(DirectMeanMaskPredictor, self).__init__()
|
||||
self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256)
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(args.predictor_dim, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, motion_features):
|
||||
delta_flow = self.flow_head(motion_features)
|
||||
mask = 0.25 * self.mask(motion_features)
|
||||
|
||||
return mask, delta_flow
|
||||
|
||||
|
||||
class BaiscMeanPredictor(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128):
|
||||
super(BaiscMeanPredictor, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, latent, flow):
|
||||
motion_features = self.encoder(flow, latent)
|
||||
delta_flow = self.flow_head(motion_features)
|
||||
mask = 0.25 * self.mask(motion_features)
|
||||
|
||||
return mask, delta_flow
|
||||
|
||||
|
||||
class BasicRPEEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicRPEEncoder, self).__init__()
|
||||
self.args = args
|
||||
dim = args.query_latent_dim
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(2, dim // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(dim // 2, dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(dim, dim),
|
||||
)
|
||||
|
||||
def forward(self, rpe_tokens):
|
||||
return self.encoder(rpe_tokens)
|
||||
|
||||
|
||||
from .twins import Block, CrossBlock
|
||||
|
||||
|
||||
class TwinsSelfAttentionLayer(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(TwinsSelfAttentionLayer, self).__init__()
|
||||
self.args = args
|
||||
embed_dim = 256
|
||||
num_heads = 8
|
||||
mlp_ratio = 4
|
||||
ws = 7
|
||||
sr_ratio = 4
|
||||
dpr = 0.0
|
||||
drop_rate = 0.0
|
||||
attn_drop_rate = 0.0
|
||||
|
||||
self.local_block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=ws,
|
||||
with_rpe=True,
|
||||
)
|
||||
self.global_block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=1,
|
||||
with_rpe=True,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, tgt, size):
|
||||
x = self.local_block(x, size)
|
||||
x = self.global_block(x, size)
|
||||
|
||||
tgt = self.local_block(tgt, size)
|
||||
tgt = self.global_block(tgt, size)
|
||||
return x, tgt
|
||||
|
||||
|
||||
class TwinsCrossAttentionLayer(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(TwinsCrossAttentionLayer, self).__init__()
|
||||
self.args = args
|
||||
embed_dim = 256
|
||||
num_heads = 8
|
||||
mlp_ratio = 4
|
||||
ws = 7
|
||||
sr_ratio = 4
|
||||
dpr = 0.0
|
||||
drop_rate = 0.0
|
||||
attn_drop_rate = 0.0
|
||||
|
||||
self.local_block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=ws,
|
||||
with_rpe=True,
|
||||
)
|
||||
self.global_block = CrossBlock(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=1,
|
||||
with_rpe=True,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, tgt, size):
|
||||
x = self.local_block(x, size)
|
||||
tgt = self.local_block(tgt, size)
|
||||
x, tgt = self.global_block(x, tgt, size)
|
||||
|
||||
return x, tgt
|
||||
@@ -0,0 +1,98 @@
|
||||
#from turtle import forward
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ConvNextLayer(nn.Module):
|
||||
def __init__(self, dim, depth=4):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)])
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def compute_params(self):
|
||||
num = 0
|
||||
for param in self.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, layer_scale_init_value=1e-6):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv2d(
|
||||
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(
|
||||
dim, 4 * dim
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(4 * dim, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
# print(f"conv next layer")
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||
|
||||
x = input + x
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
@@ -0,0 +1,316 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from ...utils.utils import coords_grid, bilinear_sampler
|
||||
from .attention import (
|
||||
MultiHeadAttention,
|
||||
LinearPositionEmbeddingSine,
|
||||
ExpPositionEmbeddingSine,
|
||||
)
|
||||
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
from .gru import BasicUpdateBlock, GMAUpdateBlock
|
||||
from .gma import Attention
|
||||
|
||||
|
||||
def initialize_flow(img):
|
||||
"""Flow is represented as difference between two means flow = mean1 - mean0"""
|
||||
N, C, H, W = img.shape
|
||||
mean = coords_grid(N, H, W).to(img.device)
|
||||
mean_init = coords_grid(N, H, W).to(img.device)
|
||||
|
||||
# optical flow computed as difference: flow = mean1 - mean0
|
||||
return mean, mean_init
|
||||
|
||||
|
||||
class CrossAttentionLayer(nn.Module):
|
||||
# def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
|
||||
def __init__(
|
||||
self,
|
||||
qk_dim,
|
||||
v_dim,
|
||||
query_token_dim,
|
||||
tgt_token_dim,
|
||||
add_flow_token=True,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
drop_path=0.0,
|
||||
dropout=0.0,
|
||||
pe="linear",
|
||||
):
|
||||
super(CrossAttentionLayer, self).__init__()
|
||||
|
||||
head_dim = qk_dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.query_token_dim = query_token_dim
|
||||
self.pe = pe
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_token_dim)
|
||||
self.norm2 = nn.LayerNorm(query_token_dim)
|
||||
self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads)
|
||||
self.q, self.k, self.v = (
|
||||
nn.Linear(query_token_dim, qk_dim, bias=True),
|
||||
nn.Linear(tgt_token_dim, qk_dim, bias=True),
|
||||
nn.Linear(tgt_token_dim, v_dim, bias=True),
|
||||
)
|
||||
|
||||
self.proj = nn.Linear(v_dim * 2, query_token_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(query_token_dim, query_token_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(query_token_dim, query_token_dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
self.add_flow_token = add_flow_token
|
||||
self.dim = qk_dim
|
||||
|
||||
def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3):
|
||||
"""
|
||||
query_coord [B, 2, H1, W1]
|
||||
"""
|
||||
B, _, H1, W1 = query_coord.shape
|
||||
|
||||
if key is None and value is None:
|
||||
key = self.k(memory)
|
||||
value = self.v(memory)
|
||||
|
||||
# [B, 2, H1, W1] -> [BH1W1, 1, 2]
|
||||
query_coord = query_coord.contiguous()
|
||||
query_coord = (
|
||||
query_coord.view(B, 2, -1)
|
||||
.permute(0, 2, 1)[:, :, None, :]
|
||||
.contiguous()
|
||||
.view(B * H1 * W1, 1, 2)
|
||||
)
|
||||
if self.pe == "linear":
|
||||
query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim)
|
||||
elif self.pe == "exp":
|
||||
query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim)
|
||||
|
||||
short_cut = query
|
||||
query = self.norm1(query)
|
||||
|
||||
if self.add_flow_token:
|
||||
q = self.q(query + query_coord_enc)
|
||||
else:
|
||||
q = self.q(query_coord_enc)
|
||||
k, v = key, value
|
||||
|
||||
x = self.multi_head_attn(q, k, v)
|
||||
|
||||
x = self.proj(torch.cat([x, short_cut], dim=2))
|
||||
x = short_cut + self.proj_drop(x)
|
||||
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
|
||||
return x, k, v
|
||||
|
||||
|
||||
class MemoryDecoderLayer(nn.Module):
|
||||
def __init__(self, dim, cfg):
|
||||
super(MemoryDecoderLayer, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.patch_size = cfg.patch_size # for converting coords into H2', W2' space
|
||||
|
||||
query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim
|
||||
qk_dim, v_dim = query_token_dim, query_token_dim
|
||||
self.cross_attend = CrossAttentionLayer(
|
||||
qk_dim,
|
||||
v_dim,
|
||||
query_token_dim,
|
||||
tgt_token_dim,
|
||||
add_flow_token=cfg.add_flow_token,
|
||||
dropout=cfg.dropout,
|
||||
)
|
||||
|
||||
def forward(self, query, key, value, memory, coords1, size, size_h3w3):
|
||||
"""
|
||||
x: [B*H1*W1, 1, C]
|
||||
memory: [B*H1*W1, H2'*W2', C]
|
||||
coords1 [B, 2, H2, W2]
|
||||
size: B, C, H1, W1
|
||||
1. Note that here coords0 and coords1 are in H2, W2 space.
|
||||
Should first convert it into H2', W2' space.
|
||||
2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0]
|
||||
"""
|
||||
x_global, k, v = self.cross_attend(
|
||||
query, key, value, memory, coords1, self.patch_size, size_h3w3
|
||||
)
|
||||
B, C, H1, W1 = size
|
||||
C = self.cfg.query_latent_dim
|
||||
x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2)
|
||||
return x_global, k, v
|
||||
|
||||
|
||||
class ReverseCostExtractor(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(ReverseCostExtractor, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, cost_maps, coords0, coords1):
|
||||
"""
|
||||
cost_maps - B*H1*W1, cost_heads_num, H2, W2
|
||||
coords - B, 2, H1, W1
|
||||
"""
|
||||
BH1W1, heads, H2, W2 = cost_maps.shape
|
||||
B, _, H1, W1 = coords1.shape
|
||||
|
||||
assert (H1 == H2) and (W1 == W2)
|
||||
assert BH1W1 == B * H1 * W1
|
||||
|
||||
cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2)
|
||||
coords = coords1.permute(0, 2, 3, 1)
|
||||
corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2]
|
||||
corr = rearrange(
|
||||
corr,
|
||||
"b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1",
|
||||
b=B,
|
||||
heads=heads,
|
||||
h1=H1,
|
||||
w1=W1,
|
||||
h2=H2,
|
||||
w2=W2,
|
||||
)
|
||||
|
||||
r = 4
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device)
|
||||
centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2)
|
||||
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords = centroid + delta
|
||||
corr = bilinear_sampler(corr, coords)
|
||||
corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2)
|
||||
return corr
|
||||
|
||||
|
||||
class MemoryDecoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(MemoryDecoder, self).__init__()
|
||||
dim = self.dim = cfg.query_latent_dim
|
||||
self.cfg = cfg
|
||||
|
||||
self.flow_token_encoder = nn.Sequential(
|
||||
nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim, dim, 1, 1),
|
||||
)
|
||||
self.proj = nn.Conv2d(256, 256, 1)
|
||||
self.depth = cfg.decoder_depth
|
||||
self.decoder_layer = MemoryDecoderLayer(dim, cfg)
|
||||
|
||||
if self.cfg.gma:
|
||||
self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128)
|
||||
self.att = Attention(
|
||||
args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128
|
||||
)
|
||||
else:
|
||||
self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128)
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
||||
|
||||
def encode_flow_token(self, cost_maps, coords):
|
||||
"""
|
||||
cost_maps - B*H1*W1, cost_heads_num, H2, W2
|
||||
coords - B, 2, H1, W1
|
||||
"""
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
r = 4
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||
|
||||
centroid = coords.reshape(batch * h1 * w1, 1, 1, 2)
|
||||
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords = centroid + delta
|
||||
corr = bilinear_sampler(cost_maps, coords)
|
||||
corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2)
|
||||
return corr
|
||||
|
||||
def forward(self, cost_memory, context, data={}, flow_init=None, iters=None):
|
||||
"""
|
||||
memory: [B*H1*W1, H2'*W2', C]
|
||||
context: [B, D, H1, W1]
|
||||
"""
|
||||
cost_maps = data["cost_maps"]
|
||||
coords0, coords1 = initialize_flow(context)
|
||||
|
||||
if flow_init is not None:
|
||||
# print("[Using warm start]")
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
# flow = coords1
|
||||
|
||||
flow_predictions = []
|
||||
|
||||
context = self.proj(context)
|
||||
net, inp = torch.split(context, [128, 128], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
if self.cfg.gma:
|
||||
attention = self.att(inp)
|
||||
|
||||
size = net.shape
|
||||
key, value = None, None
|
||||
if iters is None:
|
||||
iters = self.depth
|
||||
for idx in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
|
||||
cost_forward = self.encode_flow_token(cost_maps, coords1)
|
||||
# cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1)
|
||||
|
||||
query = self.flow_token_encoder(cost_forward)
|
||||
query = (
|
||||
query.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
.view(size[0] * size[2] * size[3], 1, self.dim)
|
||||
)
|
||||
cost_global, key, value = self.decoder_layer(
|
||||
query, key, value, cost_memory, coords1, size, data["H3W3"]
|
||||
)
|
||||
if self.cfg.only_global:
|
||||
corr = cost_global
|
||||
else:
|
||||
corr = torch.cat([cost_global, cost_forward], dim=1)
|
||||
|
||||
flow = coords1 - coords0
|
||||
|
||||
if self.cfg.gma:
|
||||
net, up_mask, delta_flow = self.update_block(
|
||||
net, inp, corr, flow, attention
|
||||
)
|
||||
else:
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
|
||||
# flow = delta_flow
|
||||
coords1 = coords1 + delta_flow
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
# if self.training:
|
||||
# return flow_predictions
|
||||
# else:
|
||||
return flow_predictions[-1], coords1 - coords0
|
||||
@@ -0,0 +1,534 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
import numpy as np
|
||||
|
||||
from einops import rearrange
|
||||
from ...utils.utils import coords_grid
|
||||
from .attention import (
|
||||
BroadMultiHeadAttention,
|
||||
MultiHeadAttention,
|
||||
LinearPositionEmbeddingSine,
|
||||
ExpPositionEmbeddingSine,
|
||||
)
|
||||
from ..encoders import twins_svt_large
|
||||
from typing import Tuple
|
||||
from .twins import Size_
|
||||
from .cnn import BasicEncoder
|
||||
from .mlpmixer import MLPMixerLayer
|
||||
from .convnext import ConvNextLayer
|
||||
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.dim = embed_dim
|
||||
self.pe = pe
|
||||
|
||||
# assert patch_size == 8
|
||||
if patch_size == 8:
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2
|
||||
),
|
||||
)
|
||||
elif patch_size == 4:
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2
|
||||
),
|
||||
)
|
||||
else:
|
||||
print(f"patch size = {patch_size} is unacceptable.")
|
||||
|
||||
self.ffn_with_coord = nn.Sequential(
|
||||
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
|
||||
)
|
||||
self.norm = nn.LayerNorm(embed_dim * 2)
|
||||
|
||||
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
|
||||
B, C, H, W = x.shape # C == 1
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
|
||||
|
||||
x = self.proj(x)
|
||||
out_size = x.shape[2:]
|
||||
|
||||
patch_coord = (
|
||||
coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size
|
||||
+ self.patch_size / 2
|
||||
) # in feature coordinate space
|
||||
patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
|
||||
if self.pe == "linear":
|
||||
patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
|
||||
elif self.pe == "exp":
|
||||
patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
|
||||
patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(
|
||||
B, -1, out_size[0], out_size[1]
|
||||
)
|
||||
|
||||
x_pe = torch.cat([x, patch_coord_enc], dim=1)
|
||||
x = self.ffn_with_coord(x_pe)
|
||||
x = self.norm(x.flatten(2).transpose(1, 2))
|
||||
|
||||
return x, out_size
|
||||
|
||||
|
||||
from .twins import Block, CrossBlock
|
||||
|
||||
|
||||
class GroupVerticalSelfAttentionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
cfg,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
drop_path=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(GroupVerticalSelfAttentionLayer, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
embed_dim = dim
|
||||
mlp_ratio = 4
|
||||
ws = 7
|
||||
sr_ratio = 4
|
||||
dpr = 0.0
|
||||
drop_rate = dropout
|
||||
attn_drop_rate = 0.0
|
||||
|
||||
self.block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=ws,
|
||||
with_rpe=True,
|
||||
vert_c_dim=cfg.vert_c_dim,
|
||||
groupattention=True,
|
||||
cfg=self.cfg,
|
||||
)
|
||||
|
||||
def forward(self, x, size, context=None):
|
||||
x = self.block(x, size, context)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VerticalSelfAttentionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
cfg,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
drop_path=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(VerticalSelfAttentionLayer, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
embed_dim = dim
|
||||
mlp_ratio = 4
|
||||
ws = 7
|
||||
sr_ratio = 4
|
||||
dpr = 0.0
|
||||
drop_rate = dropout
|
||||
attn_drop_rate = 0.0
|
||||
|
||||
self.local_block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=ws,
|
||||
with_rpe=True,
|
||||
vert_c_dim=cfg.vert_c_dim,
|
||||
)
|
||||
self.global_block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
sr_ratio=sr_ratio,
|
||||
ws=1,
|
||||
with_rpe=True,
|
||||
vert_c_dim=cfg.vert_c_dim,
|
||||
)
|
||||
|
||||
def forward(self, x, size, context=None):
|
||||
x = self.local_block(x, size, context)
|
||||
x = self.global_block(x, size, context)
|
||||
|
||||
return x
|
||||
|
||||
def compute_params(self):
|
||||
num = 0
|
||||
for param in self.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
|
||||
class SelfAttentionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
cfg,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
drop_path=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(SelfAttentionLayer, self).__init__()
|
||||
assert (
|
||||
dim % num_heads == 0
|
||||
), f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.multi_head_attn = MultiHeadAttention(dim, num_heads)
|
||||
self.q, self.k, self.v = (
|
||||
nn.Linear(dim, dim, bias=True),
|
||||
nn.Linear(dim, dim, bias=True),
|
||||
nn.Linear(dim, dim, bias=True),
|
||||
)
|
||||
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [BH1W1, H3W3, D]
|
||||
"""
|
||||
short_cut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
q, k, v = self.q(x), self.k(x), self.v(x)
|
||||
|
||||
x = self.multi_head_attn(q, k, v)
|
||||
|
||||
x = self.proj(x)
|
||||
x = short_cut + self.proj_drop(x)
|
||||
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def compute_params(self):
|
||||
num = 0
|
||||
for param in self.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
|
||||
class CrossAttentionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qk_dim,
|
||||
v_dim,
|
||||
query_token_dim,
|
||||
tgt_token_dim,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
drop_path=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(CrossAttentionLayer, self).__init__()
|
||||
assert (
|
||||
qk_dim % num_heads == 0
|
||||
), f"dim {qk_dim} should be divided by num_heads {num_heads}."
|
||||
assert (
|
||||
v_dim % num_heads == 0
|
||||
), f"dim {v_dim} should be divided by num_heads {num_heads}."
|
||||
"""
|
||||
Query Token: [N, C] -> [N, qk_dim] (Q)
|
||||
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
|
||||
"""
|
||||
self.num_heads = num_heads
|
||||
head_dim = qk_dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_token_dim)
|
||||
self.norm2 = nn.LayerNorm(query_token_dim)
|
||||
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
|
||||
self.q, self.k, self.v = (
|
||||
nn.Linear(query_token_dim, qk_dim, bias=True),
|
||||
nn.Linear(tgt_token_dim, qk_dim, bias=True),
|
||||
nn.Linear(tgt_token_dim, v_dim, bias=True),
|
||||
)
|
||||
|
||||
self.proj = nn.Linear(v_dim, query_token_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(query_token_dim, query_token_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(query_token_dim, query_token_dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, query, tgt_token):
|
||||
"""
|
||||
x: [BH1W1, H3W3, D]
|
||||
"""
|
||||
short_cut = query
|
||||
query = self.norm1(query)
|
||||
|
||||
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
|
||||
|
||||
x = self.multi_head_attn(q, k, v)
|
||||
|
||||
x = short_cut + self.proj_drop(self.proj(x))
|
||||
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CostPerceiverEncoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(CostPerceiverEncoder, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.patch_size = cfg.patch_size
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans=self.cfg.cost_heads_num,
|
||||
patch_size=self.patch_size,
|
||||
embed_dim=cfg.cost_latent_input_dim,
|
||||
pe=cfg.pe,
|
||||
)
|
||||
|
||||
self.depth = cfg.encoder_depth
|
||||
|
||||
self.latent_tokens = nn.Parameter(
|
||||
torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)
|
||||
)
|
||||
|
||||
query_token_dim, tgt_token_dim = (
|
||||
cfg.cost_latent_dim,
|
||||
cfg.cost_latent_input_dim * 2,
|
||||
)
|
||||
qk_dim, v_dim = query_token_dim, query_token_dim
|
||||
self.input_layer = CrossAttentionLayer(
|
||||
qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout
|
||||
)
|
||||
|
||||
if cfg.use_mlp:
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
|
||||
for idx in range(self.depth)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
|
||||
for idx in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
if self.cfg.vertical_conv:
|
||||
self.vertical_encoder_layers = nn.ModuleList(
|
||||
[ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]
|
||||
)
|
||||
else:
|
||||
self.vertical_encoder_layers = nn.ModuleList(
|
||||
[
|
||||
VerticalSelfAttentionLayer(
|
||||
cfg.cost_latent_dim, cfg, dropout=cfg.dropout
|
||||
)
|
||||
for idx in range(self.depth)
|
||||
]
|
||||
)
|
||||
self.cost_scale_aug = None
|
||||
if "cost_scale_aug" in cfg.keys():
|
||||
self.cost_scale_aug = cfg.cost_scale_aug
|
||||
print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug))
|
||||
|
||||
def forward(self, cost_volume, data, context=None):
|
||||
B, heads, H1, W1, H2, W2 = cost_volume.shape
|
||||
cost_maps = (
|
||||
cost_volume.permute(0, 2, 3, 1, 4, 5)
|
||||
.contiguous()
|
||||
.view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
|
||||
)
|
||||
data["cost_maps"] = cost_maps
|
||||
|
||||
if self.cost_scale_aug is not None:
|
||||
scale_factor = (
|
||||
torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
|
||||
.uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1])
|
||||
.to(cost_maps.device)
|
||||
)
|
||||
cost_maps = cost_maps * scale_factor
|
||||
|
||||
x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C
|
||||
data["H3W3"] = size
|
||||
H3, W3 = size
|
||||
|
||||
x = self.input_layer(self.latent_tokens, x)
|
||||
|
||||
short_cut = x
|
||||
|
||||
for idx, layer in enumerate(self.encoder_layers):
|
||||
x = layer(x)
|
||||
if self.cfg.vertical_conv:
|
||||
# B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1
|
||||
x = (
|
||||
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1)
|
||||
)
|
||||
x = self.vertical_encoder_layers[idx](x)
|
||||
# B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D
|
||||
x = (
|
||||
x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1)
|
||||
.permute(0, 2, 3, 1)
|
||||
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||
)
|
||||
else:
|
||||
x = (
|
||||
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1)
|
||||
)
|
||||
x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
|
||||
x = (
|
||||
x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
|
||||
)
|
||||
|
||||
if self.cfg.cost_encoder_res is True:
|
||||
x = x + short_cut
|
||||
# print("~~~~")
|
||||
return x
|
||||
|
||||
|
||||
class MemoryEncoder(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(MemoryEncoder, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.fnet == "twins":
|
||||
self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
|
||||
elif cfg.fnet == "basicencoder":
|
||||
self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
|
||||
else:
|
||||
exit()
|
||||
self.channel_convertor = nn.Conv2d(
|
||||
cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False
|
||||
)
|
||||
self.cost_perceiver_encoder = CostPerceiverEncoder(cfg)
|
||||
|
||||
def corr(self, fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = rearrange(
|
||||
fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
|
||||
)
|
||||
fmap2 = rearrange(
|
||||
fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
|
||||
)
|
||||
corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2)
|
||||
corr = corr.permute(0, 2, 1, 3).view(
|
||||
batch * ht * wd, self.cfg.cost_heads_num, ht, wd
|
||||
)
|
||||
# corr = self.norm(self.relu(corr))
|
||||
corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd)
|
||||
|
||||
return corr
|
||||
|
||||
def forward(self, img1, img2, data, context=None, return_feat=False):
|
||||
# The original implementation
|
||||
# feat_s = self.feat_encoder(img1)
|
||||
# feat_t = self.feat_encoder(img2)
|
||||
# feat_s = self.channel_convertor(feat_s)
|
||||
# feat_t = self.channel_convertor(feat_t)
|
||||
|
||||
imgs = torch.cat([img1, img2], dim=0)
|
||||
feats = self.feat_encoder(imgs)
|
||||
feats = self.channel_convertor(feats)
|
||||
B = feats.shape[0] // 2
|
||||
feat_s = feats[:B]
|
||||
if return_feat:
|
||||
ffeat = feats[:B]
|
||||
feat_t = feats[B:]
|
||||
|
||||
B, C, H, W = feat_s.shape
|
||||
size = (H, W)
|
||||
|
||||
if self.cfg.feat_cross_attn:
|
||||
feat_s = feat_s.flatten(2).transpose(1, 2)
|
||||
feat_t = feat_t.flatten(2).transpose(1, 2)
|
||||
|
||||
for layer in self.layers:
|
||||
feat_s, feat_t = layer(feat_s, feat_t, size)
|
||||
|
||||
feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
cost_volume = self.corr(feat_s, feat_t)
|
||||
x = self.cost_perceiver_encoder(cost_volume, data, context)
|
||||
|
||||
if return_feat:
|
||||
return x, ffeat
|
||||
return x
|
||||
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class RelPosEmb(nn.Module):
|
||||
def __init__(self, max_pos_size, dim_head):
|
||||
super().__init__()
|
||||
self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
|
||||
self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
|
||||
|
||||
deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(
|
||||
max_pos_size
|
||||
).view(-1, 1)
|
||||
rel_ind = deltas + max_pos_size - 1
|
||||
self.register_buffer("rel_ind", rel_ind)
|
||||
|
||||
def forward(self, q):
|
||||
batch, heads, h, w, c = q.shape
|
||||
height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
|
||||
width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
|
||||
|
||||
height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h)
|
||||
width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w)
|
||||
|
||||
height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb)
|
||||
width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb)
|
||||
|
||||
return height_score + width_score
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
args,
|
||||
dim,
|
||||
max_pos_size=100,
|
||||
heads=4,
|
||||
dim_head=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
|
||||
|
||||
self.pos_emb = RelPosEmb(max_pos_size, dim_head)
|
||||
for param in self.pos_emb.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, fmap):
|
||||
heads, b, c, h, w = self.heads, *fmap.shape
|
||||
|
||||
q, k = self.to_qk(fmap).chunk(2, dim=1)
|
||||
|
||||
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
|
||||
q = self.scale * q
|
||||
|
||||
# if self.args.position_only:
|
||||
# sim = self.pos_emb(q)
|
||||
|
||||
# elif self.args.position_and_content:
|
||||
# sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
|
||||
# sim_pos = self.pos_emb(q)
|
||||
# sim = sim_content + sim_pos
|
||||
|
||||
# else:
|
||||
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
|
||||
|
||||
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class Aggregate(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dim,
|
||||
heads=4,
|
||||
dim_head=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
||||
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if dim != inner_dim:
|
||||
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
|
||||
else:
|
||||
self.project = None
|
||||
|
||||
def forward(self, attn, fmap):
|
||||
heads, b, c, h, w = self.heads, *fmap.shape
|
||||
|
||||
v = self.to_v(fmap)
|
||||
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
||||
|
||||
if self.project is not None:
|
||||
out = self.project(out)
|
||||
|
||||
out = fmap + self.gamma * out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
att = Attention(dim=128, heads=1)
|
||||
fmap = torch.randn(2, 128, 40, 90)
|
||||
out = att(fmap)
|
||||
|
||||
print(out.shape)
|
||||
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convr1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convq1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
|
||||
self.convz2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convr2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convq2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
if args.only_global:
|
||||
print("[Decoding with only global cost]")
|
||||
cor_planes = args.query_latent_dim
|
||||
else:
|
||||
cor_planes = 81 + args.query_latent_dim
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = 0.25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
from .gma import Aggregate
|
||||
|
||||
|
||||
class GMAUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(
|
||||
hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim
|
||||
)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1)
|
||||
|
||||
def forward(self, net, inp, corr, flow, attention):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
motion_features_global = self.aggregator(attention, motion_features)
|
||||
inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
|
||||
|
||||
# Attentional update
|
||||
net = self.gru(net, inp_cat)
|
||||
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = 0.25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
@@ -0,0 +1,55 @@
|
||||
from torch import nn
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
|
||||
return nn.Sequential(
|
||||
dense(dim, dim * expansion_factor),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
dense(dim * expansion_factor, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
|
||||
class MLPMixerLayer(nn.Module):
|
||||
def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0):
|
||||
super(MLPMixerLayer, self).__init__()
|
||||
|
||||
# print(f"use mlp mixer layer")
|
||||
K = cfg.cost_latent_token_num
|
||||
expansion_factor = cfg.mlp_expansion_factor
|
||||
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
||||
|
||||
self.mlpmixer = nn.Sequential(
|
||||
PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)),
|
||||
PreNormResidual(
|
||||
dim, FeedForward(dim, expansion_factor, dropout, chan_last)
|
||||
),
|
||||
)
|
||||
|
||||
def compute_params(self):
|
||||
num = 0
|
||||
for param in self.mlpmixer.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [BH1W1, K, D]
|
||||
"""
|
||||
|
||||
return self.mlpmixer(x)
|
||||
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils.utils import coords_grid
|
||||
from ..encoders import twins_svt_large
|
||||
from .encoder import MemoryEncoder
|
||||
from .decoder import MemoryDecoder
|
||||
from .cnn import BasicEncoder
|
||||
|
||||
|
||||
class FlowFormer(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(FlowFormer, self).__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.memory_encoder = MemoryEncoder(cfg)
|
||||
self.memory_decoder = MemoryDecoder(cfg)
|
||||
if cfg.cnet == "twins":
|
||||
self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
|
||||
elif cfg.cnet == "basicencoder":
|
||||
self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
|
||||
|
||||
def build_coord(self, img):
|
||||
N, C, H, W = img.shape
|
||||
coords = coords_grid(N, H // 8, W // 8)
|
||||
return coords
|
||||
|
||||
def forward(
|
||||
self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None
|
||||
):
|
||||
# Following https://github.com/princeton-vl/RAFT/
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
data = {}
|
||||
|
||||
if self.cfg.context_concat:
|
||||
context = self.context_encoder(torch.cat([image1, image2], dim=1))
|
||||
else:
|
||||
if return_feat:
|
||||
context, cfeat = self.context_encoder(image1, return_feat=return_feat)
|
||||
else:
|
||||
context = self.context_encoder(image1)
|
||||
if return_feat:
|
||||
cost_memory, ffeat = self.memory_encoder(
|
||||
image1, image2, data, context, return_feat=return_feat
|
||||
)
|
||||
else:
|
||||
cost_memory = self.memory_encoder(image1, image2, data, context)
|
||||
|
||||
flow_predictions = self.memory_decoder(
|
||||
cost_memory, context, data, flow_init=flow_init, iters=iters
|
||||
)
|
||||
|
||||
if return_feat:
|
||||
return flow_predictions, cfeat, ffeat
|
||||
return flow_predictions
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
||||
def build_flowformer(cfg):
|
||||
name = cfg.transformer
|
||||
if name == "latentcostformer":
|
||||
from .LatentCostFormer.transformer import FlowFormer
|
||||
else:
|
||||
raise ValueError(f"FlowFormer = {name} is not a valid architecture!")
|
||||
return FlowFormer(cfg[name])
|
||||
@@ -0,0 +1,562 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from ..utils.utils import bilinear_sampler, indexing
|
||||
|
||||
|
||||
def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300):
|
||||
"""
|
||||
x is of shape [*, 2]. The last dimension are two coordinates (x and y).
|
||||
"""
|
||||
freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device)
|
||||
return torch.cat(
|
||||
[
|
||||
x * NORMALIZE_FACOR,
|
||||
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def sampler_gaussian(latent, mean, std, image_size, point_num=25):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
# std [B, 1, H, W]
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
delta_3sigma = (
|
||||
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
|
||||
* STD_MAX
|
||||
* delta
|
||||
* 3
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = centroid + delta_3sigma
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
sampled_latents = bilinear_sampler(
|
||||
latent, coords
|
||||
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
|
||||
|
||||
return sampled_latents, sampled_weights
|
||||
|
||||
|
||||
def sampler_gaussian_zy(
|
||||
latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1
|
||||
):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
# std [B, 1, H, W]
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
delta_3sigma = (
|
||||
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = centroid + delta_3sigma
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
sampled_latents = bilinear_sampler(
|
||||
latent, coords
|
||||
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta
|
||||
|
||||
if return_deltaXY:
|
||||
return sampled_latents, sampled_weights, delta_3sigma
|
||||
else:
|
||||
return sampled_latents, sampled_weights
|
||||
|
||||
|
||||
def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
# std [B, 1, H, W]
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
delta_3sigma = (
|
||||
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
|
||||
* STD_MAX
|
||||
* delta
|
||||
* 3
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = centroid + delta_3sigma
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
sampled_latents = bilinear_sampler(
|
||||
latent, coords
|
||||
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
|
||||
|
||||
if return_deltaXY:
|
||||
return sampled_latents, sampled_weights, delta_3sigma
|
||||
else:
|
||||
return sampled_latents, sampled_weights
|
||||
|
||||
|
||||
def sampler_gaussian_fix(latent, mean, image_size, point_num=49):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
radius = int((int(point_num**0.5) - 1) / 2)
|
||||
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = centroid + delta
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
sampled_latents = bilinear_sampler(
|
||||
latent, coords
|
||||
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
|
||||
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||
|
||||
return sampled_latents, sampled_weights
|
||||
|
||||
|
||||
def sampler_gaussian_fix_pyramid(
|
||||
latent, feat_pyramid, scale_weight, mean, image_size, point_num=25
|
||||
):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
# scale weight [B, H*W, layer_num]
|
||||
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
radius = int((int(point_num**0.5) - 1) / 2)
|
||||
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
sampled_latents = []
|
||||
for i in range(len(feat_pyramid)):
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = (centroid + delta) / 2**i
|
||||
coords = rearrange(
|
||||
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
|
||||
)
|
||||
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
|
||||
|
||||
sampled_latents = torch.stack(
|
||||
sampled_latents, dim=1
|
||||
) # [B, layer_num, dim, H*W, point_num]
|
||||
sampled_latents = sampled_latents.permute(
|
||||
0, 3, 4, 2, 1
|
||||
) # [B, H*W, point_num, dim, layer_num]
|
||||
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
|
||||
vis_out = scale_weight
|
||||
scale_weight = torch.unsqueeze(
|
||||
torch.unsqueeze(scale_weight, dim=2), dim=2
|
||||
) # [B, HW, 1, 1, layer_num]
|
||||
|
||||
weighted_latent = torch.sum(
|
||||
sampled_latents * scale_weight, dim=-1
|
||||
) # [B, H*W, point_num, dim]
|
||||
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||
|
||||
return weighted_latent, sampled_weights, vis_out
|
||||
|
||||
|
||||
def sampler_gaussian_pyramid(
|
||||
latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25
|
||||
):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W]
|
||||
# scale weight [B, H*W, layer_num]
|
||||
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(
|
||||
latent, "b (h w) c -> b c h w", h=H, w=W
|
||||
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
|
||||
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
|
||||
radius = int((int(point_num**0.5) - 1) / 2)
|
||||
|
||||
dx = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
dy = torch.linspace(-1, 1, int(point_num**0.5))
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
delta_3sigma = (
|
||||
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
sampled_latents = []
|
||||
for i in range(len(feat_pyramid)):
|
||||
centroid = mean.reshape(B * H * W, 1, 1, 2)
|
||||
coords = (centroid + delta_3sigma) / 2**i
|
||||
coords = rearrange(
|
||||
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
|
||||
)
|
||||
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
|
||||
|
||||
sampled_latents = torch.stack(
|
||||
sampled_latents, dim=1
|
||||
) # [B, layer_num, dim, H*W, point_num]
|
||||
sampled_latents = sampled_latents.permute(
|
||||
0, 3, 4, 2, 1
|
||||
) # [B, H*W, point_num, dim, layer_num]
|
||||
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
|
||||
vis_out = scale_weight
|
||||
scale_weight = torch.unsqueeze(
|
||||
torch.unsqueeze(scale_weight, dim=2), dim=2
|
||||
) # [B, HW, 1, 1, layer_num]
|
||||
|
||||
weighted_latent = torch.sum(
|
||||
sampled_latents * scale_weight, dim=-1
|
||||
) # [B, H*W, point_num, dim]
|
||||
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||
|
||||
return weighted_latent, sampled_weights, vis_out
|
||||
|
||||
|
||||
def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25):
|
||||
"""different heads have different mean"""
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W, heands]
|
||||
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
_, _, _, _, HEADS = mean.shape
|
||||
STD_MAX = 20
|
||||
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
|
||||
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
|
||||
|
||||
radius = int((int(point_num**0.5) - 1) / 2)
|
||||
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = (
|
||||
torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||
.to(mean.device)
|
||||
.repeat(HEADS, 1, 1, 1)
|
||||
) # [HEADS, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
|
||||
coords = centroid + delta
|
||||
coords = rearrange(
|
||||
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
|
||||
)
|
||||
sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum]
|
||||
sampled_latents = sampled_latents.permute(
|
||||
0, 2, 3, 1
|
||||
) # [B, H*W*HEADS, pointnum, dim]
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||
return sampled_latents, sampled_weights
|
||||
|
||||
|
||||
def sampler_gaussian_fix_pyramid_MH(
|
||||
latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25
|
||||
):
|
||||
# latent [B, H*W, D]
|
||||
# mean [B, 2, H, W, heands]
|
||||
# scale_head weight [B, H*W, layer_num*heads]
|
||||
|
||||
H, W = image_size
|
||||
B, HW, D = latent.shape
|
||||
_, _, _, _, HEADS = mean.shape
|
||||
|
||||
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
|
||||
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
|
||||
|
||||
radius = int((int(point_num**0.5) - 1) / 2)
|
||||
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
mean.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
sampled_latents = []
|
||||
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
|
||||
for i in range(len(feat_pyramid)):
|
||||
coords = (centroid) / 2**i + delta
|
||||
coords = rearrange(
|
||||
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
|
||||
)
|
||||
sampled_latents.append(
|
||||
bilinear_sampler(feat_pyramid[i], coords)
|
||||
) # [B, dim, H*W*HEADS, point_num]
|
||||
|
||||
sampled_latents = torch.stack(
|
||||
sampled_latents, dim=1
|
||||
) # [B, layer_num, dim, H*W*HEADS, point_num]
|
||||
sampled_latents = sampled_latents.permute(
|
||||
0, 3, 4, 2, 1
|
||||
) # [B, H*W*HEADS, point_num, dim, layer_num]
|
||||
|
||||
scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1)
|
||||
scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num]
|
||||
scale_head_weight = torch.unsqueeze(
|
||||
torch.unsqueeze(scale_head_weight, dim=2), dim=2
|
||||
) # [B, H*W*HEADS, 1, 1, layer_num]
|
||||
|
||||
weighted_latent = torch.sum(
|
||||
sampled_latents * scale_head_weight, dim=-1
|
||||
) # [B, H*W*HEADS, point_num, dim]
|
||||
|
||||
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
|
||||
|
||||
return weighted_latent, sampled_weights
|
||||
|
||||
|
||||
def sampler(feat, center, window_size):
|
||||
# feat [B, C, H, W]
|
||||
# center [B, 2, H, W]
|
||||
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
B, H, W, C = center.shape
|
||||
|
||||
radius = window_size // 2
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
center.device
|
||||
) # [B*H*W, window_size, point_num**0.5, 2]
|
||||
|
||||
center = center.reshape(B * H * W, 1, 1, 2)
|
||||
coords = center + delta
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
sampled_latents = bilinear_sampler(
|
||||
feat, coords
|
||||
) # [B*H*W, dim, window_size, window_size]
|
||||
# sampled_latents = sampled_latents.permute(0, 2, 3, 1)
|
||||
|
||||
return sampled_latents
|
||||
|
||||
|
||||
def retrieve_tokens(feat, center, window_size, sampler):
|
||||
# feat [B, C, H, W]
|
||||
# center [B, 2, H, W]
|
||||
radius = window_size // 2
|
||||
dx = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
dy = torch.linspace(-radius, radius, 2 * radius + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
|
||||
center.device
|
||||
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
|
||||
|
||||
B, H, W, C = center.shape
|
||||
centroid = center.reshape(B * H * W, 1, 1, 2)
|
||||
coords = centroid + delta
|
||||
|
||||
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
|
||||
if sampler == "nn":
|
||||
sampled_latents = indexing(feat, coords)
|
||||
elif sampler == "bilinear":
|
||||
sampled_latents = bilinear_sampler(feat, coords)
|
||||
else:
|
||||
raise ValueError("invalid sampler")
|
||||
# [B, dim, H*W, point_num]
|
||||
|
||||
return sampled_latents
|
||||
|
||||
|
||||
def pyramid_retrieve_tokens(
|
||||
feat_pyramid, center, image_size, window_sizes, sampler="bilinear"
|
||||
):
|
||||
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
|
||||
sampled_latents_pyramid = []
|
||||
for idx in range(len(window_sizes)):
|
||||
sampled_latents_pyramid.append(
|
||||
retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler)
|
||||
)
|
||||
center = center / 2
|
||||
|
||||
return torch.cat(sampled_latents_pyramid, dim=-1)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.net(x)
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5):
|
||||
super().__init__()
|
||||
self.FC1 = nn.Linear(in_dim, innter_dim)
|
||||
self.FC_out = nn.Linear(innter_dim, out_dim)
|
||||
self.relu = torch.nn.LeakyReLU(0.2)
|
||||
self.FC_inter = nn.ModuleList(
|
||||
[nn.Linear(innter_dim, innter_dim) for i in range(depth)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.FC1(x)
|
||||
x = self.relu(x)
|
||||
for inter_fc in self.FC_inter:
|
||||
x = inter_fc(x)
|
||||
x = self.relu(x)
|
||||
x = self.FC_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.num_kv_tokens = num_kv_tokens
|
||||
self.scale = (dim / heads) ** -0.5
|
||||
self.rpe = cfg.rpe
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.use_rpe = use_rpe
|
||||
|
||||
if use_rpe:
|
||||
if rpe_bias is None:
|
||||
if self.rpe == "element-wise":
|
||||
self.rpe_bias = nn.Parameter(
|
||||
torch.zeros(heads, self.num_kv_tokens, dim // heads)
|
||||
)
|
||||
elif self.rpe == "head-wise":
|
||||
self.rpe_bias = nn.Parameter(
|
||||
torch.zeros(1, heads, 1, self.num_kv_tokens)
|
||||
)
|
||||
elif self.rpe == "token-wise":
|
||||
self.rpe_bias = nn.Parameter(
|
||||
torch.zeros(1, 1, 1, self.num_kv_tokens)
|
||||
) # 81 is point_num
|
||||
elif self.rpe == "implicit":
|
||||
pass
|
||||
# self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2)
|
||||
# raise ValueError('Implicit Encoding Not Implemented')
|
||||
elif self.rpe == "element-wise-value":
|
||||
self.rpe_bias = nn.Parameter(
|
||||
torch.zeros(heads, self.num_kv_tokens, dim // heads)
|
||||
)
|
||||
self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim))
|
||||
else:
|
||||
raise ValueError("Not Implemented")
|
||||
else:
|
||||
self.rpe_bias = rpe_bias
|
||||
|
||||
def attend_with_rpe(self, Q, K, rpe_bias):
|
||||
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
|
||||
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
dots = (
|
||||
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
|
||||
) # (b hw) heads 1 pointnum
|
||||
if self.use_rpe:
|
||||
if self.rpe == "element-wise":
|
||||
rpe_bias_weight = (
|
||||
einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale
|
||||
) # (b hw) heads 1 pointnum
|
||||
dots = dots + rpe_bias_weight
|
||||
elif self.rpe == "implicit":
|
||||
pass
|
||||
rpe_bias_weight = (
|
||||
einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale
|
||||
) # (b hw) heads 1 pointnum
|
||||
dots = dots + rpe_bias_weight
|
||||
elif self.rpe == "head-wise" or self.rpe == "token-wise":
|
||||
dots = dots + rpe_bias
|
||||
|
||||
return self.attend(dots), dots
|
||||
|
||||
def forward(self, Q, K, V, rpe_bias=None):
|
||||
if self.use_rpe:
|
||||
if rpe_bias is None or self.rpe == "element-wise":
|
||||
rpe_bias = self.rpe_bias
|
||||
else:
|
||||
rpe_bias = rearrange(
|
||||
rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads
|
||||
)
|
||||
attn, dots = self.attend_with_rpe(Q, K, rpe_bias)
|
||||
else:
|
||||
attn, dots = self.attend_with_rpe(Q, K, None)
|
||||
B, HW, _ = Q.shape
|
||||
|
||||
if V is not None:
|
||||
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
||||
|
||||
out = einsum("bhij, bhjd -> bhid", attn, V)
|
||||
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
|
||||
else:
|
||||
out = None
|
||||
|
||||
# dots = torch.squeeze(dots, 2)
|
||||
# dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW)
|
||||
|
||||
return out, dots
|
||||
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
import numpy as np
|
||||
|
||||
|
||||
class twins_svt_large(nn.Module):
|
||||
def __init__(self, pretrained=True):
|
||||
super().__init__()
|
||||
self.svt = timm.create_model("twins_svt_large", pretrained=pretrained)
|
||||
|
||||
del self.svt.head
|
||||
del self.svt.patch_embeds[2]
|
||||
del self.svt.patch_embeds[2]
|
||||
del self.svt.blocks[2]
|
||||
del self.svt.blocks[2]
|
||||
del self.svt.pos_block[2]
|
||||
del self.svt.pos_block[2]
|
||||
self.svt.norm.weight.requires_grad = False
|
||||
self.svt.norm.bias.requires_grad = False
|
||||
|
||||
def forward(self, x, data=None, layer=2, return_feat=False):
|
||||
B = x.shape[0]
|
||||
if return_feat:
|
||||
feat = []
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
zip(
|
||||
self.svt.patch_embeds,
|
||||
self.svt.pos_drops,
|
||||
self.svt.blocks,
|
||||
self.svt.pos_block,
|
||||
)
|
||||
):
|
||||
x, size = embed(x)
|
||||
x = drop(x)
|
||||
for j, blk in enumerate(blocks):
|
||||
x = blk(x, size)
|
||||
if j == 0:
|
||||
x = pos_blk(x, size)
|
||||
if i < len(self.svt.depths) - 1:
|
||||
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
if return_feat:
|
||||
feat.append(x)
|
||||
if i == layer - 1:
|
||||
break
|
||||
if return_feat:
|
||||
return x, feat
|
||||
return x
|
||||
|
||||
def compute_params(self, layer=2):
|
||||
num = 0
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
zip(
|
||||
self.svt.patch_embeds,
|
||||
self.svt.pos_drops,
|
||||
self.svt.blocks,
|
||||
self.svt.pos_block,
|
||||
)
|
||||
):
|
||||
for param in embed.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
for param in drop.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
for param in blocks.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
for param in pos_blk.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
if i == layer - 1:
|
||||
break
|
||||
|
||||
for param in self.svt.head.parameters():
|
||||
num += np.prod(param.size())
|
||||
|
||||
return num
|
||||
|
||||
|
||||
class twins_svt_large_context(nn.Module):
|
||||
def __init__(self, pretrained=True):
|
||||
super().__init__()
|
||||
self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained)
|
||||
|
||||
def forward(self, x, data=None, layer=2):
|
||||
B = x.shape[0]
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
zip(
|
||||
self.svt.patch_embeds,
|
||||
self.svt.pos_drops,
|
||||
self.svt.blocks,
|
||||
self.svt.pos_block,
|
||||
)
|
||||
):
|
||||
x, size = embed(x)
|
||||
x = drop(x)
|
||||
for j, blk in enumerate(blocks):
|
||||
x = blk(x, size)
|
||||
if j == 0:
|
||||
x = pos_blk(x, size)
|
||||
if i < len(self.svt.depths) - 1:
|
||||
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
if i == layer - 1:
|
||||
break
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m = twins_svt_large()
|
||||
input = torch.randn(2, 3, 400, 800)
|
||||
out = m.extract_feature(input)
|
||||
print(out.shape)
|
||||
90
gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py
Normal file
90
gimm_vfi_arch/generalizable_INR/flowformer/core/corr.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels - 1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
def __call__(self, coords):
|
||||
r = self.radius
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
corr = bilinear_sampler(corr, coords_lvl)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
267
gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py
Normal file
267
gimm_vfi_arch/generalizable_INR/flowformer/core/extractor.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,100 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PositionEncodingSine(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256)):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model // 2, 2).float()
|
||||
* (-math.log(10000.0) / d_model // 2)
|
||||
)
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||
|
||||
self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
return x + self.pe[:, :, : x.size(2), : x.size(3)]
|
||||
|
||||
|
||||
class LinearPositionEncoding(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256)):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = (
|
||||
torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1
|
||||
) / max_shape[0]
|
||||
x_position = (
|
||||
torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1
|
||||
) / max_shape[1]
|
||||
div_term = torch.arange(0, d_model // 2, 2).float()
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi)
|
||||
|
||||
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
# assert x.shape[2] == 80 and x.shape[3] == 80
|
||||
|
||||
return x + self.pe[:, :, : x.size(2), : x.size(3)]
|
||||
|
||||
|
||||
class LearnedPositionEncoding(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(80, 80)):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
# assert x.shape[2] == 80 and x.shape[3] == 80
|
||||
|
||||
return x + self.pe
|
||||
154
gimm_vfi_arch/generalizable_INR/flowformer/core/update.py
Normal file
154
gimm_vfi_arch/generalizable_INR/flowformer/core/update.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convr1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
self.convq1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
||||
)
|
||||
|
||||
self.convz2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convr2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
self.convq2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
||||
)
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class SmallMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(SmallMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=96):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
self.encoder = SmallMotionEncoder(args)
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, None, delta_flow
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = 0.25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
113
gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py
Normal file
113
gimm_vfi_arch/generalizable_INR/flowformer/core/utils/utils.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class InputPadder:
|
||||
"""Pads images such that dimensions are divisible by 8"""
|
||||
|
||||
def __init__(self, dims, mode="sintel"):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
if mode == "sintel":
|
||||
self._pad = [
|
||||
pad_wd // 2,
|
||||
pad_wd - pad_wd // 2,
|
||||
pad_ht // 2,
|
||||
pad_ht - pad_ht // 2,
|
||||
]
|
||||
elif mode == "kitti400":
|
||||
self._pad = [0, 0, 0, 400 - self.ht]
|
||||
else:
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||
|
||||
def unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
dx, dy = flow[0], flow[1]
|
||||
|
||||
ht, wd = dx.shape
|
||||
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
|
||||
x1 = x0 + dx
|
||||
y1 = y0 + dy
|
||||
|
||||
x1 = x1.reshape(-1)
|
||||
y1 = y1.reshape(-1)
|
||||
dx = dx.reshape(-1)
|
||||
dy = dy.reshape(-1)
|
||||
|
||||
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||
x1 = x1[valid]
|
||||
y1 = y1[valid]
|
||||
dx = dx[valid]
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
|
||||
)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
|
||||
)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def indexing(img, coords, mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
"""
|
||||
TODO: directly indexing features instead of sampling
|
||||
"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True, mode="nearest")
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def upflow8(flow, mode="bilinear"):
|
||||
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
Reference in New Issue
Block a user