Files
ComfyUI-Tween/gimm_vfi_arch/generalizable_INR/modules/fi_utils.py
Ethanfel d642255e70 Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 13:11:45 +01:00

82 lines
2.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# raft: https://github.com/princeton-vl/RAFT
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
# --------------------------------------------------------
import torch
import torch.nn.functional as F
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = (
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
.view(1, 1, 1, tenFlow.shape[3])
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
)
tenVertical = (
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
.view(1, 1, tenFlow.shape[2], 1)
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
)
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
tenFlow = torch.cat(
[
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
],
1,
)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(
input=tenInput,
grid=g,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
def normalize_flow(flows):
# FIXME: MULTI-DIMENSION
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
-1, 1, 1, 1, 1
)
flows = flows / flow_scaler # [-1,1]
# # Adapt to [0,1]
flows = (flows + 1.0) / 2.0
return flows, flow_scaler
def unnormalize_flow(flows, flow_scaler):
return (flows * 2.0 - 1.0) * flow_scaler
def resize(x, scale_factor):
return F.interpolate(
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def build_coord(img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8)
return coords