SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
209 lines
9.1 KiB
Python
209 lines
9.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .refine import *
|
|
from .matching import MatchingBlock
|
|
from .gmflow import GMFlow
|
|
from .utils import InputPadder
|
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
padding=padding, dilation=dilation, bias=True),
|
|
nn.PReLU(out_planes)
|
|
)
|
|
|
|
|
|
class IFBlock(nn.Module):
|
|
def __init__(self, in_planes, c=64, layers=4, scale=4, in_else=17):
|
|
super(IFBlock, self).__init__()
|
|
self.scale = scale
|
|
|
|
self.conv0 = nn.Sequential(
|
|
conv(in_planes + in_else, c, 3, 1, 1),
|
|
conv(c, c, 3, 1, 1),
|
|
)
|
|
|
|
self.convblock = nn.Sequential(
|
|
*[conv(c, c) for _ in range(layers)]
|
|
)
|
|
|
|
self.lastconv = conv(c, 5)
|
|
|
|
def forward(self, x, flow=None, feature=None):
|
|
if self.scale != 1:
|
|
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
|
|
if flow != None:
|
|
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
|
|
align_corners=False) * 1. / self.scale
|
|
x = torch.cat((x, flow), 1)
|
|
if feature != None:
|
|
x = torch.cat((x, feature), 1)
|
|
x = self.conv0(x)
|
|
x = self.convblock(x) + x
|
|
tmp = self.lastconv(x)
|
|
flow_s = tmp[:, :4]
|
|
tmp = F.interpolate(tmp, scale_factor=self.scale, mode="bilinear", align_corners=False)
|
|
flow = tmp[:, :4] * self.scale
|
|
mask = tmp[:, 4:5]
|
|
return flow, mask, flow_s
|
|
|
|
|
|
class MultiScaleFlow(nn.Module):
|
|
def __init__(self, backbone, **kargs):
|
|
super(MultiScaleFlow, self).__init__()
|
|
self.flow_num_stage = len(kargs['hidden_dims'])
|
|
self.feature_bone = backbone
|
|
self.scale = [1, 2, 4, 8]
|
|
self.num_key_points = [kargs['num_key_points']]
|
|
self.block = nn.ModuleList(
|
|
[IFBlock(kargs['embed_dims'][-1] * 2, 128, 2, self.scale[-1], in_else=7), # 1/8
|
|
IFBlock(kargs['embed_dims'][-2] * 2, 128, 2, self.scale[-2], in_else=18)]) # 1/4
|
|
self.contextnet = Contextnet(kargs['c'] * 2)
|
|
self.unet = Unet(kargs['c'] * 2)
|
|
self.gmflow = GMFlow(
|
|
num_scales=1,
|
|
upsample_factor=8,
|
|
feature_channels=128,
|
|
attention_type='swin',
|
|
num_transformer_layers=6,
|
|
ffn_dim_expansion=4,
|
|
num_head=1)
|
|
|
|
self.matching_block = nn.ModuleList([
|
|
MatchingBlock(scale=8, dim=kargs['embed_dims'][-1], c=kargs['c'] * 4, num_layers=1, gm=True),
|
|
None
|
|
])
|
|
|
|
self.padding_factor = 16
|
|
|
|
|
|
def calculate_flow(self, imgs, timestep):
|
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
|
B = img0.size(0)
|
|
flow, mask = None, None
|
|
flow_s = None
|
|
|
|
af = self.feature_bone(img0, img1)
|
|
if self.gmflow is not None:
|
|
padder = InputPadder(img0.shape, padding_factor=self.padding_factor)
|
|
img0_p, img1_p = padder.pad(img0, img1)
|
|
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
|
matching_feat = results['trans_feat']
|
|
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1])
|
|
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
|
|
|
for i in range(2):
|
|
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
|
af0 = af[-1 - i][:B]
|
|
af1 = af[-1 - i][B:]
|
|
if flow != None:
|
|
flow_d, mask_d, flow_s_d = self.block[i](
|
|
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
|
flow,
|
|
torch.cat([af0, af1], 1),
|
|
)
|
|
flow = flow + flow_d
|
|
mask = mask + mask_d
|
|
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
|
flow_s = flow_s + flow_s_d
|
|
else:
|
|
flow, mask, flow_s = self.block[i](
|
|
torch.cat((img0, img1, t), 1),
|
|
None,
|
|
torch.cat([af0, af1], 1))
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
if self.matching_block[i] is not None:
|
|
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1 - i],
|
|
init_flow=flow, init_flow_s=flow_s, init_mask=mask,
|
|
warped_img0=warped_img0, warped_img1=warped_img1,
|
|
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1 - i],
|
|
timestep=timestep)
|
|
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
|
flow = flow + flow_t
|
|
mask = mask + mask_t
|
|
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
return flow, mask
|
|
|
|
def coraseWarp_and_Refine(self, imgs, flow, mask):
|
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
c0 = self.contextnet(img0, flow[:, :2])
|
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
res = tmp[:, :3] * 2 - 1
|
|
mask_ = torch.sigmoid(mask)
|
|
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
|
|
pred = torch.clamp(merged + res, 0, 1)
|
|
return pred
|
|
|
|
def forward(self, x, timestep=0.5):
|
|
img0, img1 = x[:, :3], x[:, 3:6]
|
|
B = x.size(0)
|
|
flow_list, mask_list = [], []
|
|
merged, merged_fine = [], []
|
|
warped_img0, warped_img1 = img0, img1
|
|
flow, mask, flow_s = None, None, None
|
|
flow_matching_list = []
|
|
matching_feat = []
|
|
af = self.feature_bone(img0, img1)
|
|
if self.gmflow is not None:
|
|
padder = InputPadder(img0.shape, padding_factor=self.padding_factor, additional_pad=False)
|
|
img0_p, img1_p = padder.pad(img0, img1)
|
|
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
|
matching_feat = results['trans_feat']
|
|
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1], additional_pad=False)
|
|
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
|
|
|
for i in range(2):
|
|
af0 = af[-1 - i][:B]
|
|
af1 = af[-1 - i][B:]
|
|
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
|
if flow != None:
|
|
flow_d, mask_d, flow_s_d = self.block[i](
|
|
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
|
flow,
|
|
torch.cat([af0, af1], 1),
|
|
)
|
|
flow = flow + flow_d
|
|
mask = mask + mask_d
|
|
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
|
flow_s = flow_s + flow_s_d
|
|
else:
|
|
flow, mask, flow_s = self.block[i](
|
|
torch.cat((img0, img1, t), 1),
|
|
None,
|
|
torch.cat([af0, af1], 1))
|
|
mask_list.append(torch.sigmoid(mask))
|
|
flow_list.append(flow)
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
|
|
if self.matching_block[i] is not None:
|
|
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1-i].detach(),
|
|
init_flow=flow.detach(), init_flow_s=flow_s.detach(), init_mask=mask.detach(),
|
|
warped_img0=warped_img0.detach(), warped_img1=warped_img1.detach(),
|
|
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1-i],
|
|
timestep=0.5)
|
|
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
|
flow = flow + flow_t
|
|
mask = mask + mask_t
|
|
mask_list[i] = torch.sigmoid(mask)
|
|
warped_img0_fine = warp(img0, flow[:, 0:2])
|
|
warped_img1_fine = warp(img1, flow[:, 2:4])
|
|
merged_fine.append(warped_img0_fine * mask_list[i] + warped_img1_fine * (1 - mask_list[i]))
|
|
warped_img0, warped_img1 = warped_img0_fine, warped_img1_fine # NOTE: for next iteration training
|
|
c0 = self.contextnet(img0, flow[:, :2])
|
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
res = tmp[:, :3] * 2 - 1
|
|
pred = torch.clamp(merged[-1] + res, 0, 1)
|
|
merged.extend(merged_fine)
|
|
return flow_list, mask_list, merged, pred, flow_matching_list
|