Add SGM-VFI (CVPR 2024) frame interpolation support
SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
5
sgm_vfi_arch/__init__.py
Normal file
5
sgm_vfi_arch/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .feature_extractor import feature_extractor
|
||||
from .flow_estimation import MultiScaleFlow
|
||||
from .warplayer import clear_warp_cache
|
||||
|
||||
__all__ = ['feature_extractor', 'MultiScaleFlow', 'clear_warp_cache']
|
||||
116
sgm_vfi_arch/backbone.py
Normal file
116
sgm_vfi_arch/backbone.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from .trident_conv import MultiScaleTridentConv
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
|
||||
dilation=dilation, padding=dilation, stride=stride, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
||||
dilation=dilation, padding=dilation, bias=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.norm1 = norm_layer(planes)
|
||||
self.norm2 = norm_layer(planes)
|
||||
if not stride == 1 or in_planes != planes:
|
||||
self.norm3 = norm_layer(planes)
|
||||
|
||||
if stride == 1 and in_planes == planes:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class CNNEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128,
|
||||
norm_layer=nn.InstanceNorm2d,
|
||||
num_output_scales=1,
|
||||
**kwargs,
|
||||
):
|
||||
super(CNNEncoder, self).__init__()
|
||||
self.num_branch = num_output_scales
|
||||
|
||||
feature_dims = [64, 96, 128]
|
||||
|
||||
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
|
||||
self.norm1 = norm_layer(feature_dims[0])
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = feature_dims[0]
|
||||
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
|
||||
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
|
||||
|
||||
# highest resolution 1/4 or 1/8
|
||||
stride = 2 if num_output_scales == 1 else 1
|
||||
self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer,
|
||||
) # 1/4 or 1/8
|
||||
|
||||
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
||||
|
||||
if self.num_branch > 1:
|
||||
if self.num_branch == 4:
|
||||
strides = (1, 2, 4, 8)
|
||||
elif self.num_branch == 3:
|
||||
strides = (1, 2, 4)
|
||||
elif self.num_branch == 2:
|
||||
strides = (1, 2)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
|
||||
kernel_size=3,
|
||||
strides=strides,
|
||||
paddings=1,
|
||||
num_branch=self.num_branch,
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
|
||||
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
|
||||
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x) # 1/2
|
||||
x = self.layer2(x) # 1/4
|
||||
x = self.layer3(x) # 1/8 or 1/4
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.num_branch > 1:
|
||||
out = self.trident_conv([x] * self.num_branch) # high to low res
|
||||
else:
|
||||
out = [x]
|
||||
|
||||
return out
|
||||
459
sgm_vfi_arch/feature_extractor.py
Normal file
459
sgm_vfi_arch/feature_extractor.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from .position import PositionEmbeddingSine
|
||||
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||
windows = (
|
||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], C)
|
||||
)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
nwB, N, C = windows.shape
|
||||
windows = windows.view(-1, window_size[0], window_size[1], C)
|
||||
B = int(nwB / (H * W / window_size[0] / window_size[1]))
|
||||
x = windows.view(
|
||||
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
def pad_if_needed(x, size, window_size):
|
||||
n, h, w, c = size
|
||||
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
||||
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
||||
if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
|
||||
img_mask = torch.zeros((1, h + pad_h, w + pad_w, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, pad_h // 2),
|
||||
slice(pad_h // 2, h + pad_h // 2),
|
||||
slice(h + pad_h // 2, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, pad_w // 2),
|
||||
slice(pad_w // 2, w + pad_w // 2),
|
||||
slice(w + pad_w // 2, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, window_size
|
||||
) # nW, window_size*window_size, 1
|
||||
mask_windows = mask_windows.squeeze(-1)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(
|
||||
attn_mask != 0, float(-100.0)
|
||||
).masked_fill(attn_mask == 0, float(0.0))
|
||||
return nn.functional.pad(
|
||||
x,
|
||||
(0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
|
||||
), attn_mask
|
||||
return x, None
|
||||
|
||||
|
||||
def depad_if_needed(x, size, window_size):
|
||||
n, h, w, c = size
|
||||
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
||||
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
||||
if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
|
||||
return x[:, pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class InterFrameAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x1, x2, H, W, mask=None):
|
||||
B, N, C = x1.shape
|
||||
q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0] # mask: nW, N, N
|
||||
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
||||
1
|
||||
).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = attn.softmax(dim=-1)
|
||||
else:
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MotionFormerBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True,
|
||||
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
if not isinstance(self.window_size, (tuple, list)):
|
||||
self.window_size = to_2tuple(window_size)
|
||||
self.shift_size = shift_size
|
||||
if not isinstance(self.shift_size, (tuple, list)):
|
||||
self.shift_size = to_2tuple(shift_size)
|
||||
self.bidirectional = bidirectional
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = InterFrameAttention(
|
||||
dim,
|
||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
# BEGIN: absolute pos_embed, beneficial to local information extraction in our experiments
|
||||
self.pos_embed = PositionEmbeddingSine(dim // 2)
|
||||
# END: absolute pos_embed
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W, B, self_att=False):
|
||||
x = x.view(2 * B, H, W, -1)
|
||||
x_pad, mask = pad_if_needed(x, x.size(), self.window_size)
|
||||
|
||||
if self.shift_size[0] or self.shift_size[1]:
|
||||
_, H_p, W_p, C = x_pad.shape
|
||||
x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||
|
||||
if hasattr(self, 'HW') and self.HW.item() == H_p * W_p:
|
||||
shift_mask = self.attn_mask
|
||||
else:
|
||||
shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size[0]),
|
||||
slice(-self.window_size[0], -self.shift_size[0]),
|
||||
slice(-self.shift_size[0], None))
|
||||
w_slices = (slice(0, -self.window_size[1]),
|
||||
slice(-self.window_size[1], -self.shift_size[1]),
|
||||
slice(-self.shift_size[1], None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
shift_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1)
|
||||
shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
shift_mask = shift_mask.masked_fill(shift_mask != 0,
|
||||
float(-100.0)).masked_fill(shift_mask == 0,
|
||||
float(0.0))
|
||||
|
||||
if mask is not None:
|
||||
shift_mask = shift_mask.masked_fill(mask != 0,
|
||||
float(-100.0))
|
||||
self.register_buffer("attn_mask", shift_mask)
|
||||
self.register_buffer("HW", torch.Tensor([H_p * W_p]))
|
||||
else:
|
||||
shift_mask = mask
|
||||
|
||||
if shift_mask is not None:
|
||||
shift_mask = shift_mask.to(x_pad.device)
|
||||
|
||||
_, Hw, Ww, C = x_pad.shape
|
||||
x_win = window_partition(x_pad, self.window_size)
|
||||
|
||||
nwB = x_win.shape[0]
|
||||
x_norm = self.norm1(x_win)
|
||||
# BEGIN: absolute pos embed, beneficial to local information extraction in our experiments
|
||||
x_norm = x_norm.view(nwB, self.window_size[0], self.window_size[1], C).permute(0, 3, 1, 2)
|
||||
ape = self.pos_embed(x_norm)
|
||||
x_norm = x_norm + ape
|
||||
x_norm = x_norm.permute(0, 2, 3, 1).view(nwB, self.window_size[0] * self.window_size[1], C)
|
||||
# END: absolute pos embed
|
||||
|
||||
if self_att is False:
|
||||
x_reverse = torch.cat([x_norm[nwB // 2:], x_norm[:nwB // 2]])
|
||||
x_appearence = self.attn(x_norm, x_reverse, H, W, shift_mask)
|
||||
else:
|
||||
x_appearence = self.attn(x_norm, x_norm, H, W, shift_mask)
|
||||
|
||||
x_norm = x_norm + self.drop_path(x_appearence)
|
||||
|
||||
x_back = x_norm
|
||||
x_back_win = window_reverse(x_back, self.window_size, Hw, Ww)
|
||||
|
||||
if self.shift_size[0] or self.shift_size[1]:
|
||||
x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
|
||||
|
||||
x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2 * B, H * W, -1)
|
||||
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, depths=2, act_layer=nn.PReLU):
|
||||
super().__init__()
|
||||
layers = []
|
||||
for i in range(depths):
|
||||
if i == 0:
|
||||
layers.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1))
|
||||
else:
|
||||
layers.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1))
|
||||
layers.extend([
|
||||
act_layer(out_dim),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MotionFormer(nn.Module):
|
||||
def __init__(self, in_chans=3, embed_dims=None, num_heads=None,
|
||||
mlp_ratios=None, qkv_bias=True, qk_scale=None, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||
depths=None, window_sizes=None, **kwarg):
|
||||
super().__init__()
|
||||
self.depths = depths
|
||||
self.num_stages = len(embed_dims)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
self.conv_stages = self.num_stages - len(num_heads)
|
||||
|
||||
for i in range(self.num_stages):
|
||||
if i == 0:
|
||||
block = ConvBlock(in_chans, embed_dims[i], depths[i])
|
||||
else:
|
||||
if i < self.conv_stages:
|
||||
patch_embed = nn.Sequential(
|
||||
nn.Conv2d(embed_dims[i - 1], embed_dims[i], 3, 2, 1),
|
||||
nn.PReLU(embed_dims[i])
|
||||
)
|
||||
block = ConvBlock(embed_dims[i], embed_dims[i], depths[i])
|
||||
else:
|
||||
patch_embed = OverlapPatchEmbed(patch_size=3,
|
||||
stride=2,
|
||||
in_chans=embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i])
|
||||
|
||||
block = nn.ModuleList([MotionFormerBlock(
|
||||
dim=embed_dims[i], num_heads=num_heads[i - self.conv_stages],
|
||||
window_size=window_sizes[i - self.conv_stages],
|
||||
shift_size=0 if (j % 2) == 0 else window_sizes[i - self.conv_stages] // 2,
|
||||
mlp_ratio=mlp_ratios[i - self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
|
||||
for j in range(depths[i])])
|
||||
|
||||
norm = norm_layer(embed_dims[i])
|
||||
setattr(self, f"norm{i + 1}", norm)
|
||||
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f"block{i + 1}", block)
|
||||
|
||||
self.cor = {}
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def get_cor(self, shape, device):
|
||||
k = (str(shape), str(device))
|
||||
if k not in self.cor:
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
|
||||
1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
|
||||
tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
|
||||
1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
|
||||
self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
|
||||
return self.cor[k]
|
||||
|
||||
def forward(self, x1, x2):
|
||||
B = x1.shape[0]
|
||||
x = torch.cat([x1, x2], 0)
|
||||
appearence_features = []
|
||||
xs = []
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f"patch_embed{i + 1}", None)
|
||||
block = getattr(self, f"block{i + 1}", None)
|
||||
norm = getattr(self, f"norm{i + 1}", None)
|
||||
if i < self.conv_stages:
|
||||
if i > 0:
|
||||
x = patch_embed(x)
|
||||
x = block(x)
|
||||
xs.append(x)
|
||||
else:
|
||||
x, H, W = patch_embed(x)
|
||||
for j in range(len(block)):
|
||||
x = block[j](x, H, W, B, self_att=False)
|
||||
xs.append(x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous())
|
||||
x = norm(x)
|
||||
x = x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
appearence_features.append(x)
|
||||
return appearence_features
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||
x = self.dwconv(x)
|
||||
x = x.reshape(B, C, -1).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def feature_extractor(**kargs):
|
||||
model = MotionFormer(**kargs)
|
||||
return model
|
||||
208
sgm_vfi_arch/flow_estimation.py
Normal file
208
sgm_vfi_arch/flow_estimation.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .refine import *
|
||||
from .matching import MatchingBlock
|
||||
from .gmflow import GMFlow
|
||||
from .utils import InputPadder
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, c=64, layers=4, scale=4, in_else=17):
|
||||
super(IFBlock, self).__init__()
|
||||
self.scale = scale
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
conv(in_planes + in_else, c, 3, 1, 1),
|
||||
conv(c, c, 3, 1, 1),
|
||||
)
|
||||
|
||||
self.convblock = nn.Sequential(
|
||||
*[conv(c, c) for _ in range(layers)]
|
||||
)
|
||||
|
||||
self.lastconv = conv(c, 5)
|
||||
|
||||
def forward(self, x, flow=None, feature=None):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
|
||||
align_corners=False) * 1. / self.scale
|
||||
x = torch.cat((x, flow), 1)
|
||||
if feature != None:
|
||||
x = torch.cat((x, feature), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
tmp = self.lastconv(x)
|
||||
flow_s = tmp[:, :4]
|
||||
tmp = F.interpolate(tmp, scale_factor=self.scale, mode="bilinear", align_corners=False)
|
||||
flow = tmp[:, :4] * self.scale
|
||||
mask = tmp[:, 4:5]
|
||||
return flow, mask, flow_s
|
||||
|
||||
|
||||
class MultiScaleFlow(nn.Module):
|
||||
def __init__(self, backbone, **kargs):
|
||||
super(MultiScaleFlow, self).__init__()
|
||||
self.flow_num_stage = len(kargs['hidden_dims'])
|
||||
self.feature_bone = backbone
|
||||
self.scale = [1, 2, 4, 8]
|
||||
self.num_key_points = [kargs['num_key_points']]
|
||||
self.block = nn.ModuleList(
|
||||
[IFBlock(kargs['embed_dims'][-1] * 2, 128, 2, self.scale[-1], in_else=7), # 1/8
|
||||
IFBlock(kargs['embed_dims'][-2] * 2, 128, 2, self.scale[-2], in_else=18)]) # 1/4
|
||||
self.contextnet = Contextnet(kargs['c'] * 2)
|
||||
self.unet = Unet(kargs['c'] * 2)
|
||||
self.gmflow = GMFlow(
|
||||
num_scales=1,
|
||||
upsample_factor=8,
|
||||
feature_channels=128,
|
||||
attention_type='swin',
|
||||
num_transformer_layers=6,
|
||||
ffn_dim_expansion=4,
|
||||
num_head=1)
|
||||
|
||||
self.matching_block = nn.ModuleList([
|
||||
MatchingBlock(scale=8, dim=kargs['embed_dims'][-1], c=kargs['c'] * 4, num_layers=1, gm=True),
|
||||
None
|
||||
])
|
||||
|
||||
self.padding_factor = 16
|
||||
|
||||
|
||||
def calculate_flow(self, imgs, timestep):
|
||||
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
||||
B = img0.size(0)
|
||||
flow, mask = None, None
|
||||
flow_s = None
|
||||
|
||||
af = self.feature_bone(img0, img1)
|
||||
if self.gmflow is not None:
|
||||
padder = InputPadder(img0.shape, padding_factor=self.padding_factor)
|
||||
img0_p, img1_p = padder.pad(img0, img1)
|
||||
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
||||
matching_feat = results['trans_feat']
|
||||
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1])
|
||||
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
||||
|
||||
for i in range(2):
|
||||
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
||||
af0 = af[-1 - i][:B]
|
||||
af1 = af[-1 - i][B:]
|
||||
if flow != None:
|
||||
flow_d, mask_d, flow_s_d = self.block[i](
|
||||
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
||||
flow,
|
||||
torch.cat([af0, af1], 1),
|
||||
)
|
||||
flow = flow + flow_d
|
||||
mask = mask + mask_d
|
||||
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
||||
flow_s = flow_s + flow_s_d
|
||||
else:
|
||||
flow, mask, flow_s = self.block[i](
|
||||
torch.cat((img0, img1, t), 1),
|
||||
None,
|
||||
torch.cat([af0, af1], 1))
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
if self.matching_block[i] is not None:
|
||||
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1 - i],
|
||||
init_flow=flow, init_flow_s=flow_s, init_mask=mask,
|
||||
warped_img0=warped_img0, warped_img1=warped_img1,
|
||||
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1 - i],
|
||||
timestep=timestep)
|
||||
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
||||
flow = flow + flow_t
|
||||
mask = mask + mask_t
|
||||
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
return flow, mask
|
||||
|
||||
def coraseWarp_and_Refine(self, imgs, flow, mask):
|
||||
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
res = tmp[:, :3] * 2 - 1
|
||||
mask_ = torch.sigmoid(mask)
|
||||
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
|
||||
pred = torch.clamp(merged + res, 0, 1)
|
||||
return pred
|
||||
|
||||
def forward(self, x, timestep=0.5):
|
||||
img0, img1 = x[:, :3], x[:, 3:6]
|
||||
B = x.size(0)
|
||||
flow_list, mask_list = [], []
|
||||
merged, merged_fine = [], []
|
||||
warped_img0, warped_img1 = img0, img1
|
||||
flow, mask, flow_s = None, None, None
|
||||
flow_matching_list = []
|
||||
matching_feat = []
|
||||
af = self.feature_bone(img0, img1)
|
||||
if self.gmflow is not None:
|
||||
padder = InputPadder(img0.shape, padding_factor=self.padding_factor, additional_pad=False)
|
||||
img0_p, img1_p = padder.pad(img0, img1)
|
||||
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
||||
matching_feat = results['trans_feat']
|
||||
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1], additional_pad=False)
|
||||
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
||||
|
||||
for i in range(2):
|
||||
af0 = af[-1 - i][:B]
|
||||
af1 = af[-1 - i][B:]
|
||||
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
||||
if flow != None:
|
||||
flow_d, mask_d, flow_s_d = self.block[i](
|
||||
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
||||
flow,
|
||||
torch.cat([af0, af1], 1),
|
||||
)
|
||||
flow = flow + flow_d
|
||||
mask = mask + mask_d
|
||||
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
||||
flow_s = flow_s + flow_s_d
|
||||
else:
|
||||
flow, mask, flow_s = self.block[i](
|
||||
torch.cat((img0, img1, t), 1),
|
||||
None,
|
||||
torch.cat([af0, af1], 1))
|
||||
mask_list.append(torch.sigmoid(mask))
|
||||
flow_list.append(flow)
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
|
||||
if self.matching_block[i] is not None:
|
||||
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1-i].detach(),
|
||||
init_flow=flow.detach(), init_flow_s=flow_s.detach(), init_mask=mask.detach(),
|
||||
warped_img0=warped_img0.detach(), warped_img1=warped_img1.detach(),
|
||||
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1-i],
|
||||
timestep=0.5)
|
||||
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
||||
flow = flow + flow_t
|
||||
mask = mask + mask_t
|
||||
mask_list[i] = torch.sigmoid(mask)
|
||||
warped_img0_fine = warp(img0, flow[:, 0:2])
|
||||
warped_img1_fine = warp(img1, flow[:, 2:4])
|
||||
merged_fine.append(warped_img0_fine * mask_list[i] + warped_img1_fine * (1 - mask_list[i]))
|
||||
warped_img0, warped_img1 = warped_img0_fine, warped_img1_fine # NOTE: for next iteration training
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
res = tmp[:, :3] * 2 - 1
|
||||
pred = torch.clamp(merged[-1] + res, 0, 1)
|
||||
merged.extend(merged_fine)
|
||||
return flow_list, mask_list, merged, pred, flow_matching_list
|
||||
96
sgm_vfi_arch/geometry.py
Normal file
96
sgm_vfi_arch/geometry.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def coords_grid(b, h, w, homogeneous=False, device=None):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
||||
|
||||
stacks = [x, y]
|
||||
|
||||
if homogeneous:
|
||||
ones = torch.ones_like(x) # [H, W]
|
||||
stacks.append(ones)
|
||||
|
||||
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
||||
|
||||
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
||||
|
||||
if device is not None:
|
||||
grid = grid.to(device)
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
||||
assert device is not None
|
||||
|
||||
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
||||
torch.linspace(h_min, h_max, len_h, device=device)],
|
||||
)
|
||||
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def normalize_coords(coords, h, w):
|
||||
# coords: [B, H, W, 2]
|
||||
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
||||
return (coords - c) / c # [-1, 1]
|
||||
|
||||
|
||||
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
||||
# img: [B, C, H, W]
|
||||
# sample_coords: [B, 2, H, W] in image scale
|
||||
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
||||
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
||||
|
||||
b, _, h, w = sample_coords.shape
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
||||
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
||||
|
||||
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
||||
|
||||
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
||||
|
||||
if return_mask:
|
||||
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
||||
|
||||
return img, mask
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
||||
b, c, h, w = feature.size()
|
||||
assert flow.size(1) == 2
|
||||
|
||||
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
||||
|
||||
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
||||
return_mask=mask)
|
||||
|
||||
|
||||
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
||||
alpha=0.01,
|
||||
beta=0.5
|
||||
):
|
||||
# fwd_flow, bwd_flow: [B, 2, H, W]
|
||||
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
||||
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
||||
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
||||
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
||||
|
||||
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
||||
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
||||
|
||||
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
||||
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
||||
|
||||
threshold = alpha * flow_mag + beta
|
||||
|
||||
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
||||
bwd_occ = (diff_bwd > threshold).float()
|
||||
|
||||
return fwd_occ, bwd_occ
|
||||
87
sgm_vfi_arch/gmflow.py
Normal file
87
sgm_vfi_arch/gmflow.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .backbone import CNNEncoder
|
||||
from .transformer import FeatureTransformer, FeatureFlowAttention
|
||||
from .utils import feature_add_position
|
||||
|
||||
class GMFlow(nn.Module):
|
||||
def __init__(self,
|
||||
num_scales=1,
|
||||
upsample_factor=8,
|
||||
feature_channels=128,
|
||||
attention_type='swin',
|
||||
num_transformer_layers=6,
|
||||
ffn_dim_expansion=4,
|
||||
num_head=1,
|
||||
**kwargs,
|
||||
):
|
||||
super(GMFlow, self).__init__()
|
||||
|
||||
self.num_scales = num_scales
|
||||
self.feature_channels = feature_channels
|
||||
self.upsample_factor = upsample_factor
|
||||
self.attention_type = attention_type
|
||||
self.num_transformer_layers = num_transformer_layers
|
||||
|
||||
# CNN backbone
|
||||
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
||||
|
||||
# Transformer
|
||||
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
|
||||
d_model=feature_channels,
|
||||
nhead=num_head,
|
||||
attention_type=attention_type,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
)
|
||||
|
||||
def extract_feature(self, img0, img1):
|
||||
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
||||
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
|
||||
|
||||
# reverse: resolution from low to high
|
||||
features = features[::-1]
|
||||
|
||||
feature0, feature1 = [], []
|
||||
|
||||
for i in range(len(features)):
|
||||
feature = features[i]
|
||||
chunks = torch.chunk(feature, 2, 0) # tuple
|
||||
feature0.append(chunks[0])
|
||||
feature1.append(chunks[1])
|
||||
|
||||
return feature0, feature1
|
||||
|
||||
def forward(self, img0, img1,
|
||||
attn_splits_list=None,
|
||||
corr_radius_list=None,
|
||||
prop_radius_list=None,
|
||||
pred_bidir_flow=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
results_dict = {}
|
||||
flow_preds = []
|
||||
flow_s_macthing = []
|
||||
flow_s_prop = []
|
||||
transformer_features = []
|
||||
|
||||
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
||||
|
||||
flow = None
|
||||
|
||||
for scale_idx in range(self.num_scales):
|
||||
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
||||
|
||||
attn_splits = attn_splits_list[scale_idx]
|
||||
|
||||
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
|
||||
|
||||
# Transformer
|
||||
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
|
||||
transformer_features.append(torch.cat([feature0, feature1], 0))
|
||||
|
||||
results_dict.update({'trans_feat': transformer_features})
|
||||
|
||||
return results_dict
|
||||
278
sgm_vfi_arch/matching.py
Normal file
278
sgm_vfi_arch/matching.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .warplayer import warp as backwarp
|
||||
from .softsplat import softsplat
|
||||
from .geometry import coords_grid
|
||||
|
||||
|
||||
# for random sample ablation
|
||||
def random_sample(feature, num_points=256):
|
||||
rand_ind = torch.randint(low=0, high=feature.shape[1], size=(feature.shape[0], num_points)).unsqueeze(-1).to(
|
||||
feature.device)
|
||||
kp = torch.gather(feature, dim=1, index=rand_ind.expand(-1, -1, feature.shape[2]))
|
||||
return rand_ind, kp
|
||||
|
||||
def sample_key_points(importance_map, feature, num_points=256):
|
||||
importance_map = importance_map.view(-1, 1, importance_map.shape[2] * importance_map.shape[3]).permute(0, 2, 1)
|
||||
_, kp_ind = torch.topk(importance_map, num_points, dim=1)
|
||||
kp = torch.gather(feature, dim=1, index=kp_ind.expand(-1, -1, feature.shape[2]))
|
||||
return kp_ind, kp
|
||||
|
||||
|
||||
def forward_warp(tenIn, tenFlow, z=None):
|
||||
if z is None:
|
||||
z = torch.ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]).to(tenIn.device)
|
||||
else:
|
||||
z = torch.where(z == 0, -20, 1)
|
||||
out = softsplat(tenIn, tenFlow, tenMetric=z, strMode='soft')
|
||||
return out
|
||||
|
||||
|
||||
def warp_twice(imgA, target, flow_tA, flow_tB):
|
||||
It_warp = backwarp(imgA, flow_tA) # backward warp(I1,Ft1)
|
||||
z = torch.ones([imgA.shape[0], 1, imgA.shape[2], imgA.shape[3]]).to(imgA.device)
|
||||
IB_warp = softsplat(tenIn=It_warp, tenFlow=flow_tB, tenMetric=z, strMode='soft')
|
||||
return IB_warp
|
||||
|
||||
|
||||
def build_map(imgA, imgB, flow_tA, flow_tB):
|
||||
# build map for img B
|
||||
IB_warp = warp_twice(imgA, imgB, flow_tA, flow_tB)
|
||||
difference_map = IB_warp - imgB # [B, 3, H, W], difference map on IB
|
||||
difference_map = torch.sum(torch.abs(difference_map), dim=1, keepdim=True) # B, 1, H, W
|
||||
return difference_map
|
||||
|
||||
|
||||
def build_hole_mask(img_template, flow_tA, flow_tB):
|
||||
# build hole mask
|
||||
with torch.no_grad():
|
||||
ones = torch.ones(img_template.shape[0], 1, img_template.shape[2], img_template.shape[3]).to(
|
||||
img_template.device)
|
||||
out = warp_twice(ones, ones, flow_tA, flow_tB)
|
||||
hole_mask = torch.where(out == 0, 0, 1)
|
||||
return hole_mask
|
||||
|
||||
|
||||
def gen_importance_map(img0, img1, flow):
|
||||
I1_dmap = build_map(img0, img1, flow[:, 0:2], flow[:, 2:4])
|
||||
I0_dmap = build_map(img1, img0, flow[:, 2:4], flow[:, 0:2])
|
||||
|
||||
I1_hole_mask = build_hole_mask(img0, flow[:, 0:2], flow[:, 2:4])
|
||||
I0_hole_mask = build_hole_mask(img1, flow[:, 2:4], flow[:, 0:2])
|
||||
|
||||
I1_dmap = I1_dmap * I1_hole_mask
|
||||
I0_dmap = I0_dmap * I0_hole_mask
|
||||
|
||||
I0_prob = warp_twice(I1_dmap, I1_dmap, flow[:, 2:4], flow[:, 0:2])
|
||||
I1_prob = warp_twice(I0_dmap, I0_dmap, flow[:, 0:2], flow[:, 2:4])
|
||||
|
||||
importance_map = torch.cat([I0_prob, I1_prob], dim=0) # 2B, 1, H, W
|
||||
return importance_map
|
||||
|
||||
|
||||
def global_matching(key_feature, global_feature, key_index, H, W):
|
||||
b, n, c = global_feature.shape
|
||||
query = key_feature
|
||||
key = global_feature
|
||||
correlation = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, k, H*W]
|
||||
|
||||
prob = F.softmax(correlation, dim=-1)
|
||||
init_grid = coords_grid(b, H, W, homogeneous=False, device=global_feature.device)
|
||||
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
||||
out = torch.matmul(prob, grid) # B, k, 2
|
||||
if key_index is not None:
|
||||
flow_fix = torch.zeros_like(grid)
|
||||
# key_index: [B, K, 1], out: [B, K, 2], flow_fix: [B, H*W, 2]
|
||||
flow_fix = torch.scatter(flow_fix, dim=1, index=key_index.expand(-1, -1, 2), src=out)
|
||||
flow_fix = flow_fix.view(b, H, W, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
||||
|
||||
# for grid, points in grid and not in key_index, set to 0
|
||||
grid_new = torch.zeros_like(grid)
|
||||
key_pos = torch.ones_like(out)
|
||||
grid_new = torch.scatter(grid_new, dim=1, index=key_index.expand(-1, -1, 2), src=key_pos)
|
||||
grid = (grid * grid_new).reshape(b, H, W, 2).permute(0, 3, 1, 2)
|
||||
flow_fix = flow_fix - grid
|
||||
else:
|
||||
flow_fix = out.view(b, H, W, 2).permute(0, 3, 1, 2)
|
||||
flow_fix = flow_fix - init_grid
|
||||
return flow_fix, prob
|
||||
|
||||
|
||||
def extract_topk(foo, k):
|
||||
b, _, h, w = foo.shape
|
||||
foo = foo.view(b, 1, h * w).permute(0, 2, 1)
|
||||
kp, kp_ind = torch.topk(foo, k, dim=1)
|
||||
grid = torch.zeros(b, h * w, 1).to(foo.device)
|
||||
out = torch.scatter(grid, dim=1, index=kp_ind, src=kp)
|
||||
out = out.permute(0, 2, 1).reshape(b, 1, h, w)
|
||||
return out
|
||||
|
||||
|
||||
def flow_shift(flow_fix, timestep, num_key_points=None, select_topk=False):
|
||||
B = flow_fix.shape[0] // 2
|
||||
z = torch.where(flow_fix == 0, 0, 1).detach().sum(1, keepdim=True) / 2
|
||||
zt0, zt1 = z[B:], z[:B]
|
||||
flow_fix_t0 = forward_warp(flow_fix[B:] * timestep, flow_fix[B:] * (1 - timestep), z=zt0)
|
||||
flow_fix_t1 = forward_warp(flow_fix[:B] * (1 - timestep), flow_fix[:B] * timestep, z=zt1)
|
||||
flow_fix_t = torch.cat([flow_fix_t0, flow_fix_t1], 0)
|
||||
if select_topk and num_key_points != -1:
|
||||
warp_map_t0 = softsplat(zt0, flow_fix[B:] * (1 - timestep), None, 'sum')
|
||||
warp_map_t1 = softsplat(zt1, flow_fix[:B] * timestep, None, 'sum')
|
||||
|
||||
warp_map = torch.cat([warp_map_t0, warp_map_t1], 0)
|
||||
warp_map_topk = extract_topk(warp_map, num_key_points)
|
||||
warp_map_topk = torch.where(warp_map_topk != 0, 1, 0)
|
||||
flow_fix_t = flow_fix_t * warp_map_topk
|
||||
return flow_fix_t
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
def deconv(in_planes=64, out_planes=64, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
class FlowRefine(nn.Module):
|
||||
def __init__(self, in_planes, scale=4, c=64, n_layers=8):
|
||||
super(FlowRefine, self).__init__()
|
||||
self.conv0 = nn.Sequential(
|
||||
conv(in_planes, c, 3, 1, 1),
|
||||
conv(c, c, 3, 1, 1),
|
||||
)
|
||||
self.convblock = nn.Sequential(
|
||||
*[conv(c, c) for _ in range(n_layers)]
|
||||
)
|
||||
self.lastconv = conv(c, 5)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, flow_s, flow):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
|
||||
if flow is not None:
|
||||
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
|
||||
align_corners=False) * 1. / self.scale
|
||||
x = torch.cat((x, flow), 1)
|
||||
if flow_s is not None:
|
||||
x = torch.cat((x, flow_s), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
x = self.lastconv(x)
|
||||
tmp = F.interpolate(x, scale_factor=self.scale, mode="bilinear", align_corners=False)
|
||||
flow = tmp[:, :4] * self.scale
|
||||
mask = tmp[:, 4:5]
|
||||
return flow, mask
|
||||
|
||||
|
||||
class MergingBlock(nn.Module):
|
||||
def __init__(self, radius=3, input_dim=256, hidden_dim=256):
|
||||
super(MergingBlock, self).__init__()
|
||||
self.r = radius
|
||||
self.rf = radius ** 2
|
||||
self.conv = nn.Sequential(nn.Conv2d(8 + 2*input_dim, hidden_dim, 3, 1, 1),
|
||||
nn.PReLU(hidden_dim),
|
||||
nn.Conv2d(hidden_dim, 2*2*self.rf, 1, 1, 0))
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(0.1 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, feature, init_flow, flow_fix):
|
||||
"""
|
||||
:param feature: [B, C, H, W] -> (local feature) or (local feature + matching feature)
|
||||
:param init_flow: [B, 2, H, W] -> (local init flow)
|
||||
:param flow_fix: [B, 2, H, W] -> (matching output, flow_fix (after patching, no hollows))
|
||||
"""
|
||||
b, flow_channel, h, w = init_flow.shape
|
||||
concat = torch.cat((init_flow, flow_fix, feature), dim=1)
|
||||
mask = self.conv(concat)
|
||||
assert init_flow.shape == flow_fix.shape, f"different flow shape not implemented yet"
|
||||
mask = mask.view(b, 1, 2 * 2 * self.rf, h, w)
|
||||
mask0 = mask[:, :, :2 * self.rf, :, :]
|
||||
mask1 = mask[:, :, 2 * self.rf:, :, :]
|
||||
mask = torch.cat([mask0, mask1], dim=0)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
init_flow_all = torch.cat([init_flow[:, 0:2], init_flow[:, 2:4]], dim=0)
|
||||
flow_fix_all = torch.cat([flow_fix[:, 0:2], flow_fix[:, 2:4]], dim=0)
|
||||
|
||||
init_flow_grid = F.unfold(init_flow_all, [self.r, self.r], padding=self.r//2)
|
||||
init_flow_grid = init_flow_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
|
||||
flow_fix_grid = F.unfold(flow_fix_all, [self.r, self.r], padding=self.r//2)
|
||||
flow_fix_grid = flow_fix_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
|
||||
|
||||
flow_grid = torch.cat([init_flow_grid, flow_fix_grid], dim=2) # [B, 2, 2*9, H, W]
|
||||
|
||||
merge_flow = torch.sum(mask * flow_grid, dim=2) # [B, 2, H, W]
|
||||
return merge_flow
|
||||
|
||||
|
||||
class MatchingBlock(nn.Module):
|
||||
def __init__(self, scale, c, dim, num_layers=2, gm=True):
|
||||
super(MatchingBlock, self).__init__()
|
||||
self.gm = gm
|
||||
self.dim = dim
|
||||
self.scale = scale
|
||||
self.merge = MergingBlock(radius=3, input_dim=dim+128, hidden_dim=256)
|
||||
self.refine_block = FlowRefine(27, scale, c, num_layers)
|
||||
|
||||
def forward(self, img0, img1, x, main_x, init_flow, init_flow_s, init_mask,
|
||||
warped_img0, warped_img1, num_key_points, scale_factor, timestep=0.5):
|
||||
result_dict = {}
|
||||
|
||||
_, c, h, w = x.shape
|
||||
B = main_x.shape[0] // 2
|
||||
# NOTE:
|
||||
# 1. we stop sparse selecting points when the image resolution
|
||||
# becomes too small (1/8 feature map resolution <= 32, i.e., h <= 256)
|
||||
# (see `random_rescale` in train_x4k.py)
|
||||
# 2. This limitation should be deleted when evaluating on low-resolution images (<=256x256)
|
||||
if num_key_points != -1 and h > 32:
|
||||
num_key_points = int(num_key_points * (h * w))
|
||||
else:
|
||||
num_key_points = -1 # -1 stands for global matching
|
||||
|
||||
feature = x.permute(0, 2, 3, 1).reshape(2 * B, h*w, c)
|
||||
feature_reverse = torch.cat([feature[B:], feature[:B]], 0)
|
||||
|
||||
if num_key_points == -1:
|
||||
flow_fix_norm, _ = global_matching(feature, feature_reverse, None, h, w)
|
||||
else:
|
||||
imap = gen_importance_map(img0, img1, init_flow)
|
||||
imap_s = F.interpolate(imap, size=(h, w), mode="bilinear", align_corners=False)
|
||||
kp_ind, kp_feature = sample_key_points(imap_s, feature, num_key_points)
|
||||
flow_fix_norm, _ = global_matching(kp_feature, feature_reverse, kp_ind, h, w)
|
||||
|
||||
flow_fix = flow_shift(flow_fix_norm, timestep, num_key_points, select_topk=True)
|
||||
flow_fix = torch.cat([flow_fix[:B], flow_fix[B:]], 1)
|
||||
flow_r = torch.where(flow_fix == 0, init_flow_s, flow_fix)
|
||||
flow_merge = self.merge(torch.cat([x[:B], x[B:], main_x[:B], main_x[B:]], dim=1), init_flow_s, flow_r)
|
||||
flow_merge = torch.cat([flow_merge[:B], flow_merge[B:]], dim=1)
|
||||
img0_s = F.interpolate(img0, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
|
||||
img1_s = F.interpolate(img1, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
|
||||
warped_img0_fine_s_m = backwarp(img0_s, flow_merge[:, 0:2])
|
||||
warped_img1_fine_s_m = backwarp(img1_s, flow_merge[:, 2:4])
|
||||
|
||||
flow_t, mask_t = self.refine_block(torch.cat((img0, img1, warped_img0, warped_img1, init_mask), 1),
|
||||
torch.cat([warped_img0_fine_s_m, warped_img1_fine_s_m, flow_merge], 1),
|
||||
init_flow)
|
||||
|
||||
result_dict.update({'flow_t': flow_t})
|
||||
result_dict.update({'mask_t': mask_t})
|
||||
return result_dict
|
||||
46
sgm_vfi_arch/position.py
Normal file
46
sgm_vfi_arch/position.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
# x = tensor_list.tensors # [B, C, H, W]
|
||||
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
||||
b, c, h, w = x.size()
|
||||
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
||||
y_embed = mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
98
sgm_vfi_arch/refine.py
Normal file
98
sgm_vfi_arch/refine.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from timm.models.layers import trunc_normal_
|
||||
from .warplayer import warp
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
class Conv2(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=2):
|
||||
super(Conv2, self).__init__()
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class Contextnet(nn.Module):
|
||||
def __init__(self, c=16):
|
||||
super(Contextnet, self).__init__()
|
||||
self.conv1 = Conv2(3, c)
|
||||
self.conv2 = Conv2(c, 2 * c)
|
||||
self.conv3 = Conv2(2 * c, 4 * c)
|
||||
self.conv4 = Conv2(4 * c, 8 * c)
|
||||
|
||||
def forward(self, x, flow):
|
||||
x = self.conv1(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||
recompute_scale_factor=False) * 0.5
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||
recompute_scale_factor=False) * 0.5
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||
recompute_scale_factor=False) * 0.5
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||
recompute_scale_factor=False) * 0.5
|
||||
f4 = warp(x, flow)
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(self, c=16, out=3):
|
||||
super(Unet, self).__init__()
|
||||
self.down0 = Conv2(17, 2*c)
|
||||
self.down1 = Conv2(4*c, 4*c)
|
||||
self.down2 = Conv2(8*c, 8*c)
|
||||
self.down3 = Conv2(16*c, 16*c)
|
||||
self.up0 = deconv(32*c, 8*c)
|
||||
self.up1 = deconv(16*c, 4*c)
|
||||
self.up2 = deconv(8*c, 2*c)
|
||||
self.up3 = deconv(4*c, c)
|
||||
self.conv = nn.Conv2d(c, out, 3, 1, 1)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
|
||||
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
|
||||
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
|
||||
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
|
||||
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
|
||||
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
|
||||
x = self.up1(torch.cat((x, s2), 1))
|
||||
x = self.up2(torch.cat((x, s1), 1))
|
||||
x = self.up3(torch.cat((x, s0), 1))
|
||||
x = self.conv(x)
|
||||
return torch.sigmoid(x)
|
||||
530
sgm_vfi_arch/softsplat.py
Normal file
530
sgm_vfi_arch/softsplat.py
Normal file
@@ -0,0 +1,530 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import collections
|
||||
import cupy
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import typing
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
objCudacache = {}
|
||||
|
||||
|
||||
def cuda_int32(intIn:int):
|
||||
return cupy.int32(intIn)
|
||||
# end
|
||||
|
||||
|
||||
def cuda_float32(fltIn:float):
|
||||
return cupy.float32(fltIn)
|
||||
# end
|
||||
|
||||
|
||||
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
||||
if 'device' not in objCudacache:
|
||||
objCudacache['device'] = torch.cuda.get_device_name()
|
||||
# end
|
||||
|
||||
strKey = strFunction
|
||||
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
strKey += strVariable
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKey += objValue
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
strKey += str(objValue.dtype)
|
||||
strKey += str(objValue.shape)
|
||||
strKey += str(objValue.stride())
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert(False)
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
strKey += objCudacache['device']
|
||||
|
||||
if strKey not in objCudacache:
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||
strKernel = strKernel.replace('{{type}}', 'unsigned char')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||
strKernel = strKernel.replace('{{type}}', 'half')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||
strKernel = strKernel.replace('{{type}}', 'float')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||
strKernel = strKernel.replace('{{type}}', 'double')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||
strKernel = strKernel.replace('{{type}}', 'int')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||
strKernel = strKernel.replace('{{type}}', 'long')
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
print(strVariable, objValue.dtype)
|
||||
assert(False)
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert(False)
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intArg = int(objMatch.group(2))
|
||||
|
||||
strTensor = objMatch.group(4)
|
||||
intSizes = objVariables[strTensor].size()
|
||||
|
||||
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(',')
|
||||
|
||||
assert(intArgs == len(strArgs) - 1)
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(',')
|
||||
|
||||
assert(intArgs == len(strArgs) - 1)
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
|
||||
# end
|
||||
|
||||
objCudacache[strKey] = {
|
||||
'strFunction': strFunction,
|
||||
'strKernel': strKernel
|
||||
}
|
||||
# end
|
||||
|
||||
return strKey
|
||||
# end
|
||||
|
||||
|
||||
@cupy.memoize(for_each_device=True)
|
||||
def cuda_launch(strKey:str):
|
||||
if 'CUDA_HOME' not in os.environ:
|
||||
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
||||
# end
|
||||
|
||||
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
|
||||
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
||||
# end
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
||||
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
||||
|
||||
if strMode == 'sum': assert(tenMetric is None)
|
||||
if strMode == 'avg': assert(tenMetric is None)
|
||||
if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
|
||||
if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
|
||||
|
||||
if strMode == 'avg':
|
||||
tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
|
||||
|
||||
elif strMode.split('-')[0] == 'linear':
|
||||
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
|
||||
|
||||
elif strMode.split('-')[0] == 'soft':
|
||||
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
|
||||
|
||||
# end
|
||||
|
||||
tenOut = softsplat_func.apply(tenIn, tenFlow)
|
||||
|
||||
if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
|
||||
tenNormalize = tenOut[:, -1:, :, :]
|
||||
|
||||
if len(strMode.split('-')) == 1:
|
||||
tenNormalize = tenNormalize + 0.0000001
|
||||
|
||||
elif strMode.split('-')[1] == 'addeps':
|
||||
tenNormalize = tenNormalize + 0.0000001
|
||||
|
||||
elif strMode.split('-')[1] == 'zeroeps':
|
||||
tenNormalize[tenNormalize == 0.0] = 1.0
|
||||
|
||||
elif strMode.split('-')[1] == 'clipeps':
|
||||
tenNormalize = tenNormalize.clip(0.0000001, None)
|
||||
|
||||
# end
|
||||
|
||||
tenOut = tenOut[:, :-1, :, :] / tenNormalize
|
||||
# end
|
||||
|
||||
return tenOut
|
||||
# end
|
||||
|
||||
|
||||
class softsplat_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(self, tenIn, tenFlow):
|
||||
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
||||
|
||||
if tenIn.is_cuda == True:
|
||||
cuda_launch(cuda_kernel('softsplat_out', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
{{type}}* __restrict__ tenOut
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
|
||||
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
|
||||
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
|
||||
}
|
||||
} }
|
||||
''', {
|
||||
'tenIn': tenIn,
|
||||
'tenFlow': tenFlow,
|
||||
'tenOut': tenOut
|
||||
}))(
|
||||
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
|
||||
elif tenIn.is_cuda != True:
|
||||
assert(False)
|
||||
|
||||
# end
|
||||
|
||||
self.save_for_backward(tenIn, tenFlow)
|
||||
|
||||
return tenOut
|
||||
# end
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(self, tenOutgrad):
|
||||
tenIn, tenFlow = self.saved_tensors
|
||||
|
||||
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
|
||||
|
||||
tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
|
||||
tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
|
||||
|
||||
if tenIngrad is not None:
|
||||
cuda_launch(cuda_kernel('softsplat_ingrad', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
{{type}}* __restrict__ tenIngrad,
|
||||
{{type}}* __restrict__ tenFlowgrad
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
|
||||
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
|
||||
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltIngrad = 0.0f;
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
||||
}
|
||||
|
||||
tenIngrad[intIndex] = fltIngrad;
|
||||
} }
|
||||
''', {
|
||||
'tenIn': tenIn,
|
||||
'tenFlow': tenFlow,
|
||||
'tenOutgrad': tenOutgrad,
|
||||
'tenIngrad': tenIngrad,
|
||||
'tenFlowgrad': tenFlowgrad
|
||||
}))(
|
||||
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
# end
|
||||
|
||||
if tenFlowgrad is not None:
|
||||
cuda_launch(cuda_kernel('softsplat_flowgrad', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
{{type}}* __restrict__ tenIngrad,
|
||||
{{type}}* __restrict__ tenFlowgrad
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
|
||||
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
|
||||
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltFlowgrad = 0.0f;
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = 0.0f;
|
||||
{{type}} fltNortheast = 0.0f;
|
||||
{{type}} fltSouthwest = 0.0f;
|
||||
{{type}} fltSoutheast = 0.0f;
|
||||
|
||||
if (intC == 0) {
|
||||
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
|
||||
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
|
||||
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
} else if (intC == 1) {
|
||||
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
|
||||
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
|
||||
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
|
||||
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
|
||||
|
||||
}
|
||||
|
||||
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
|
||||
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
|
||||
}
|
||||
}
|
||||
|
||||
tenFlowgrad[intIndex] = fltFlowgrad;
|
||||
} }
|
||||
''', {
|
||||
'tenIn': tenIn,
|
||||
'tenFlow': tenFlow,
|
||||
'tenOutgrad': tenOutgrad,
|
||||
'tenIngrad': tenIngrad,
|
||||
'tenFlowgrad': tenFlowgrad
|
||||
}))(
|
||||
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
# end
|
||||
|
||||
return tenIngrad, tenFlowgrad
|
||||
# end
|
||||
# end
|
||||
450
sgm_vfi_arch/transformer.py
Normal file
450
sgm_vfi_arch/transformer.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def split_feature(feature,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B, H, W, C]
|
||||
b, h, w, c = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
||||
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
||||
else: # [B, C, H, W]
|
||||
b, c, h, w = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0, f'h: {h}, w: {w}, num_splits: {num_splits}'
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
||||
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def merge_splits(splits,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B*K*K, H/K, W/K, C]
|
||||
b, h, w, c = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
||||
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
||||
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
||||
else: # [B*K*K, C, H/K, W/K]
|
||||
b, c, h, w = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
||||
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
||||
|
||||
return merge
|
||||
|
||||
|
||||
def single_head_full_attention(q, k, v):
|
||||
# q, k, v: [B, L, C]
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
|
||||
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
|
||||
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
||||
out = torch.matmul(attn, v) # [B, L, C]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
|
||||
shift_size_h, shift_size_w, device=None):
|
||||
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||
# calculate attention mask for SW-MSA
|
||||
h, w = input_resolution
|
||||
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
||||
h_slices = (slice(0, -window_size_h),
|
||||
slice(-window_size_h, -shift_size_h),
|
||||
slice(-shift_size_h, None))
|
||||
w_slices = (slice(0, -window_size_w),
|
||||
slice(-window_size_w, -shift_size_w),
|
||||
slice(-shift_size_w, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
|
||||
|
||||
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
def single_head_split_window_attention(q, k, v,
|
||||
num_splits=1,
|
||||
with_shift=False,
|
||||
h=None,
|
||||
w=None,
|
||||
attn_mask=None,
|
||||
):
|
||||
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||
# q, k, v: [B, L, C]
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
|
||||
assert h is not None and w is not None
|
||||
assert q.size(1) == h * w
|
||||
|
||||
b, _, c = q.size()
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
|
||||
window_size_h = h // num_splits
|
||||
window_size_w = w // num_splits
|
||||
|
||||
q = q.view(b, h, w, c) # [B, H, W, C]
|
||||
k = k.view(b, h, w, c)
|
||||
v = v.view(b, h, w, c)
|
||||
|
||||
scale_factor = c ** 0.5
|
||||
|
||||
if with_shift:
|
||||
assert attn_mask is not None # compute once
|
||||
shift_size_h = window_size_h // 2
|
||||
shift_size_w = window_size_w // 2
|
||||
|
||||
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
|
||||
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
||||
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
||||
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
||||
|
||||
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
||||
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
|
||||
|
||||
if with_shift:
|
||||
scores += attn_mask.repeat(b, 1, 1)
|
||||
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
||||
|
||||
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
|
||||
num_splits=num_splits, channel_last=True) # [B, H, W, C]
|
||||
|
||||
# shift back
|
||||
if with_shift:
|
||||
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
||||
|
||||
out = out.view(b, -1, c)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self,
|
||||
d_model=256,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
no_ffn=False,
|
||||
ffn_dim_expansion=4,
|
||||
with_shift=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(TransformerLayer, self).__init__()
|
||||
|
||||
self.dim = d_model
|
||||
self.nhead = nhead
|
||||
self.attention_type = attention_type
|
||||
self.no_ffn = no_ffn
|
||||
|
||||
self.with_shift = with_shift
|
||||
|
||||
# multi-head attention
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
||||
# no ffn after self-attn, with ffn after cross-attn
|
||||
if not self.no_ffn:
|
||||
in_channels = d_model * 2
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, source, target,
|
||||
height=None,
|
||||
width=None,
|
||||
shifted_window_attn_mask=None,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
# source, target: [B, L, C]
|
||||
query, key, value = source, target, target
|
||||
|
||||
# single-head attention
|
||||
query = self.q_proj(query) # [B, L, C]
|
||||
key = self.k_proj(key) # [B, L, C]
|
||||
value = self.v_proj(value) # [B, L, C]
|
||||
|
||||
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||
if self.nhead > 1:
|
||||
# we observe that multihead attention slows down the speed and increases the memory consumption
|
||||
# without bringing obvious performance gains and thus the implementation is removed
|
||||
raise NotImplementedError
|
||||
else:
|
||||
message = single_head_split_window_attention(query, key, value,
|
||||
num_splits=attn_num_splits,
|
||||
with_shift=self.with_shift,
|
||||
h=height,
|
||||
w=width,
|
||||
attn_mask=shifted_window_attn_mask,
|
||||
)
|
||||
else:
|
||||
message = single_head_full_attention(query, key, value) # [B, L, C]
|
||||
|
||||
message = self.merge(message) # [B, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
if not self.no_ffn:
|
||||
message = self.mlp(torch.cat([source, message], dim=-1))
|
||||
message = self.norm2(message)
|
||||
|
||||
return source + message
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""self attention + cross attention + FFN"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=256,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
ffn_dim_expansion=4,
|
||||
with_shift=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(TransformerBlock, self).__init__()
|
||||
|
||||
self.self_attn = TransformerLayer(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
no_ffn=True,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=with_shift,
|
||||
)
|
||||
|
||||
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=with_shift,
|
||||
)
|
||||
|
||||
def forward(self, source, target,
|
||||
height=None,
|
||||
width=None,
|
||||
shifted_window_attn_mask=None,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
# source, target: [B, L, C]
|
||||
|
||||
# self attention
|
||||
source = self.self_attn(source, source,
|
||||
height=height,
|
||||
width=width,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
# cross attention and ffn
|
||||
source = self.cross_attn_ffn(source, target,
|
||||
height=height,
|
||||
width=width,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
return source
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers=6,
|
||||
d_model=128,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
ffn_dim_expansion=4,
|
||||
**kwargs,
|
||||
):
|
||||
super(FeatureTransformer, self).__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
|
||||
)
|
||||
for i in range(num_layers)])
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feature0, feature1,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
b, c, h, w = feature0.shape
|
||||
assert self.d_model == c
|
||||
|
||||
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||
|
||||
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||
# global and refine use different number of splits
|
||||
window_size_h = h // attn_num_splits
|
||||
window_size_w = w // attn_num_splits
|
||||
|
||||
# compute attn mask once
|
||||
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
||||
input_resolution=(h, w),
|
||||
window_size_h=window_size_h,
|
||||
window_size_w=window_size_w,
|
||||
shift_size_h=window_size_h // 2,
|
||||
shift_size_w=window_size_w // 2,
|
||||
device=feature0.device,
|
||||
) # [K*K, H/K*W/K, H/K*W/K]
|
||||
else:
|
||||
shifted_window_attn_mask = None
|
||||
|
||||
# concat feature0 and feature1 in batch dimension to compute in parallel
|
||||
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
||||
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
||||
|
||||
for layer in self.layers:
|
||||
concat0 = layer(concat0, concat1,
|
||||
height=h,
|
||||
width=w,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
# update feature1
|
||||
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
||||
|
||||
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
||||
|
||||
# reshape back
|
||||
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||
|
||||
return feature0, feature1
|
||||
|
||||
|
||||
class FeatureFlowAttention(nn.Module):
|
||||
"""
|
||||
flow propagation with self-attention on feature
|
||||
query: feature0, key: feature0, value: flow
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels,
|
||||
**kwargs,
|
||||
):
|
||||
super(FeatureFlowAttention, self).__init__()
|
||||
|
||||
self.q_proj = nn.Linear(in_channels, in_channels)
|
||||
self.k_proj = nn.Linear(in_channels, in_channels)
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feature0, flow,
|
||||
local_window_attn=False,
|
||||
local_window_radius=1,
|
||||
**kwargs,
|
||||
):
|
||||
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
||||
if local_window_attn:
|
||||
return self.forward_local_window_attn(feature0, flow,
|
||||
local_window_radius=local_window_radius)
|
||||
|
||||
b, c, h, w = feature0.size()
|
||||
|
||||
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
||||
|
||||
query = self.q_proj(query) # [B, H*W, C]
|
||||
key = self.k_proj(query) # [B, H*W, C]
|
||||
|
||||
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
||||
|
||||
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
|
||||
prob = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(prob, value) # [B, H*W, 2]
|
||||
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
||||
|
||||
return out
|
||||
|
||||
def forward_local_window_attn(self, feature0, flow,
|
||||
local_window_radius=1,
|
||||
):
|
||||
assert flow.size(1) == 2
|
||||
assert local_window_radius > 0
|
||||
|
||||
b, c, h, w = feature0.size()
|
||||
|
||||
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
|
||||
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
|
||||
|
||||
kernel_size = 2 * local_window_radius + 1
|
||||
|
||||
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
|
||||
|
||||
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
|
||||
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
|
||||
|
||||
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
|
||||
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
|
||||
|
||||
flow_window = F.unfold(flow, kernel_size=kernel_size,
|
||||
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
|
||||
|
||||
flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
|
||||
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
|
||||
|
||||
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
|
||||
|
||||
prob = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
||||
|
||||
return out
|
||||
90
sgm_vfi_arch/trident_conv.py
Normal file
90
sgm_vfi_arch/trident_conv.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
|
||||
class MultiScaleTridentConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
strides=1,
|
||||
paddings=0,
|
||||
dilations=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
num_branch=1,
|
||||
test_branch_idx=-1,
|
||||
bias=False,
|
||||
norm=None,
|
||||
activation=None,
|
||||
):
|
||||
super(MultiScaleTridentConv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.num_branch = num_branch
|
||||
self.stride = _pair(stride)
|
||||
self.groups = groups
|
||||
self.with_bias = bias
|
||||
self.dilation = dilation
|
||||
if isinstance(paddings, int):
|
||||
paddings = [paddings] * self.num_branch
|
||||
if isinstance(dilations, int):
|
||||
dilations = [dilations] * self.num_branch
|
||||
if isinstance(strides, int):
|
||||
strides = [strides] * self.num_branch
|
||||
self.paddings = [_pair(padding) for padding in paddings]
|
||||
self.dilations = [_pair(dilation) for dilation in dilations]
|
||||
self.strides = [_pair(stride) for stride in strides]
|
||||
self.test_branch_idx = test_branch_idx
|
||||
self.norm = norm
|
||||
self.activation = activation
|
||||
|
||||
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
||||
)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0)
|
||||
|
||||
def forward(self, inputs):
|
||||
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||
assert len(inputs) == num_branch
|
||||
|
||||
if self.training or self.test_branch_idx == -1:
|
||||
outputs = [
|
||||
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
|
||||
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
F.conv2d(
|
||||
inputs[0],
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
|
||||
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
]
|
||||
|
||||
if self.norm is not None:
|
||||
outputs = [self.norm(x) for x in outputs]
|
||||
if self.activation is not None:
|
||||
outputs = [self.activation(x) for x in outputs]
|
||||
return outputs
|
||||
98
sgm_vfi_arch/utils.py
Normal file
98
sgm_vfi_arch/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .position import PositionEmbeddingSine
|
||||
from .geometry import coords_grid, generate_window_grid, normalize_coords
|
||||
|
||||
|
||||
def split_feature(feature,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B, H, W, C]
|
||||
b, h, w, c = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
||||
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
||||
else: # [B, C, H, W]
|
||||
b, c, h, w = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
||||
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
||||
|
||||
return feature
|
||||
|
||||
def merge_splits(splits,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B*K*K, H/K, W/K, C]
|
||||
b, h, w, c = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
||||
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
||||
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
||||
else: # [B*K*K, C, H/K, W/K]
|
||||
b, c, h, w = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
||||
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
||||
|
||||
return merge
|
||||
|
||||
|
||||
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
||||
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
||||
|
||||
if attn_splits > 1: # add position in splited window
|
||||
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
||||
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
||||
|
||||
position = pos_enc(feature0_splits)
|
||||
|
||||
feature0_splits = feature0_splits + position
|
||||
feature1_splits = feature1_splits + position
|
||||
|
||||
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
||||
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
||||
else:
|
||||
position = pos_enc(feature0)
|
||||
|
||||
feature0 = feature0 + position
|
||||
feature1 = feature1 + position
|
||||
|
||||
return feature0, feature1
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
|
||||
def __init__(self, dims, mode='sintel', padding_factor=8, additional_pad=False):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
add_pad = padding_factor*2 if additional_pad else 0
|
||||
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + add_pad
|
||||
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + add_pad
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
||||
else:
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
def unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
25
sgm_vfi_arch/warplayer.py
Normal file
25
sgm_vfi_arch/warplayer.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
|
||||
backwarp_tenGrid = {}
|
||||
|
||||
|
||||
def clear_warp_cache():
|
||||
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
|
||||
backwarp_tenGrid.clear()
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow):
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
|
||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
|
||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
backwarp_tenGrid[k] = torch.cat(
|
||||
[tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
||||
|
||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
||||
Reference in New Issue
Block a user