SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .backbone import CNNEncoder
|
|
from .transformer import FeatureTransformer, FeatureFlowAttention
|
|
from .utils import feature_add_position
|
|
|
|
class GMFlow(nn.Module):
|
|
def __init__(self,
|
|
num_scales=1,
|
|
upsample_factor=8,
|
|
feature_channels=128,
|
|
attention_type='swin',
|
|
num_transformer_layers=6,
|
|
ffn_dim_expansion=4,
|
|
num_head=1,
|
|
**kwargs,
|
|
):
|
|
super(GMFlow, self).__init__()
|
|
|
|
self.num_scales = num_scales
|
|
self.feature_channels = feature_channels
|
|
self.upsample_factor = upsample_factor
|
|
self.attention_type = attention_type
|
|
self.num_transformer_layers = num_transformer_layers
|
|
|
|
# CNN backbone
|
|
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
|
|
|
# Transformer
|
|
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
|
|
d_model=feature_channels,
|
|
nhead=num_head,
|
|
attention_type=attention_type,
|
|
ffn_dim_expansion=ffn_dim_expansion,
|
|
)
|
|
|
|
def extract_feature(self, img0, img1):
|
|
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
|
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
|
|
|
|
# reverse: resolution from low to high
|
|
features = features[::-1]
|
|
|
|
feature0, feature1 = [], []
|
|
|
|
for i in range(len(features)):
|
|
feature = features[i]
|
|
chunks = torch.chunk(feature, 2, 0) # tuple
|
|
feature0.append(chunks[0])
|
|
feature1.append(chunks[1])
|
|
|
|
return feature0, feature1
|
|
|
|
def forward(self, img0, img1,
|
|
attn_splits_list=None,
|
|
corr_radius_list=None,
|
|
prop_radius_list=None,
|
|
pred_bidir_flow=False,
|
|
**kwargs,
|
|
):
|
|
|
|
results_dict = {}
|
|
flow_preds = []
|
|
flow_s_macthing = []
|
|
flow_s_prop = []
|
|
transformer_features = []
|
|
|
|
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
|
|
|
flow = None
|
|
|
|
for scale_idx in range(self.num_scales):
|
|
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
|
|
|
attn_splits = attn_splits_list[scale_idx]
|
|
|
|
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
|
|
|
|
# Transformer
|
|
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
|
|
transformer_features.append(torch.cat([feature0, feature1], 0))
|
|
|
|
results_dict.update({'trans_feat': transformer_features})
|
|
|
|
return results_dict
|