Files
ComfyUI-Tween/gimm_vfi_arch/generalizable_INR/gimm.py
Ethanfel d642255e70 Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 13:11:45 +01:00

254 lines
9.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# motif: https://github.com/sichun233746/MoTIF
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configs import GIMMConfig
from .modules.coord_sampler import CoordSampler3D
from .modules.fi_components import LateralBlock
from .modules.hyponet import HypoNet
from .modules.fi_utils import warp
from .modules.softsplat import softsplat
class GIMM(nn.Module):
Config = GIMMConfig
def __init__(self, config: GIMMConfig):
super().__init__()
self.config = config = config.copy()
self.hyponet_config = config.hyponet
self.coord_sampler = CoordSampler3D(config.coord_range)
self.fwarp_type = config.fwarp_type
# Motion Encoder
channel = 32
in_dim = 2
self.cnn_encoder = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
LateralBlock(channel),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
# Latent Refiner
channel = 64
in_dim = 64
self.res_conv = nn.Sequential(
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
LateralBlock(channel),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
),
)
self.g_filter = torch.nn.Parameter(
torch.FloatTensor(
[
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
]
).reshape(1, 1, 1, 3, 3),
requires_grad=False,
)
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
def cal_splatting_weights(self, raft_flow01, raft_flow10):
batch_size = raft_flow01.shape[0]
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
## flow variance metric
sqaure_mean, mean_square = torch.split(
F.conv3d(
F.pad(
torch.cat([raft_flows**2, raft_flows], 1),
(1, 1, 1, 1),
mode="reflect",
).unsqueeze(1),
self.g_filter,
).squeeze(1),
2,
dim=1,
)
var = (
(sqaure_mean - mean_square**2)
.clamp(1e-9, None)
.sqrt()
.mean(1)
.unsqueeze(1)
)
var01 = var[:batch_size]
var10 = var[batch_size:]
## flow warp metirc
f01_warp = -warp(raft_flow10, raft_flow01)
f10_warp = -warp(raft_flow01, raft_flow10)
err01 = (
torch.nn.functional.l1_loss(
input=f01_warp, target=raft_flow01, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
err02 = (
torch.nn.functional.l1_loss(
input=f10_warp, target=raft_flow10, reduction="none"
)
.mean(1)
.unsqueeze(1)
)
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
return weights1, weights2
def forward(
self, xs, coord=None, keep_xs_shape=True, ori_flow=None, timesteps=None
):
coord = self.sample_coord_input(xs) if coord is None else coord
raft_flow01 = ori_flow[:, :, 0]
raft_flow10 = ori_flow[:, :, 1]
# calculate splatting metrics
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
# b,c,h,w
pixel_latent_0 = self.cnn_encoder(xs[:, :, 0])
pixel_latent_1 = self.cnn_encoder(xs[:, :, 1])
pixel_latent = []
modulation_params_dict = None
strtype = self.fwarp_type
if isinstance(timesteps, list):
assert isinstance(coord, list)
assert len(timesteps) == len(coord)
for i, cur_t in enumerate(timesteps):
cur_t = cur_t.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
all_outputs = []
for idx, c in enumerate(coord):
outputs = self.hyponet(
c,
modulation_params_dict=modulation_params_dict,
pixel_latent=pixel_latent[idx],
)
if keep_xs_shape:
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
outputs = outputs.permute(0, -1, *permute_idx_range)
all_outputs.append(outputs)
return all_outputs
else:
cur_t = timesteps.reshape(-1, 1, 1, 1)
tmp_pixel_latent_0 = softsplat(
tenIn=pixel_latent_0,
tenFlow=raft_flow01 * cur_t,
tenMetric=weights1,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent_1 = softsplat(
tenIn=pixel_latent_1,
tenFlow=raft_flow10 * (1 - cur_t),
tenMetric=weights2,
strMode=strtype + "-zeroeps",
)
tmp_pixel_latent = torch.cat(
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
)
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
)
pixel_latent = tmp_pixel_latent.permute(0, 2, 3, 1)
# predict all pixels of coord after applying the modulation_parms into hyponet
outputs = self.hyponet(
coord,
modulation_params_dict=modulation_params_dict,
pixel_latent=pixel_latent,
)
if keep_xs_shape:
permute_idx_range = [i for i in range(1, xs.ndim - 1)]
outputs = outputs.permute(0, -1, *permute_idx_range)
return outputs
def compute_loss(self, preds, targets, reduction="mean", single=False):
assert reduction in ["mean", "sum", "none"]
batch_size = preds.shape[0]
sample_mses = 0
assert preds.shape[2] == 1
assert targets.shape[2] == 1
for i in range(preds.shape[2]):
sample_mses += torch.reshape(
(preds[:, :, i] - targets[:, :, i]) ** 2, (batch_size, -1)
).mean(dim=-1)
sample_mses = sample_mses / preds.shape[2]
if reduction == "mean":
total_loss = sample_mses.mean()
psnr = (-10 * torch.log10(sample_mses)).mean()
elif reduction == "sum":
total_loss = sample_mses.sum()
psnr = (-10 * torch.log10(sample_mses)).sum()
else:
total_loss = sample_mses
psnr = -10 * torch.log10(sample_mses)
return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr}
def sample_coord_input(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
assert device is not None
assert coord_range is None
coord_inputs = self.coord_sampler(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
return coord_inputs