Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
341 lines
11 KiB
Python
341 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# amt: https://github.com/MCG-NKU/AMT
|
|
# motif: https://github.com/sichun233746/MoTIF
|
|
# --------------------------------------------------------
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from .fi_utils import warp, resize
|
|
|
|
|
|
class LateralBlock(nn.Module):
|
|
def __init__(self, dim):
|
|
super(LateralBlock, self).__init__()
|
|
self.layers = nn.Sequential(
|
|
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = x
|
|
x = self.layers(x)
|
|
return x + res
|
|
|
|
|
|
def convrelu(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
return nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
bias=bias,
|
|
),
|
|
nn.PReLU(out_channels),
|
|
)
|
|
|
|
|
|
def multi_flow_combine(
|
|
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
|
|
):
|
|
assert mean is None
|
|
b, c, h, w = flow0.shape
|
|
num_flows = c // 2
|
|
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
|
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
|
|
|
mask = (
|
|
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
|
|
if mask is not None
|
|
else None
|
|
)
|
|
img_res = (
|
|
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
|
|
if img_res is not None
|
|
else 0
|
|
)
|
|
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
|
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
|
mean = (
|
|
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
|
|
if mean is not None
|
|
else 0
|
|
)
|
|
|
|
img0_warp = warp(img0, flow0)
|
|
img1_warp = warp(img1, flow1)
|
|
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
|
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
|
|
|
res = comb_block(img_warps.view(b, -1, h, w))
|
|
imgt_pred = img_warps.mean(1) + res
|
|
|
|
imgt_pred = (imgt_pred + 1.0) / 2
|
|
|
|
return imgt_pred
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(self, in_channels, side_channels, bias=True):
|
|
super(ResBlock, self).__init__()
|
|
self.side_channels = side_channels
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
),
|
|
nn.PReLU(in_channels),
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
side_channels,
|
|
side_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=bias,
|
|
),
|
|
nn.PReLU(side_channels),
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
),
|
|
nn.PReLU(in_channels),
|
|
)
|
|
self.conv4 = nn.Sequential(
|
|
nn.Conv2d(
|
|
side_channels,
|
|
side_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=bias,
|
|
),
|
|
nn.PReLU(side_channels),
|
|
)
|
|
self.conv5 = nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
)
|
|
self.prelu = nn.PReLU(in_channels)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
|
|
res_feat = out[:, : -self.side_channels, ...]
|
|
side_feat = out[:, -self.side_channels :, :, :]
|
|
side_feat = self.conv2(side_feat)
|
|
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
|
|
|
res_feat = out[:, : -self.side_channels, ...]
|
|
side_feat = out[:, -self.side_channels :, :, :]
|
|
side_feat = self.conv4(side_feat)
|
|
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
|
|
|
out = self.prelu(x + out)
|
|
return out
|
|
|
|
|
|
class BasicUpdateBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
cdim,
|
|
hidden_dim,
|
|
flow_dim,
|
|
corr_dim,
|
|
corr_dim2,
|
|
fc_dim,
|
|
corr_levels=4,
|
|
radius=3,
|
|
scale_factor=None,
|
|
out_num=1,
|
|
):
|
|
super(BasicUpdateBlock, self).__init__()
|
|
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
|
|
|
self.scale_factor = scale_factor
|
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
|
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
|
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
|
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
|
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
|
|
|
self.gru = nn.Sequential(
|
|
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
|
)
|
|
|
|
self.feat_head = nn.Sequential(
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
|
)
|
|
|
|
self.flow_head = nn.Sequential(
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
|
)
|
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
|
|
def forward(self, net, flow, corr):
|
|
net = (
|
|
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
|
)
|
|
cor = self.lrelu(self.convc1(corr))
|
|
cor = self.lrelu(self.convc2(cor))
|
|
flo = self.lrelu(self.convf1(flow))
|
|
flo = self.lrelu(self.convf2(flo))
|
|
cor_flo = torch.cat([cor, flo], dim=1)
|
|
inp = self.lrelu(self.conv(cor_flo))
|
|
inp = torch.cat([inp, flow, net], dim=1)
|
|
|
|
out = self.gru(inp)
|
|
delta_net = self.feat_head(out)
|
|
delta_flow = self.flow_head(out)
|
|
|
|
if self.scale_factor is not None:
|
|
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
|
delta_flow = self.scale_factor * resize(
|
|
delta_flow, scale_factor=self.scale_factor
|
|
)
|
|
return delta_net, delta_flow
|
|
|
|
|
|
def get_bn():
|
|
return nn.BatchNorm2d
|
|
|
|
|
|
class NewInitDecoder(nn.Module):
|
|
def __init__(self, in_ch, skip_ch):
|
|
super().__init__()
|
|
norm_layer = get_bn()
|
|
|
|
self.upsample = nn.Sequential(
|
|
nn.PixelShuffle(2),
|
|
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 2),
|
|
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
|
norm_layer(in_ch // 2),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
in_ch = in_ch // 2
|
|
self.convblock = nn.Sequential(
|
|
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
|
|
ResBlock(in_ch, skip_ch),
|
|
ResBlock(in_ch, skip_ch),
|
|
ResBlock(in_ch, skip_ch),
|
|
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
|
|
)
|
|
|
|
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
|
|
f0 = self.upsample(f0)
|
|
f1 = self.upsample(f1)
|
|
f0_warp_ks = warp(f0, flow0_in)
|
|
f1_warp_ks = warp(f1, flow1_in)
|
|
|
|
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
|
|
|
|
assert img0 is not None
|
|
assert img1 is not None
|
|
scale_factor = f_in.shape[2] / img0.shape[2]
|
|
img0 = resize(img0, scale_factor=scale_factor)
|
|
img1 = resize(img1, scale_factor=scale_factor)
|
|
warped_img0 = warp(img0, flow0_in)
|
|
warped_img1 = warp(img1, flow1_in)
|
|
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
|
|
|
out = self.convblock(f_in)
|
|
ft_ = out[:, 4:, ...]
|
|
flow0 = flow0_in + out[:, :2, ...]
|
|
flow1 = flow1_in + out[:, 2:4, ...]
|
|
return flow0, flow1, ft_
|
|
|
|
|
|
class NewMultiFlowDecoder(nn.Module):
|
|
def __init__(self, in_ch, skip_ch, num_flows=3):
|
|
super(NewMultiFlowDecoder, self).__init__()
|
|
norm_layer = get_bn()
|
|
|
|
self.upsample = nn.Sequential(
|
|
nn.PixelShuffle(2),
|
|
nn.PixelShuffle(2),
|
|
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 4),
|
|
convrelu(in_ch // 4, in_ch // 2),
|
|
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
|
norm_layer(in_ch // 2),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
self.num_flows = num_flows
|
|
ch_factor = 2
|
|
self.convblock = nn.Sequential(
|
|
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
|
|
ResBlock(in_ch * ch_factor, skip_ch),
|
|
ResBlock(in_ch * ch_factor, skip_ch),
|
|
ResBlock(in_ch * ch_factor, skip_ch),
|
|
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
|
|
)
|
|
|
|
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
|
|
f0 = self.upsample(f0)
|
|
# print([f1.shape,f0.shape])
|
|
f1 = self.upsample(f1)
|
|
n = self.num_flows
|
|
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
|
|
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
|
|
|
|
ft_ = resize(ft_, scale_factor=4.0)
|
|
mask = resize(mask, scale_factor=4.0)
|
|
f0_warp = warp(f0, flow0)
|
|
f1_warp = warp(f1, flow1)
|
|
|
|
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
|
|
|
|
assert mask is not None
|
|
f_in = torch.cat([f_in, mask], 1)
|
|
|
|
assert img0 is not None
|
|
assert img1 is not None
|
|
warped_img0 = warp(img0, flow0)
|
|
warped_img1 = warp(img1, flow1)
|
|
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
|
|
|
out = self.convblock(f_in)
|
|
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
|
|
out, [2 * n, 2 * n, n, 3 * n], 1
|
|
)
|
|
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
|
|
mask = torch.sigmoid(mask)
|
|
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
|
|
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
|
|
|
|
return flow0, flow1, mask, img_res
|