Add SGM-VFI (CVPR 2024) frame interpolation support
SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
96
sgm_vfi_arch/geometry.py
Normal file
96
sgm_vfi_arch/geometry.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def coords_grid(b, h, w, homogeneous=False, device=None):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
||||
|
||||
stacks = [x, y]
|
||||
|
||||
if homogeneous:
|
||||
ones = torch.ones_like(x) # [H, W]
|
||||
stacks.append(ones)
|
||||
|
||||
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
||||
|
||||
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
||||
|
||||
if device is not None:
|
||||
grid = grid.to(device)
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
||||
assert device is not None
|
||||
|
||||
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
||||
torch.linspace(h_min, h_max, len_h, device=device)],
|
||||
)
|
||||
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def normalize_coords(coords, h, w):
|
||||
# coords: [B, H, W, 2]
|
||||
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
||||
return (coords - c) / c # [-1, 1]
|
||||
|
||||
|
||||
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
||||
# img: [B, C, H, W]
|
||||
# sample_coords: [B, 2, H, W] in image scale
|
||||
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
||||
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
||||
|
||||
b, _, h, w = sample_coords.shape
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
||||
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
||||
|
||||
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
||||
|
||||
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
||||
|
||||
if return_mask:
|
||||
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
||||
|
||||
return img, mask
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
||||
b, c, h, w = feature.size()
|
||||
assert flow.size(1) == 2
|
||||
|
||||
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
||||
|
||||
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
||||
return_mask=mask)
|
||||
|
||||
|
||||
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
||||
alpha=0.01,
|
||||
beta=0.5
|
||||
):
|
||||
# fwd_flow, bwd_flow: [B, 2, H, W]
|
||||
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
||||
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
||||
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
||||
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
||||
|
||||
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
||||
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
||||
|
||||
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
||||
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
||||
|
||||
threshold = alpha * flow_mag + beta
|
||||
|
||||
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
||||
bwd_occ = (diff_bwd > threshold).float()
|
||||
|
||||
return fwd_occ, bwd_occ
|
||||
Reference in New Issue
Block a user