Files
ComfyUI-Tween/gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Ethanfel d642255e70 Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 13:11:45 +01:00

341 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# motif: https://github.com/sichun233746/MoTIF
# --------------------------------------------------------
import torch
import torch.nn as nn
from .fi_utils import warp, resize
class LateralBlock(nn.Module):
def __init__(self, dim):
super(LateralBlock, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
)
def forward(self, x):
res = x
x = self.layers(x)
return x + res
def convrelu(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
):
return nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias,
),
nn.PReLU(out_channels),
)
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
assert mean is None
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
res = comb_block(img_warps.view(b, -1, h, w))
imgt_pred = img_warps.mean(1) + res
imgt_pred = (imgt_pred + 1.0) / 2
return imgt_pred
class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__()
self.side_channels = side_channels
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv3 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv4 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv5 = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
out = self.conv1(x)
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv2(side_feat)
out = self.conv3(torch.cat([res_feat, side_feat], 1))
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv4(side_feat)
out = self.conv5(torch.cat([res_feat, side_feat], 1))
out = self.prelu(x + out)
return out
class BasicUpdateBlock(nn.Module):
def __init__(
self,
cdim,
hidden_dim,
flow_dim,
corr_dim,
corr_dim2,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
out_num=1,
):
super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) ** 2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
self.feat_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
)
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = (
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
inp = self.lrelu(self.conv(cor_flo))
inp = torch.cat([inp, flow, net], dim=1)
out = self.gru(inp)
delta_net = self.feat_head(out)
delta_flow = self.flow_head(out)
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow
def get_bn():
return nn.BatchNorm2d
class NewInitDecoder(nn.Module):
def __init__(self, in_ch, skip_ch):
super().__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
in_ch = in_ch // 2
self.convblock = nn.Sequential(
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
)
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
f0 = self.upsample(f0)
f1 = self.upsample(f1)
f0_warp_ks = warp(f0, flow0_in)
f1_warp_ks = warp(f1, flow1_in)
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
assert img0 is not None
assert img1 is not None
scale_factor = f_in.shape[2] / img0.shape[2]
img0 = resize(img0, scale_factor=scale_factor)
img1 = resize(img1, scale_factor=scale_factor)
warped_img0 = warp(img0, flow0_in)
warped_img1 = warp(img1, flow1_in)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
ft_ = out[:, 4:, ...]
flow0 = flow0_in + out[:, :2, ...]
flow1 = flow1_in + out[:, 2:4, ...]
return flow0, flow1, ft_
class NewMultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(NewMultiFlowDecoder, self).__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
nn.PixelShuffle(2),
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
self.num_flows = num_flows
ch_factor = 2
self.convblock = nn.Sequential(
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
)
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
f0 = self.upsample(f0)
# print([f1.shape,f0.shape])
f1 = self.upsample(f1)
n = self.num_flows
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
ft_ = resize(ft_, scale_factor=4.0)
mask = resize(mask, scale_factor=4.0)
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
assert mask is not None
f_in = torch.cat([f_in, mask], 1)
assert img0 is not None
assert img1 is not None
warped_img0 = warp(img0, flow0)
warped_img1 = warp(img1, flow1)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res