Add SGM-VFI (CVPR 2024) frame interpolation support
SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
450
sgm_vfi_arch/transformer.py
Normal file
450
sgm_vfi_arch/transformer.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def split_feature(feature,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B, H, W, C]
|
||||
b, h, w, c = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
||||
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
||||
else: # [B, C, H, W]
|
||||
b, c, h, w = feature.size()
|
||||
assert h % num_splits == 0 and w % num_splits == 0, f'h: {h}, w: {w}, num_splits: {num_splits}'
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
h_new = h // num_splits
|
||||
w_new = w // num_splits
|
||||
|
||||
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
||||
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def merge_splits(splits,
|
||||
num_splits=2,
|
||||
channel_last=False,
|
||||
):
|
||||
if channel_last: # [B*K*K, H/K, W/K, C]
|
||||
b, h, w, c = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
||||
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
||||
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
||||
else: # [B*K*K, C, H/K, W/K]
|
||||
b, c, h, w = splits.size()
|
||||
new_b = b // num_splits // num_splits
|
||||
|
||||
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
||||
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
||||
|
||||
return merge
|
||||
|
||||
|
||||
def single_head_full_attention(q, k, v):
|
||||
# q, k, v: [B, L, C]
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
|
||||
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
|
||||
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
||||
out = torch.matmul(attn, v) # [B, L, C]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
|
||||
shift_size_h, shift_size_w, device=None):
|
||||
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||
# calculate attention mask for SW-MSA
|
||||
h, w = input_resolution
|
||||
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
||||
h_slices = (slice(0, -window_size_h),
|
||||
slice(-window_size_h, -shift_size_h),
|
||||
slice(-shift_size_h, None))
|
||||
w_slices = (slice(0, -window_size_w),
|
||||
slice(-window_size_w, -shift_size_w),
|
||||
slice(-shift_size_w, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
|
||||
|
||||
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
def single_head_split_window_attention(q, k, v,
|
||||
num_splits=1,
|
||||
with_shift=False,
|
||||
h=None,
|
||||
w=None,
|
||||
attn_mask=None,
|
||||
):
|
||||
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||
# q, k, v: [B, L, C]
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
|
||||
assert h is not None and w is not None
|
||||
assert q.size(1) == h * w
|
||||
|
||||
b, _, c = q.size()
|
||||
|
||||
b_new = b * num_splits * num_splits
|
||||
|
||||
window_size_h = h // num_splits
|
||||
window_size_w = w // num_splits
|
||||
|
||||
q = q.view(b, h, w, c) # [B, H, W, C]
|
||||
k = k.view(b, h, w, c)
|
||||
v = v.view(b, h, w, c)
|
||||
|
||||
scale_factor = c ** 0.5
|
||||
|
||||
if with_shift:
|
||||
assert attn_mask is not None # compute once
|
||||
shift_size_h = window_size_h // 2
|
||||
shift_size_w = window_size_w // 2
|
||||
|
||||
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||
|
||||
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
||||
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
||||
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
||||
|
||||
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
||||
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
|
||||
|
||||
if with_shift:
|
||||
scores += attn_mask.repeat(b, 1, 1)
|
||||
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
||||
|
||||
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
|
||||
num_splits=num_splits, channel_last=True) # [B, H, W, C]
|
||||
|
||||
# shift back
|
||||
if with_shift:
|
||||
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
||||
|
||||
out = out.view(b, -1, c)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self,
|
||||
d_model=256,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
no_ffn=False,
|
||||
ffn_dim_expansion=4,
|
||||
with_shift=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(TransformerLayer, self).__init__()
|
||||
|
||||
self.dim = d_model
|
||||
self.nhead = nhead
|
||||
self.attention_type = attention_type
|
||||
self.no_ffn = no_ffn
|
||||
|
||||
self.with_shift = with_shift
|
||||
|
||||
# multi-head attention
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
||||
# no ffn after self-attn, with ffn after cross-attn
|
||||
if not self.no_ffn:
|
||||
in_channels = d_model * 2
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, source, target,
|
||||
height=None,
|
||||
width=None,
|
||||
shifted_window_attn_mask=None,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
# source, target: [B, L, C]
|
||||
query, key, value = source, target, target
|
||||
|
||||
# single-head attention
|
||||
query = self.q_proj(query) # [B, L, C]
|
||||
key = self.k_proj(key) # [B, L, C]
|
||||
value = self.v_proj(value) # [B, L, C]
|
||||
|
||||
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||
if self.nhead > 1:
|
||||
# we observe that multihead attention slows down the speed and increases the memory consumption
|
||||
# without bringing obvious performance gains and thus the implementation is removed
|
||||
raise NotImplementedError
|
||||
else:
|
||||
message = single_head_split_window_attention(query, key, value,
|
||||
num_splits=attn_num_splits,
|
||||
with_shift=self.with_shift,
|
||||
h=height,
|
||||
w=width,
|
||||
attn_mask=shifted_window_attn_mask,
|
||||
)
|
||||
else:
|
||||
message = single_head_full_attention(query, key, value) # [B, L, C]
|
||||
|
||||
message = self.merge(message) # [B, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
if not self.no_ffn:
|
||||
message = self.mlp(torch.cat([source, message], dim=-1))
|
||||
message = self.norm2(message)
|
||||
|
||||
return source + message
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""self attention + cross attention + FFN"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=256,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
ffn_dim_expansion=4,
|
||||
with_shift=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(TransformerBlock, self).__init__()
|
||||
|
||||
self.self_attn = TransformerLayer(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
no_ffn=True,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=with_shift,
|
||||
)
|
||||
|
||||
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=with_shift,
|
||||
)
|
||||
|
||||
def forward(self, source, target,
|
||||
height=None,
|
||||
width=None,
|
||||
shifted_window_attn_mask=None,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
# source, target: [B, L, C]
|
||||
|
||||
# self attention
|
||||
source = self.self_attn(source, source,
|
||||
height=height,
|
||||
width=width,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
# cross attention and ffn
|
||||
source = self.cross_attn_ffn(source, target,
|
||||
height=height,
|
||||
width=width,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
return source
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers=6,
|
||||
d_model=128,
|
||||
nhead=1,
|
||||
attention_type='swin',
|
||||
ffn_dim_expansion=4,
|
||||
**kwargs,
|
||||
):
|
||||
super(FeatureTransformer, self).__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(d_model=d_model,
|
||||
nhead=nhead,
|
||||
attention_type=attention_type,
|
||||
ffn_dim_expansion=ffn_dim_expansion,
|
||||
with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
|
||||
)
|
||||
for i in range(num_layers)])
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feature0, feature1,
|
||||
attn_num_splits=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
b, c, h, w = feature0.shape
|
||||
assert self.d_model == c
|
||||
|
||||
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||
|
||||
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||
# global and refine use different number of splits
|
||||
window_size_h = h // attn_num_splits
|
||||
window_size_w = w // attn_num_splits
|
||||
|
||||
# compute attn mask once
|
||||
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
||||
input_resolution=(h, w),
|
||||
window_size_h=window_size_h,
|
||||
window_size_w=window_size_w,
|
||||
shift_size_h=window_size_h // 2,
|
||||
shift_size_w=window_size_w // 2,
|
||||
device=feature0.device,
|
||||
) # [K*K, H/K*W/K, H/K*W/K]
|
||||
else:
|
||||
shifted_window_attn_mask = None
|
||||
|
||||
# concat feature0 and feature1 in batch dimension to compute in parallel
|
||||
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
||||
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
||||
|
||||
for layer in self.layers:
|
||||
concat0 = layer(concat0, concat1,
|
||||
height=h,
|
||||
width=w,
|
||||
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||
attn_num_splits=attn_num_splits,
|
||||
)
|
||||
|
||||
# update feature1
|
||||
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
||||
|
||||
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
||||
|
||||
# reshape back
|
||||
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||
|
||||
return feature0, feature1
|
||||
|
||||
|
||||
class FeatureFlowAttention(nn.Module):
|
||||
"""
|
||||
flow propagation with self-attention on feature
|
||||
query: feature0, key: feature0, value: flow
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels,
|
||||
**kwargs,
|
||||
):
|
||||
super(FeatureFlowAttention, self).__init__()
|
||||
|
||||
self.q_proj = nn.Linear(in_channels, in_channels)
|
||||
self.k_proj = nn.Linear(in_channels, in_channels)
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feature0, flow,
|
||||
local_window_attn=False,
|
||||
local_window_radius=1,
|
||||
**kwargs,
|
||||
):
|
||||
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
||||
if local_window_attn:
|
||||
return self.forward_local_window_attn(feature0, flow,
|
||||
local_window_radius=local_window_radius)
|
||||
|
||||
b, c, h, w = feature0.size()
|
||||
|
||||
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
||||
|
||||
query = self.q_proj(query) # [B, H*W, C]
|
||||
key = self.k_proj(query) # [B, H*W, C]
|
||||
|
||||
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
||||
|
||||
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
|
||||
prob = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(prob, value) # [B, H*W, 2]
|
||||
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
||||
|
||||
return out
|
||||
|
||||
def forward_local_window_attn(self, feature0, flow,
|
||||
local_window_radius=1,
|
||||
):
|
||||
assert flow.size(1) == 2
|
||||
assert local_window_radius > 0
|
||||
|
||||
b, c, h, w = feature0.size()
|
||||
|
||||
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
|
||||
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
|
||||
|
||||
kernel_size = 2 * local_window_radius + 1
|
||||
|
||||
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
|
||||
|
||||
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
|
||||
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
|
||||
|
||||
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
|
||||
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
|
||||
|
||||
flow_window = F.unfold(flow, kernel_size=kernel_size,
|
||||
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
|
||||
|
||||
flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
|
||||
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
|
||||
|
||||
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
|
||||
|
||||
prob = torch.softmax(scores, dim=-1)
|
||||
|
||||
out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user