Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
82 lines
2.4 KiB
Python
82 lines
2.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# raft: https://github.com/princeton-vl/RAFT
|
|
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
|
|
# --------------------------------------------------------
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
backwarp_tenGrid = {}
|
|
|
|
|
|
def warp(tenInput, tenFlow):
|
|
k = (str(tenFlow.device), str(tenFlow.size()))
|
|
if k not in backwarp_tenGrid:
|
|
tenHorizontal = (
|
|
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
|
|
.view(1, 1, 1, tenFlow.shape[3])
|
|
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
|
)
|
|
tenVertical = (
|
|
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
|
|
.view(1, 1, tenFlow.shape[2], 1)
|
|
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
|
)
|
|
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
|
|
|
tenFlow = torch.cat(
|
|
[
|
|
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
|
],
|
|
1,
|
|
)
|
|
|
|
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
|
return torch.nn.functional.grid_sample(
|
|
input=tenInput,
|
|
grid=g,
|
|
mode="bilinear",
|
|
padding_mode="border",
|
|
align_corners=True,
|
|
)
|
|
|
|
|
|
def normalize_flow(flows):
|
|
# FIXME: MULTI-DIMENSION
|
|
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
|
|
-1, 1, 1, 1, 1
|
|
)
|
|
flows = flows / flow_scaler # [-1,1]
|
|
# # Adapt to [0,1]
|
|
flows = (flows + 1.0) / 2.0
|
|
return flows, flow_scaler
|
|
|
|
|
|
def unnormalize_flow(flows, flow_scaler):
|
|
return (flows * 2.0 - 1.0) * flow_scaler
|
|
|
|
|
|
def resize(x, scale_factor):
|
|
return F.interpolate(
|
|
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
|
)
|
|
|
|
|
|
def coords_grid(batch, ht, wd):
|
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
|
coords = torch.stack(coords[::-1], dim=0).float()
|
|
return coords[None].repeat(batch, 1, 1, 1)
|
|
|
|
|
|
def build_coord(img):
|
|
N, C, H, W = img.shape
|
|
coords = coords_grid(N, H // 8, W // 8)
|
|
return coords
|