Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation

Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 13:11:45 +01:00
parent 3c3d4b2537
commit d642255e70
56 changed files with 9774 additions and 1 deletions

View File

@@ -0,0 +1,77 @@
from yacs.config import CfgNode as CN
_CN = CN()
_CN.name = ""
_CN.suffix = ""
_CN.gamma = 0.8
_CN.max_flow = 400
_CN.batch_size = 6
_CN.sum_freq = 100
_CN.val_freq = 5000000
_CN.image_size = [432, 960]
_CN.add_noise = False
_CN.critical_params = []
_CN.transformer = "latentcostformer"
_CN.model = "pretrained_ckpt/flowformer_sintel.pth"
# latentcostformer
_CN.latentcostformer = CN()
_CN.latentcostformer.pe = "linear"
_CN.latentcostformer.dropout = 0.0
_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
_CN.latentcostformer.query_latent_dim = 64
_CN.latentcostformer.cost_latent_input_dim = 64
_CN.latentcostformer.cost_latent_token_num = 8
_CN.latentcostformer.cost_latent_dim = 128
_CN.latentcostformer.arc_type = "transformer"
_CN.latentcostformer.cost_heads_num = 1
# encoder
_CN.latentcostformer.pretrain = True
_CN.latentcostformer.context_concat = False
_CN.latentcostformer.encoder_depth = 3
_CN.latentcostformer.feat_cross_attn = False
_CN.latentcostformer.patch_size = 8
_CN.latentcostformer.patch_embed = "single"
_CN.latentcostformer.no_pe = False
_CN.latentcostformer.gma = "GMA"
_CN.latentcostformer.kernel_size = 9
_CN.latentcostformer.rm_res = True
_CN.latentcostformer.vert_c_dim = 64
_CN.latentcostformer.cost_encoder_res = True
_CN.latentcostformer.cnet = "twins"
_CN.latentcostformer.fnet = "twins"
_CN.latentcostformer.no_sc = False
_CN.latentcostformer.only_global = False
_CN.latentcostformer.add_flow_token = True
_CN.latentcostformer.use_mlp = False
_CN.latentcostformer.vertical_conv = False
# decoder
_CN.latentcostformer.decoder_depth = 32
_CN.latentcostformer.critical_params = [
"cost_heads_num",
"vert_c_dim",
"cnet",
"pretrain",
"add_flow_token",
"encoder_depth",
"gma",
"cost_encoder_res",
]
### TRAINER
_CN.trainer = CN()
_CN.trainer.scheduler = "OneCycleLR"
_CN.trainer.optimizer = "adamw"
_CN.trainer.canonical_lr = 12.5e-5
_CN.trainer.adamw_decay = 1e-4
_CN.trainer.clip = 1.0
_CN.trainer.num_steps = 120000
_CN.trainer.epsilon = 1e-8
_CN.trainer.anneal_strategy = "linear"
def get_cfg():
return _CN.clone()

View File

@@ -0,0 +1,197 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops.layers.torch import Rearrange
from einops import rearrange
class BroadMultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(BroadMultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, _, _ = K.shape
_, N, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K):
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = (
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
) # (b hw) heads 1 pointnum
return self.attend(dots)
def forward(self, Q, K, V):
attn = self.attend_with_rpe(Q, K)
B, HW, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
return out
# class MultiHeadAttentionRelative_encoder(nn.Module):
# def __init__(self, dim, heads):
# super(MultiHeadAttentionRelative, self).__init__()
# self.dim = dim
# self.heads = heads
# self.scale = (dim/heads) ** -0.5
# self.attend = nn.Softmax(dim=-1)
# def attend_with_rpe(self, Q, K, Q_r, K_r):
# """
# Q: [BH1W1, H3W3, dim]
# K: [BH1W1, H3W3, dim]
# Q_r: [BH1W1, H3W3, H3W3, dim]
# K_r: [BH1W1, H3W3, H3W3, dim]
# """
# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim]
# # context-context similarity
# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3]
# # context-position similarity
# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3]
# # position-context similarity
# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:])
# p_c = torch.squeeze(p_c, dim=4)
# p_c = p_c.permute(0, 1, 3, 2)
# dots = c_c + c_p + p_c
# return self.attend(dots)
# def forward(self, Q, K, V, Q_r, K_r):
# attn = self.attend_with_rpe(Q, K, Q_r, K_r)
# B, HW, _ = Q.shape
# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads)
# out = einsum('bhij, bhjd -> bhid', attn, V)
# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW)
# return out
class MultiHeadAttentionRelative(nn.Module):
def __init__(self, dim, heads):
super(MultiHeadAttentionRelative, self).__init__()
self.dim = dim
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.attend = nn.Softmax(dim=-1)
def attend_with_rpe(self, Q, K, Q_r, K_r):
"""
Q: [BH1W1, 1, dim]
K: [BH1W1, H3W3, dim]
Q_r: [BH1W1, H3W3, dim]
K_r: [BH1W1, H3W3, dim]
"""
Q = rearrange(
Q, "b i (heads d) -> b heads i d", heads=self.heads
) # [BH1W1, heads, 1, dim]
K = rearrange(
K, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
K_r = rearrange(
K_r, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
Q_r = rearrange(
Q_r, "b j (heads d) -> b heads j d", heads=self.heads
) # [BH1W1, heads, H3W3, dim]
# context-context similarity
c_c = einsum("bhid, bhjd -> bhij", Q, K) * self.scale # [(B H1W1) heads 1 H3W3]
# context-position similarity
c_p = (
einsum("bhid, bhjd -> bhij", Q, K_r) * self.scale
) # [(B H1W1) heads 1 H3W3]
# position-context similarity
p_c = (
einsum("bhijd, bhikd -> bhijk", Q_r[:, :, :, None, :], K[:, :, :, None, :])
* self.scale
)
p_c = torch.squeeze(p_c, dim=4)
p_c = p_c.permute(0, 1, 3, 2)
dots = c_c + c_p + p_c
return self.attend(dots)
def forward(self, Q, K, V, Q_r, K_r):
attn = self.attend_with_rpe(Q, K, Q_r, K_r)
B, HW, _ = Q.shape
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
return out
def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
return torch.cat(
[
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
],
dim=-1,
)
def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1 / 200):
# 200 should be enough for a 8x downsampled image
# assume x to be [_, _, 2]
freq_bands = torch.linspace(0, dim // 4 - 1, dim // 4).to(x.device)
return torch.cat(
[
torch.sin(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.cos(x[..., -2:-1] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.sin(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
torch.cos(x[..., -1:] * (NORMALIZE_FACOR * 2**freq_bands)),
],
dim=-1,
)

View File

@@ -0,0 +1,649 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
import math
import numpy as np
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
mul = input_dim // 3
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64 * mul)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64 * mul)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64 * mul
self.layer1 = self._make_layer(64 * mul, stride=1)
self.layer2 = self._make_layer(96 * mul, stride=2)
self.layer3 = self._make_layer(128 * mul, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class ConvNets(nn.Module):
def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1):
super(ConvNets, self).__init__()
self.conv_first = nn.Conv2d(
in_dim, inter_dim, kernel_size=3, padding=1, stride=stride
)
self.conv_last = nn.Conv2d(
inter_dim, out_dim, kernel_size=3, padding=1, stride=stride
)
self.relu = nn.ReLU(inplace=True)
self.inter_convs = nn.ModuleList(
[
ResidualBlock(inter_dim, inter_dim, norm_fn="none", stride=1)
for i in range(depth)
]
)
def forward(self, x):
x = self.relu(self.conv_first(x))
for inter_conv in self.inter_convs:
x = inter_conv(x)
x = self.conv_last(x)
return x
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.motion_feature_dim
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicFuseMotion(nn.Module):
def __init__(self, args):
super(BasicFuseMotion, self).__init__()
cor_planes = args.motion_feature_dim
out_planes = args.query_latent_dim
self.normf1 = nn.InstanceNorm2d(128)
self.normf2 = nn.InstanceNorm2d(128)
self.convf1 = nn.Conv2d(2, 128, 3, padding=1)
self.convf2 = nn.Conv2d(128, 128, 3, padding=1)
self.convf3 = nn.Conv2d(128, 64, 3, padding=1)
s = 1
self.normc1 = nn.InstanceNorm2d(256 * s)
self.normc2 = nn.InstanceNorm2d(256 * s)
self.normc3 = nn.InstanceNorm2d(256 * s)
self.convc1 = nn.Conv2d(cor_planes + 128, 256 * s, 1, padding=0)
self.convc2 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.convc3 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.convc4 = nn.Conv2d(256 * s, 256 * s, 3, padding=1)
self.conv = nn.Conv2d(256 * s + 64, out_planes, 1, padding=0)
def forward(self, flow, feat, context1=None):
flo = F.relu(self.normf1(self.convf1(flow)))
flo = F.relu(self.normf2(self.convf2(flo)))
flo = self.convf3(flo)
feat = torch.cat([feat, context1], dim=1)
feat = F.relu(self.normc1(self.convc1(feat)))
feat = F.relu(self.normc2(self.convc2(feat)))
feat = F.relu(self.normc3(self.convc3(feat)))
feat = self.convc4(feat)
feat = torch.cat([flo, feat], dim=1)
feat = F.relu(self.conv(feat))
return feat
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
class DirectMeanMaskPredictor(nn.Module):
def __init__(self, args):
super(DirectMeanMaskPredictor, self).__init__()
self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(args.predictor_dim, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, motion_features):
delta_flow = self.flow_head(motion_features)
mask = 0.25 * self.mask(motion_features)
return mask, delta_flow
class BaiscMeanPredictor(nn.Module):
def __init__(self, args, hidden_dim=128):
super(BaiscMeanPredictor, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, latent, flow):
motion_features = self.encoder(flow, latent)
delta_flow = self.flow_head(motion_features)
mask = 0.25 * self.mask(motion_features)
return mask, delta_flow
class BasicRPEEncoder(nn.Module):
def __init__(self, args):
super(BasicRPEEncoder, self).__init__()
self.args = args
dim = args.query_latent_dim
self.encoder = nn.Sequential(
nn.Linear(2, dim // 2),
nn.ReLU(inplace=True),
nn.Linear(dim // 2, dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim),
)
def forward(self, rpe_tokens):
return self.encoder(rpe_tokens)
from .twins import Block, CrossBlock
class TwinsSelfAttentionLayer(nn.Module):
def __init__(self, args):
super(TwinsSelfAttentionLayer, self).__init__()
self.args = args
embed_dim = 256
num_heads = 8
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = 0.0
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
)
self.global_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward(self, x, tgt, size):
x = self.local_block(x, size)
x = self.global_block(x, size)
tgt = self.local_block(tgt, size)
tgt = self.global_block(tgt, size)
return x, tgt
class TwinsCrossAttentionLayer(nn.Module):
def __init__(self, args):
super(TwinsCrossAttentionLayer, self).__init__()
self.args = args
embed_dim = 256
num_heads = 8
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = 0.0
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
)
self.global_block = CrossBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward(self, x, tgt, size):
x = self.local_block(x, size)
tgt = self.local_block(tgt, size)
x, tgt = self.global_block(x, tgt, size)
return x, tgt

View File

@@ -0,0 +1,98 @@
#from turtle import forward
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class ConvNextLayer(nn.Module):
def __init__(self, dim, depth=4):
super().__init__()
self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)])
def forward(self, x):
return self.net(x)
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class ConvNextBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# print(f"conv next layer")
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + x
return x
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

View File

@@ -0,0 +1,316 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ...utils.utils import coords_grid, bilinear_sampler
from .attention import (
MultiHeadAttention,
LinearPositionEmbeddingSine,
ExpPositionEmbeddingSine,
)
from timm.models.layers import DropPath
from .gru import BasicUpdateBlock, GMAUpdateBlock
from .gma import Attention
def initialize_flow(img):
"""Flow is represented as difference between two means flow = mean1 - mean0"""
N, C, H, W = img.shape
mean = coords_grid(N, H, W).to(img.device)
mean_init = coords_grid(N, H, W).to(img.device)
# optical flow computed as difference: flow = mean1 - mean0
return mean, mean_init
class CrossAttentionLayer(nn.Module):
# def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.):
def __init__(
self,
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
add_flow_token=True,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
pe="linear",
):
super(CrossAttentionLayer, self).__init__()
head_dim = qk_dim // num_heads
self.scale = head_dim**-0.5
self.query_token_dim = query_token_dim
self.pe = pe
self.norm1 = nn.LayerNorm(query_token_dim)
self.norm2 = nn.LayerNorm(query_token_dim)
self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(query_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, v_dim, bias=True),
)
self.proj = nn.Linear(v_dim * 2, query_token_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(query_token_dim, query_token_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(query_token_dim, query_token_dim),
nn.Dropout(dropout),
)
self.add_flow_token = add_flow_token
self.dim = qk_dim
def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3):
"""
query_coord [B, 2, H1, W1]
"""
B, _, H1, W1 = query_coord.shape
if key is None and value is None:
key = self.k(memory)
value = self.v(memory)
# [B, 2, H1, W1] -> [BH1W1, 1, 2]
query_coord = query_coord.contiguous()
query_coord = (
query_coord.view(B, 2, -1)
.permute(0, 2, 1)[:, :, None, :]
.contiguous()
.view(B * H1 * W1, 1, 2)
)
if self.pe == "linear":
query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim)
elif self.pe == "exp":
query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim)
short_cut = query
query = self.norm1(query)
if self.add_flow_token:
q = self.q(query + query_coord_enc)
else:
q = self.q(query_coord_enc)
k, v = key, value
x = self.multi_head_attn(q, k, v)
x = self.proj(torch.cat([x, short_cut], dim=2))
x = short_cut + self.proj_drop(x)
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x, k, v
class MemoryDecoderLayer(nn.Module):
def __init__(self, dim, cfg):
super(MemoryDecoderLayer, self).__init__()
self.cfg = cfg
self.patch_size = cfg.patch_size # for converting coords into H2', W2' space
query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim
qk_dim, v_dim = query_token_dim, query_token_dim
self.cross_attend = CrossAttentionLayer(
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
add_flow_token=cfg.add_flow_token,
dropout=cfg.dropout,
)
def forward(self, query, key, value, memory, coords1, size, size_h3w3):
"""
x: [B*H1*W1, 1, C]
memory: [B*H1*W1, H2'*W2', C]
coords1 [B, 2, H2, W2]
size: B, C, H1, W1
1. Note that here coords0 and coords1 are in H2, W2 space.
Should first convert it into H2', W2' space.
2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0]
"""
x_global, k, v = self.cross_attend(
query, key, value, memory, coords1, self.patch_size, size_h3w3
)
B, C, H1, W1 = size
C = self.cfg.query_latent_dim
x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2)
return x_global, k, v
class ReverseCostExtractor(nn.Module):
def __init__(self, cfg):
super(ReverseCostExtractor, self).__init__()
self.cfg = cfg
def forward(self, cost_maps, coords0, coords1):
"""
cost_maps - B*H1*W1, cost_heads_num, H2, W2
coords - B, 2, H1, W1
"""
BH1W1, heads, H2, W2 = cost_maps.shape
B, _, H1, W1 = coords1.shape
assert (H1 == H2) and (W1 == W2)
assert BH1W1 == B * H1 * W1
cost_maps = cost_maps.reshape(B, H1 * W1 * heads, H2, W2)
coords = coords1.permute(0, 2, 3, 1)
corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2]
corr = rearrange(
corr,
"b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1",
b=B,
heads=heads,
h1=H1,
w1=W1,
h2=H2,
w2=W2,
)
r = 4
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device)
centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2)
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords = centroid + delta
corr = bilinear_sampler(corr, coords)
corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2)
return corr
class MemoryDecoder(nn.Module):
def __init__(self, cfg):
super(MemoryDecoder, self).__init__()
dim = self.dim = cfg.query_latent_dim
self.cfg = cfg
self.flow_token_encoder = nn.Sequential(
nn.Conv2d(81 * cfg.cost_heads_num, dim, 1, 1),
nn.GELU(),
nn.Conv2d(dim, dim, 1, 1),
)
self.proj = nn.Conv2d(256, 256, 1)
self.depth = cfg.decoder_depth
self.decoder_layer = MemoryDecoderLayer(dim, cfg)
if self.cfg.gma:
self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128)
self.att = Attention(
args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128
)
else:
self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128)
def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def encode_flow_token(self, cost_maps, coords):
"""
cost_maps - B*H1*W1, cost_heads_num, H2, W2
coords - B, 2, H1, W1
"""
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
r = 4
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid = coords.reshape(batch * h1 * w1, 1, 1, 2)
delta = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords = centroid + delta
corr = bilinear_sampler(cost_maps, coords)
corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2)
return corr
def forward(self, cost_memory, context, data={}, flow_init=None, iters=None):
"""
memory: [B*H1*W1, H2'*W2', C]
context: [B, D, H1, W1]
"""
cost_maps = data["cost_maps"]
coords0, coords1 = initialize_flow(context)
if flow_init is not None:
# print("[Using warm start]")
coords1 = coords1 + flow_init
# flow = coords1
flow_predictions = []
context = self.proj(context)
net, inp = torch.split(context, [128, 128], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
if self.cfg.gma:
attention = self.att(inp)
size = net.shape
key, value = None, None
if iters is None:
iters = self.depth
for idx in range(iters):
coords1 = coords1.detach()
cost_forward = self.encode_flow_token(cost_maps, coords1)
# cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1)
query = self.flow_token_encoder(cost_forward)
query = (
query.permute(0, 2, 3, 1)
.contiguous()
.view(size[0] * size[2] * size[3], 1, self.dim)
)
cost_global, key, value = self.decoder_layer(
query, key, value, cost_memory, coords1, size, data["H3W3"]
)
if self.cfg.only_global:
corr = cost_global
else:
corr = torch.cat([cost_global, cost_forward], dim=1)
flow = coords1 - coords0
if self.cfg.gma:
net, up_mask, delta_flow = self.update_block(
net, inp, corr, flow, attention
)
else:
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# flow = delta_flow
coords1 = coords1 + delta_flow
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
# if self.training:
# return flow_predictions
# else:
return flow_predictions[-1], coords1 - coords0

View File

@@ -0,0 +1,534 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
import numpy as np
from einops import rearrange
from ...utils.utils import coords_grid
from .attention import (
BroadMultiHeadAttention,
MultiHeadAttention,
LinearPositionEmbeddingSine,
ExpPositionEmbeddingSine,
)
from ..encoders import twins_svt_large
from typing import Tuple
from .twins import Size_
from .cnn import BasicEncoder
from .mlpmixer import MLPMixerLayer
from .convnext import ConvNextLayer
from timm.models.layers import DropPath
class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe="linear"):
super().__init__()
self.patch_size = patch_size
self.dim = embed_dim
self.pe = pe
# assert patch_size == 8
if patch_size == 8:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(
embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2
),
nn.ReLU(),
nn.Conv2d(
embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2
),
)
elif patch_size == 4:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(
embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2
),
)
else:
print(f"patch size = {patch_size} is unacceptable.")
self.ffn_with_coord = nn.Sequential(
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
nn.ReLU(),
nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1),
)
self.norm = nn.LayerNorm(embed_dim * 2)
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
B, C, H, W = x.shape # C == 1
pad_l = pad_t = 0
pad_r = (self.patch_size - W % self.patch_size) % self.patch_size
pad_b = (self.patch_size - H % self.patch_size) % self.patch_size
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
x = self.proj(x)
out_size = x.shape[2:]
patch_coord = (
coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size
+ self.patch_size / 2
) # in feature coordinate space
patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1)
if self.pe == "linear":
patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim)
elif self.pe == "exp":
patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim)
patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(
B, -1, out_size[0], out_size[1]
)
x_pe = torch.cat([x, patch_coord_enc], dim=1)
x = self.ffn_with_coord(x_pe)
x = self.norm(x.flatten(2).transpose(1, 2))
return x, out_size
from .twins import Block, CrossBlock
class GroupVerticalSelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(GroupVerticalSelfAttentionLayer, self).__init__()
self.cfg = cfg
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
embed_dim = dim
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = dropout
attn_drop_rate = 0.0
self.block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
groupattention=True,
cfg=self.cfg,
)
def forward(self, x, size, context=None):
x = self.block(x, size, context)
return x
class VerticalSelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(VerticalSelfAttentionLayer, self).__init__()
self.cfg = cfg
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
embed_dim = dim
mlp_ratio = 4
ws = 7
sr_ratio = 4
dpr = 0.0
drop_rate = dropout
attn_drop_rate = 0.0
self.local_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=ws,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
)
self.global_block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
sr_ratio=sr_ratio,
ws=1,
with_rpe=True,
vert_c_dim=cfg.vert_c_dim,
)
def forward(self, x, size, context=None):
x = self.local_block(x, size, context)
x = self.global_block(x, size, context)
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class SelfAttentionLayer(nn.Module):
def __init__(
self,
dim,
cfg,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(SelfAttentionLayer, self).__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.multi_head_attn = MultiHeadAttention(dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(dim, dim, bias=True),
nn.Linear(dim, dim, bias=True),
nn.Linear(dim, dim, bias=True),
)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = x
x = self.norm1(x)
q, k, v = self.q(x), self.k(x), self.v(x)
x = self.multi_head_attn(q, k, v)
x = self.proj(x)
x = short_cut + self.proj_drop(x)
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
def compute_params(self):
num = 0
for param in self.parameters():
num += np.prod(param.size())
return num
class CrossAttentionLayer(nn.Module):
def __init__(
self,
qk_dim,
v_dim,
query_token_dim,
tgt_token_dim,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
dropout=0.0,
):
super(CrossAttentionLayer, self).__init__()
assert (
qk_dim % num_heads == 0
), f"dim {qk_dim} should be divided by num_heads {num_heads}."
assert (
v_dim % num_heads == 0
), f"dim {v_dim} should be divided by num_heads {num_heads}."
"""
Query Token: [N, C] -> [N, qk_dim] (Q)
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
"""
self.num_heads = num_heads
head_dim = qk_dim // num_heads
self.scale = head_dim**-0.5
self.norm1 = nn.LayerNorm(query_token_dim)
self.norm2 = nn.LayerNorm(query_token_dim)
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
self.q, self.k, self.v = (
nn.Linear(query_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, qk_dim, bias=True),
nn.Linear(tgt_token_dim, v_dim, bias=True),
)
self.proj = nn.Linear(v_dim, query_token_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ffn = nn.Sequential(
nn.Linear(query_token_dim, query_token_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(query_token_dim, query_token_dim),
nn.Dropout(dropout),
)
def forward(self, query, tgt_token):
"""
x: [BH1W1, H3W3, D]
"""
short_cut = query
query = self.norm1(query)
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
x = self.multi_head_attn(q, k, v)
x = short_cut + self.proj_drop(self.proj(x))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class CostPerceiverEncoder(nn.Module):
def __init__(self, cfg):
super(CostPerceiverEncoder, self).__init__()
self.cfg = cfg
self.patch_size = cfg.patch_size
self.patch_embed = PatchEmbed(
in_chans=self.cfg.cost_heads_num,
patch_size=self.patch_size,
embed_dim=cfg.cost_latent_input_dim,
pe=cfg.pe,
)
self.depth = cfg.encoder_depth
self.latent_tokens = nn.Parameter(
torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)
)
query_token_dim, tgt_token_dim = (
cfg.cost_latent_dim,
cfg.cost_latent_input_dim * 2,
)
qk_dim, v_dim = query_token_dim, query_token_dim
self.input_layer = CrossAttentionLayer(
qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout
)
if cfg.use_mlp:
self.encoder_layers = nn.ModuleList(
[
MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
for idx in range(self.depth)
]
)
else:
self.encoder_layers = nn.ModuleList(
[
SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout)
for idx in range(self.depth)
]
)
if self.cfg.vertical_conv:
self.vertical_encoder_layers = nn.ModuleList(
[ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]
)
else:
self.vertical_encoder_layers = nn.ModuleList(
[
VerticalSelfAttentionLayer(
cfg.cost_latent_dim, cfg, dropout=cfg.dropout
)
for idx in range(self.depth)
]
)
self.cost_scale_aug = None
if "cost_scale_aug" in cfg.keys():
self.cost_scale_aug = cfg.cost_scale_aug
print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug))
def forward(self, cost_volume, data, context=None):
B, heads, H1, W1, H2, W2 = cost_volume.shape
cost_maps = (
cost_volume.permute(0, 2, 3, 1, 4, 5)
.contiguous()
.view(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
)
data["cost_maps"] = cost_maps
if self.cost_scale_aug is not None:
scale_factor = (
torch.FloatTensor(B * H1 * W1, self.cfg.cost_heads_num, H2, W2)
.uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1])
.to(cost_maps.device)
)
cost_maps = cost_maps * scale_factor
x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C
data["H3W3"] = size
H3, W3 = size
x = self.input_layer(self.latent_tokens, x)
short_cut = x
for idx, layer in enumerate(self.encoder_layers):
x = layer(x)
if self.cfg.vertical_conv:
# B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1
x = (
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
.permute(0, 3, 1, 2)
.reshape(B * self.cfg.cost_latent_token_num, -1, H1, W1)
)
x = self.vertical_encoder_layers[idx](x)
# B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D
x = (
x.view(B, self.cfg.cost_latent_token_num, -1, H1 * W1)
.permute(0, 2, 3, 1)
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
)
else:
x = (
x.view(B, H1 * W1, self.cfg.cost_latent_token_num, -1)
.permute(0, 2, 1, 3)
.reshape(B * self.cfg.cost_latent_token_num, H1 * W1, -1)
)
x = self.vertical_encoder_layers[idx](x, (H1, W1), context)
x = (
x.view(B, self.cfg.cost_latent_token_num, H1 * W1, -1)
.permute(0, 2, 1, 3)
.reshape(B * H1 * W1, self.cfg.cost_latent_token_num, -1)
)
if self.cfg.cost_encoder_res is True:
x = x + short_cut
# print("~~~~")
return x
class MemoryEncoder(nn.Module):
def __init__(self, cfg):
super(MemoryEncoder, self).__init__()
self.cfg = cfg
if cfg.fnet == "twins":
self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
elif cfg.fnet == "basicencoder":
self.feat_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
else:
exit()
self.channel_convertor = nn.Conv2d(
cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False
)
self.cost_perceiver_encoder = CostPerceiverEncoder(cfg)
def corr(self, fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = rearrange(
fmap1, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
)
fmap2 = rearrange(
fmap2, "b (heads d) h w -> b heads (h w) d", heads=self.cfg.cost_heads_num
)
corr = einsum("bhid, bhjd -> bhij", fmap1, fmap2)
corr = corr.permute(0, 2, 1, 3).view(
batch * ht * wd, self.cfg.cost_heads_num, ht, wd
)
# corr = self.norm(self.relu(corr))
corr = corr.view(batch, ht * wd, self.cfg.cost_heads_num, ht * wd).permute(
0, 2, 1, 3
)
corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd)
return corr
def forward(self, img1, img2, data, context=None, return_feat=False):
# The original implementation
# feat_s = self.feat_encoder(img1)
# feat_t = self.feat_encoder(img2)
# feat_s = self.channel_convertor(feat_s)
# feat_t = self.channel_convertor(feat_t)
imgs = torch.cat([img1, img2], dim=0)
feats = self.feat_encoder(imgs)
feats = self.channel_convertor(feats)
B = feats.shape[0] // 2
feat_s = feats[:B]
if return_feat:
ffeat = feats[:B]
feat_t = feats[B:]
B, C, H, W = feat_s.shape
size = (H, W)
if self.cfg.feat_cross_attn:
feat_s = feat_s.flatten(2).transpose(1, 2)
feat_t = feat_t.flatten(2).transpose(1, 2)
for layer in self.layers:
feat_s, feat_t = layer(feat_s, feat_t, size)
feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
cost_volume = self.corr(feat_s, feat_t)
x = self.cost_perceiver_encoder(cost_volume, data, context)
if return_feat:
return x, ffeat
return x

View File

@@ -0,0 +1,123 @@
import torch
from torch import nn, einsum
from einops import rearrange
class RelPosEmb(nn.Module):
def __init__(self, max_pos_size, dim_head):
super().__init__()
self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(
max_pos_size
).view(-1, 1)
rel_ind = deltas + max_pos_size - 1
self.register_buffer("rel_ind", rel_ind)
def forward(self, q):
batch, heads, h, w, c = q.shape
height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h)
width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w)
height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb)
width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb)
return height_score + width_score
class Attention(nn.Module):
def __init__(
self,
*,
args,
dim,
max_pos_size=100,
heads=4,
dim_head=128,
):
super().__init__()
self.args = args
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
self.pos_emb = RelPosEmb(max_pos_size, dim_head)
for param in self.pos_emb.parameters():
param.requires_grad = False
def forward(self, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
q, k = self.to_qk(fmap).chunk(2, dim=1)
q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
q = self.scale * q
# if self.args.position_only:
# sim = self.pos_emb(q)
# elif self.args.position_and_content:
# sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
# sim_pos = self.pos_emb(q)
# sim = sim_content + sim_pos
# else:
sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
attn = sim.softmax(dim=-1)
return attn
class Aggregate(nn.Module):
def __init__(
self,
args,
dim,
heads=4,
dim_head=128,
):
super().__init__()
self.args = args
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.gamma = nn.Parameter(torch.zeros(1))
if dim != inner_dim:
self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
else:
self.project = None
def forward(self, attn, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
v = self.to_v(fmap)
v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
if self.project is not None:
out = self.project(out)
out = fmap + self.gamma * out
return out
if __name__ == "__main__":
att = Attention(dim=128, heads=1)
fmap = torch.randn(2, 128, 40, 90)
out = att(fmap)
print(out.shape)

View File

@@ -0,0 +1,160 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
if args.only_global:
print("[Decoding with only global cost]")
cor_planes = args.query_latent_dim
else:
cor_planes = 81 + args.query_latent_dim
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
from .gma import Aggregate
class GMAUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128):
super().__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(
hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim
)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1)
def forward(self, net, inp, corr, flow, attention):
motion_features = self.encoder(flow, corr)
motion_features_global = self.aggregator(attention, motion_features)
inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
# Attentional update
net = self.gru(net, inp_cat)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,55 @@
from torch import nn
from einops.layers.torch import Rearrange, Reduce
from functools import partial
import numpy as np
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout),
)
class MLPMixerLayer(nn.Module):
def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0):
super(MLPMixerLayer, self).__init__()
# print(f"use mlp mixer layer")
K = cfg.cost_latent_token_num
expansion_factor = cfg.mlp_expansion_factor
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
self.mlpmixer = nn.Sequential(
PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)),
PreNormResidual(
dim, FeedForward(dim, expansion_factor, dropout, chan_last)
),
)
def compute_params(self):
num = 0
for param in self.mlpmixer.parameters():
num += np.prod(param.size())
return num
def forward(self, x):
"""
x: [BH1W1, K, D]
"""
return self.mlpmixer(x)

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from ...utils.utils import coords_grid
from ..encoders import twins_svt_large
from .encoder import MemoryEncoder
from .decoder import MemoryDecoder
from .cnn import BasicEncoder
class FlowFormer(nn.Module):
def __init__(self, cfg):
super(FlowFormer, self).__init__()
self.cfg = cfg
self.memory_encoder = MemoryEncoder(cfg)
self.memory_decoder = MemoryDecoder(cfg)
if cfg.cnet == "twins":
self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
elif cfg.cnet == "basicencoder":
self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
def build_coord(self, img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8)
return coords
def forward(
self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None
):
# Following https://github.com/princeton-vl/RAFT/
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
data = {}
if self.cfg.context_concat:
context = self.context_encoder(torch.cat([image1, image2], dim=1))
else:
if return_feat:
context, cfeat = self.context_encoder(image1, return_feat=return_feat)
else:
context = self.context_encoder(image1)
if return_feat:
cost_memory, ffeat = self.memory_encoder(
image1, image2, data, context, return_feat=return_feat
)
else:
cost_memory = self.memory_encoder(image1, image2, data, context)
flow_predictions = self.memory_decoder(
cost_memory, context, data, flow_init=flow_init, iters=iters
)
if return_feat:
return flow_predictions, cfeat, ffeat
return flow_predictions

View File

@@ -0,0 +1,7 @@
def build_flowformer(cfg):
name = cfg.transformer
if name == "latentcostformer":
from .LatentCostFormer.transformer import FlowFormer
else:
raise ValueError(f"FlowFormer = {name} is not a valid architecture!")
return FlowFormer(cfg[name])

View File

@@ -0,0 +1,562 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from ..utils.utils import bilinear_sampler, indexing
def nerf_encoding(x, L=6, NORMALIZE_FACOR=1 / 300):
"""
x is of shape [*, 2]. The last dimension are two coordinates (x and y).
"""
freq_bands = 2.0 ** torch.linspace(0, L, L - 1).to(x.device)
return torch.cat(
[
x * NORMALIZE_FACOR,
torch.sin(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -2:-1] * freq_bands * NORMALIZE_FACOR),
torch.sin(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
torch.cos(3.14 * x[..., -1:] * freq_bands * NORMALIZE_FACOR),
],
dim=-1,
)
def sampler_gaussian(latent, mean, std, image_size, point_num=25):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
* STD_MAX
* delta
* 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
return sampled_latents, sampled_weights
def sampler_gaussian_zy(
latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta
if return_deltaXY:
return sampled_latents, sampled_weights, delta_3sigma
else:
return sampled_latents, sampled_weights
def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# std [B, 1, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
F.sigmoid(std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1))
* STD_MAX
* delta
* 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta_3sigma
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1))
if return_deltaXY:
return sampled_latents, sampled_weights, delta_3sigma
else:
return sampled_latents, sampled_weights
def sampler_gaussian_fix(latent, mean, image_size, point_num=49):
# latent [B, H*W, D]
# mean [B, 2, H, W]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
latent, coords
) # [B*H*W, dim, point_num**0.5, point_num**0.5]
sampled_latents = sampled_latents.permute(0, 2, 3, 1)
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return sampled_latents, sampled_weights
def sampler_gaussian_fix_pyramid(
latent, feat_pyramid, scale_weight, mean, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# scale weight [B, H*W, layer_num]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
for i in range(len(feat_pyramid)):
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = (centroid + delta) / 2**i
coords = rearrange(
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
)
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W, point_num, dim, layer_num]
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
vis_out = scale_weight
scale_weight = torch.unsqueeze(
torch.unsqueeze(scale_weight, dim=2), dim=2
) # [B, HW, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_weight, dim=-1
) # [B, H*W, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights, vis_out
def sampler_gaussian_pyramid(
latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W]
# scale weight [B, H*W, layer_num]
H, W = image_size
B, HW, D = latent.shape
STD_MAX = 20
latent = rearrange(
latent, "b (h w) c -> b c h w", h=H, w=W
) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2)
mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-1, 1, int(point_num**0.5))
dy = torch.linspace(-1, 1, int(point_num**0.5))
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
delta_3sigma = (
std.permute(0, 2, 3, 1).reshape(B * HW, 1, 1, 1) * delta * 3
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
for i in range(len(feat_pyramid)):
centroid = mean.reshape(B * H * W, 1, 1, 2)
coords = (centroid + delta_3sigma) / 2**i
coords = rearrange(
coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W
)
sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords))
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W, point_num, dim, layer_num]
scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num]
vis_out = scale_weight
scale_weight = torch.unsqueeze(
torch.unsqueeze(scale_weight, dim=2), dim=2
) # [B, HW, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_weight, dim=-1
) # [B, H*W, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights, vis_out
def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25):
"""different heads have different mean"""
# latent [B, H*W, D]
# mean [B, 2, H, W, heands]
H, W = image_size
B, HW, D = latent.shape
_, _, _, _, HEADS = mean.shape
STD_MAX = 20
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = (
torch.stack(torch.meshgrid(dy, dx), axis=-1)
.to(mean.device)
.repeat(HEADS, 1, 1, 1)
) # [HEADS, point_num**0.5, point_num**0.5, 2]
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
coords = centroid + delta
coords = rearrange(
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
)
sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum]
sampled_latents = sampled_latents.permute(
0, 2, 3, 1
) # [B, H*W*HEADS, pointnum, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return sampled_latents, sampled_weights
def sampler_gaussian_fix_pyramid_MH(
latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25
):
# latent [B, H*W, D]
# mean [B, 2, H, W, heands]
# scale_head weight [B, H*W, layer_num*heads]
H, W = image_size
B, HW, D = latent.shape
_, _, _, _, HEADS = mean.shape
latent = rearrange(latent, "b (h w) c -> b c h w", h=H, w=W)
mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2]
radius = int((int(point_num**0.5) - 1) / 2)
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
mean.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
sampled_latents = []
centroid = mean.reshape(B * H * W, HEADS, 1, 1, 2)
for i in range(len(feat_pyramid)):
coords = (centroid) / 2**i + delta
coords = rearrange(
coords, "(b h w) H r1 r2 c -> b (h w H) (r1 r2) c", b=B, h=H, w=W, H=HEADS
)
sampled_latents.append(
bilinear_sampler(feat_pyramid[i], coords)
) # [B, dim, H*W*HEADS, point_num]
sampled_latents = torch.stack(
sampled_latents, dim=1
) # [B, layer_num, dim, H*W*HEADS, point_num]
sampled_latents = sampled_latents.permute(
0, 3, 4, 2, 1
) # [B, H*W*HEADS, point_num, dim, layer_num]
scale_head_weight = scale_head_weight.reshape(B, H * W * HEADS, -1)
scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num]
scale_head_weight = torch.unsqueeze(
torch.unsqueeze(scale_head_weight, dim=2), dim=2
) # [B, H*W*HEADS, 1, 1, layer_num]
weighted_latent = torch.sum(
sampled_latents * scale_head_weight, dim=-1
) # [B, H*W*HEADS, point_num, dim]
sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term
return weighted_latent, sampled_weights
def sampler(feat, center, window_size):
# feat [B, C, H, W]
# center [B, 2, H, W]
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
B, H, W, C = center.shape
radius = window_size // 2
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
center.device
) # [B*H*W, window_size, point_num**0.5, 2]
center = center.reshape(B * H * W, 1, 1, 2)
coords = center + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
sampled_latents = bilinear_sampler(
feat, coords
) # [B*H*W, dim, window_size, window_size]
# sampled_latents = sampled_latents.permute(0, 2, 3, 1)
return sampled_latents
def retrieve_tokens(feat, center, window_size, sampler):
# feat [B, C, H, W]
# center [B, 2, H, W]
radius = window_size // 2
dx = torch.linspace(-radius, radius, 2 * radius + 1)
dy = torch.linspace(-radius, radius, 2 * radius + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(
center.device
) # [B*H*W, point_num**0.5, point_num**0.5, 2]
B, H, W, C = center.shape
centroid = center.reshape(B * H * W, 1, 1, 2)
coords = centroid + delta
coords = rearrange(coords, "(b h w) r1 r2 c -> b (h w) (r1 r2) c", b=B, h=H, w=W)
if sampler == "nn":
sampled_latents = indexing(feat, coords)
elif sampler == "bilinear":
sampled_latents = bilinear_sampler(feat, coords)
else:
raise ValueError("invalid sampler")
# [B, dim, H*W, point_num]
return sampled_latents
def pyramid_retrieve_tokens(
feat_pyramid, center, image_size, window_sizes, sampler="bilinear"
):
center = center.permute(0, 2, 3, 1) # [B, H, W, 2]
sampled_latents_pyramid = []
for idx in range(len(window_sizes)):
sampled_latents_pyramid.append(
retrieve_tokens(feat_pyramid[idx], center, window_sizes[idx], sampler)
)
center = center / 2
return torch.cat(sampled_latents_pyramid, dim=-1)
class FeedForward(nn.Module):
def __init__(self, dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
x = self.net(x)
return x
class MLP(nn.Module):
def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5):
super().__init__()
self.FC1 = nn.Linear(in_dim, innter_dim)
self.FC_out = nn.Linear(innter_dim, out_dim)
self.relu = torch.nn.LeakyReLU(0.2)
self.FC_inter = nn.ModuleList(
[nn.Linear(innter_dim, innter_dim) for i in range(depth)]
)
def forward(self, x):
x = self.FC1(x)
x = self.relu(x)
for inter_fc in self.FC_inter:
x = inter_fc(x)
x = self.relu(x)
x = self.FC_out(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.heads = heads
self.num_kv_tokens = num_kv_tokens
self.scale = (dim / heads) ** -0.5
self.rpe = cfg.rpe
self.attend = nn.Softmax(dim=-1)
self.use_rpe = use_rpe
if use_rpe:
if rpe_bias is None:
if self.rpe == "element-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(heads, self.num_kv_tokens, dim // heads)
)
elif self.rpe == "head-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(1, heads, 1, self.num_kv_tokens)
)
elif self.rpe == "token-wise":
self.rpe_bias = nn.Parameter(
torch.zeros(1, 1, 1, self.num_kv_tokens)
) # 81 is point_num
elif self.rpe == "implicit":
pass
# self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2)
# raise ValueError('Implicit Encoding Not Implemented')
elif self.rpe == "element-wise-value":
self.rpe_bias = nn.Parameter(
torch.zeros(heads, self.num_kv_tokens, dim // heads)
)
self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim))
else:
raise ValueError("Not Implemented")
else:
self.rpe_bias = rpe_bias
def attend_with_rpe(self, Q, K, rpe_bias):
Q = rearrange(Q, "b i (heads d) -> b heads i d", heads=self.heads)
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
dots = (
einsum("bhid, bhjd -> bhij", Q, K) * self.scale
) # (b hw) heads 1 pointnum
if self.use_rpe:
if self.rpe == "element-wise":
rpe_bias_weight = (
einsum("bhid, hjd -> bhij", Q, rpe_bias) * self.scale
) # (b hw) heads 1 pointnum
dots = dots + rpe_bias_weight
elif self.rpe == "implicit":
pass
rpe_bias_weight = (
einsum("bhid, bhjd -> bhij", Q, rpe_bias) * self.scale
) # (b hw) heads 1 pointnum
dots = dots + rpe_bias_weight
elif self.rpe == "head-wise" or self.rpe == "token-wise":
dots = dots + rpe_bias
return self.attend(dots), dots
def forward(self, Q, K, V, rpe_bias=None):
if self.use_rpe:
if rpe_bias is None or self.rpe == "element-wise":
rpe_bias = self.rpe_bias
else:
rpe_bias = rearrange(
rpe_bias, "b hw pn (heads d) -> (b hw) heads pn d", heads=self.heads
)
attn, dots = self.attend_with_rpe(Q, K, rpe_bias)
else:
attn, dots = self.attend_with_rpe(Q, K, None)
B, HW, _ = Q.shape
if V is not None:
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
out = einsum("bhij, bhjd -> bhid", attn, V)
out = rearrange(out, "b heads hw d -> b hw (heads d)", b=B, hw=HW)
else:
out = None
# dots = torch.squeeze(dots, 2)
# dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW)
return out, dots

View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import timm
import numpy as np
class twins_svt_large(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model("twins_svt_large", pretrained=pretrained)
del self.svt.head
del self.svt.patch_embeds[2]
del self.svt.patch_embeds[2]
del self.svt.blocks[2]
del self.svt.blocks[2]
del self.svt.pos_block[2]
del self.svt.pos_block[2]
self.svt.norm.weight.requires_grad = False
self.svt.norm.bias.requires_grad = False
def forward(self, x, data=None, layer=2, return_feat=False):
B = x.shape[0]
if return_feat:
feat = []
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if return_feat:
feat.append(x)
if i == layer - 1:
break
if return_feat:
return x, feat
return x
def compute_params(self, layer=2):
num = 0
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
for param in embed.parameters():
num += np.prod(param.size())
for param in drop.parameters():
num += np.prod(param.size())
for param in blocks.parameters():
num += np.prod(param.size())
for param in pos_blk.parameters():
num += np.prod(param.size())
if i == layer - 1:
break
for param in self.svt.head.parameters():
num += np.prod(param.size())
return num
class twins_svt_large_context(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained)
def forward(self, x, data=None, layer=2):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(
self.svt.patch_embeds,
self.svt.pos_drops,
self.svt.blocks,
self.svt.pos_block,
)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i == layer - 1:
break
return x
if __name__ == "__main__":
m = twins_svt_large()
input = torch.randn(2, 3, 400, 800)
out = m.extract_feature(input)
print(out.shape)

View File

@@ -0,0 +1,90 @@
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@@ -0,0 +1,267 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

View File

@@ -0,0 +1,100 @@
import math
import torch
from torch import nn
class PositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(
torch.arange(0, d_model // 2, 2).float()
* (-math.log(10000.0) / d_model // 2)
)
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)
self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, : x.size(2), : x.size(3)]
class LinearPositionEncoding(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = (
torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1
) / max_shape[0]
x_position = (
torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1
) / max_shape[1]
div_term = torch.arange(0, d_model // 2, 2).float()
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi)
pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi)
pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi)
pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi)
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
# assert x.shape[2] == 80 and x.shape[3] == 80
return x + self.pe[:, :, : x.size(2), : x.size(3)]
class LearnedPositionEncoding(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(80, 80)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model))
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
# assert x.shape[2] == 80 and x.shape[3] == 80
return x + self.pe

View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,113 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
elif mode == "kitti400":
self._pad = [0, 0, 0, 400 - self.ht]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def indexing(img, coords, mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
"""
TODO: directly indexing features instead of sampling
"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True, mode="nearest")
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode="bilinear"):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)