Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# amt: https://github.com/MCG-NKU/AMT
|
||||
# motif: https://github.com/sichun233746/MoTIF
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .fi_utils import warp, resize
|
||||
|
||||
|
||||
class LateralBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(LateralBlock, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.layers(x)
|
||||
return x + res
|
||||
|
||||
|
||||
def convrelu(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(out_channels),
|
||||
)
|
||||
|
||||
|
||||
def multi_flow_combine(
|
||||
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
|
||||
):
|
||||
assert mean is None
|
||||
b, c, h, w = flow0.shape
|
||||
num_flows = c // 2
|
||||
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
|
||||
mask = (
|
||||
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
|
||||
if mask is not None
|
||||
else None
|
||||
)
|
||||
img_res = (
|
||||
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
|
||||
if img_res is not None
|
||||
else 0
|
||||
)
|
||||
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
mean = (
|
||||
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
|
||||
if mean is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
img0_warp = warp(img0, flow0)
|
||||
img1_warp = warp(img1, flow1)
|
||||
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||
|
||||
res = comb_block(img_warps.view(b, -1, h, w))
|
||||
imgt_pred = img_warps.mean(1) + res
|
||||
|
||||
imgt_pred = (imgt_pred + 1.0) / 2
|
||||
|
||||
return imgt_pred
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, side_channels, bias=True):
|
||||
super(ResBlock, self).__init__()
|
||||
self.side_channels = side_channels
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv5 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
)
|
||||
self.prelu = nn.PReLU(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv2(side_feat)
|
||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv4(side_feat)
|
||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
out = self.prelu(x + out)
|
||||
return out
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cdim,
|
||||
hidden_dim,
|
||||
flow_dim,
|
||||
corr_dim,
|
||||
corr_dim2,
|
||||
fc_dim,
|
||||
corr_levels=4,
|
||||
radius=3,
|
||||
scale_factor=None,
|
||||
out_num=1,
|
||||
):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.feat_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.flow_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
||||
)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = (
|
||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
||||
)
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
cor = self.lrelu(self.convc2(cor))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
flo = self.lrelu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
inp = self.lrelu(self.conv(cor_flo))
|
||||
inp = torch.cat([inp, flow, net], dim=1)
|
||||
|
||||
out = self.gru(inp)
|
||||
delta_net = self.feat_head(out)
|
||||
delta_flow = self.flow_head(out)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(
|
||||
delta_flow, scale_factor=self.scale_factor
|
||||
)
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
def get_bn():
|
||||
return nn.BatchNorm2d
|
||||
|
||||
|
||||
class NewInitDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch):
|
||||
super().__init__()
|
||||
norm_layer = get_bn()
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.PixelShuffle(2),
|
||||
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 2),
|
||||
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||
norm_layer(in_ch // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
in_ch = in_ch // 2
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
|
||||
)
|
||||
|
||||
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
|
||||
f0 = self.upsample(f0)
|
||||
f1 = self.upsample(f1)
|
||||
f0_warp_ks = warp(f0, flow0_in)
|
||||
f1_warp_ks = warp(f1, flow1_in)
|
||||
|
||||
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
|
||||
|
||||
assert img0 is not None
|
||||
assert img1 is not None
|
||||
scale_factor = f_in.shape[2] / img0.shape[2]
|
||||
img0 = resize(img0, scale_factor=scale_factor)
|
||||
img1 = resize(img1, scale_factor=scale_factor)
|
||||
warped_img0 = warp(img0, flow0_in)
|
||||
warped_img1 = warp(img1, flow1_in)
|
||||
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||
|
||||
out = self.convblock(f_in)
|
||||
ft_ = out[:, 4:, ...]
|
||||
flow0 = flow0_in + out[:, :2, ...]
|
||||
flow1 = flow1_in + out[:, 2:4, ...]
|
||||
return flow0, flow1, ft_
|
||||
|
||||
|
||||
class NewMultiFlowDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||
super(NewMultiFlowDecoder, self).__init__()
|
||||
norm_layer = get_bn()
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.PixelShuffle(2),
|
||||
nn.PixelShuffle(2),
|
||||
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 2),
|
||||
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||
norm_layer(in_ch // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.num_flows = num_flows
|
||||
ch_factor = 2
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
|
||||
f0 = self.upsample(f0)
|
||||
# print([f1.shape,f0.shape])
|
||||
f1 = self.upsample(f1)
|
||||
n = self.num_flows
|
||||
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
|
||||
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
|
||||
|
||||
ft_ = resize(ft_, scale_factor=4.0)
|
||||
mask = resize(mask, scale_factor=4.0)
|
||||
f0_warp = warp(f0, flow0)
|
||||
f1_warp = warp(f1, flow1)
|
||||
|
||||
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
|
||||
|
||||
assert mask is not None
|
||||
f_in = torch.cat([f_in, mask], 1)
|
||||
|
||||
assert img0 is not None
|
||||
assert img1 is not None
|
||||
warped_img0 = warp(img0, flow0)
|
||||
warped_img1 = warp(img1, flow1)
|
||||
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||
|
||||
out = self.convblock(f_in)
|
||||
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
|
||||
out, [2 * n, 2 * n, n, 3 * n], 1
|
||||
)
|
||||
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
|
||||
mask = torch.sigmoid(mask)
|
||||
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
|
||||
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
|
||||
|
||||
return flow0, flow1, mask, img_res
|
||||
Reference in New Issue
Block a user