Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
180
inference.py
180
inference.py
@@ -441,3 +441,183 @@ class SGMVFIModel:
|
||||
pred = self._inference(img0, img1, timestep=time_step)
|
||||
pred = padder.unpad(pred)
|
||||
return torch.clamp(pred, 0, 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GIMM-VFI model wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GIMMVFIModel:
|
||||
"""Clean inference wrapper around GIMM-VFI for ComfyUI integration.
|
||||
|
||||
Supports two modes:
|
||||
- interpolate_batch(): standard single-midpoint interface (compatible with
|
||||
recursive _interpolate_frames machinery used by other models)
|
||||
- interpolate_multi(): GIMM-VFI's unique single-pass mode, generates all
|
||||
N-1 intermediate frames between each pair in one forward pass
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_path, flow_checkpoint_path, variant="auto",
|
||||
ds_factor=1.0, device="cpu"):
|
||||
import os
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from .gimm_vfi_arch import (
|
||||
GIMMVFI_R, GIMMVFI_F, GIMMVFIConfig,
|
||||
GIMM_RAFT, GIMM_FlowFormer, gimm_get_flowformer_cfg,
|
||||
GIMMInputPadder, GIMMRaftArgs, easydict_to_dict,
|
||||
)
|
||||
import comfy.utils
|
||||
|
||||
self.ds_factor = ds_factor
|
||||
self.device = device
|
||||
self._InputPadder = GIMMInputPadder
|
||||
|
||||
filename = os.path.basename(checkpoint_path).lower()
|
||||
|
||||
# Detect variant from filename
|
||||
if variant == "auto":
|
||||
self.is_flowformer = "gimmvfi_f" in filename
|
||||
else:
|
||||
self.is_flowformer = (variant == "flowformer")
|
||||
|
||||
self.variant_name = "flowformer" if self.is_flowformer else "raft"
|
||||
|
||||
# Load config
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if self.is_flowformer:
|
||||
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_f_arb.yaml")
|
||||
else:
|
||||
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_r_arb.yaml")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config = easydict_to_dict(config)
|
||||
config = OmegaConf.create(config)
|
||||
arch_defaults = GIMMVFIConfig.create(config.arch)
|
||||
config = OmegaConf.merge(arch_defaults, config.arch)
|
||||
|
||||
# Build model + flow estimator
|
||||
dtype = torch.float32
|
||||
|
||||
if self.is_flowformer:
|
||||
self.model = GIMMVFI_F(dtype, config)
|
||||
cfg = gimm_get_flowformer_cfg()
|
||||
flow_estimator = GIMM_FlowFormer(cfg.latentcostformer)
|
||||
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
||||
flow_estimator.load_state_dict(flow_sd, strict=True)
|
||||
else:
|
||||
self.model = GIMMVFI_R(dtype, config)
|
||||
raft_args = GIMMRaftArgs(small=False, mixed_precision=False, alternate_corr=False)
|
||||
flow_estimator = GIMM_RAFT(raft_args)
|
||||
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
||||
flow_estimator.load_state_dict(flow_sd, strict=True)
|
||||
|
||||
# Load main model weights
|
||||
sd = comfy.utils.load_torch_file(checkpoint_path)
|
||||
self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
self.model.flow_estimator = flow_estimator
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
"""Move model to device (returns self for chaining)."""
|
||||
self.device = device if isinstance(device, str) else str(device)
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
||||
"""Interpolate a single midpoint frame per pair (standard interface).
|
||||
|
||||
Args:
|
||||
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
||||
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
||||
time_step: float in (0, 1)
|
||||
|
||||
Returns:
|
||||
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
results = []
|
||||
|
||||
for i in range(frames0.shape[0]):
|
||||
I0 = frames0[i:i+1].to(device)
|
||||
I2 = frames1[i:i+1].to(device)
|
||||
|
||||
padder = self._InputPadder(I0.shape, 32)
|
||||
I0_p, I2_p = padder.pad(I0, I2)
|
||||
|
||||
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
||||
batch_size = xs.shape[0]
|
||||
s_shape = xs.shape[-2:]
|
||||
|
||||
coord_inputs = [(
|
||||
self.model.sample_coord_input(
|
||||
batch_size, s_shape, [time_step],
|
||||
device=xs.device, upsample_ratio=self.ds_factor,
|
||||
),
|
||||
None,
|
||||
)]
|
||||
timesteps = [
|
||||
time_step * torch.ones(xs.shape[0]).to(xs.device)
|
||||
]
|
||||
|
||||
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
||||
pred = padder.unpad(all_outputs["imgt_pred"][0])
|
||||
results.append(torch.clamp(pred, 0, 1))
|
||||
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_multi(self, frame0, frame1, num_intermediates):
|
||||
"""Generate all intermediate frames between a pair in one forward pass.
|
||||
|
||||
This is GIMM-VFI's unique capability -- arbitrary timestep interpolation
|
||||
without recursive 2x passes.
|
||||
|
||||
Args:
|
||||
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
||||
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
||||
num_intermediates: int, number of intermediate frames to generate
|
||||
|
||||
Returns:
|
||||
List of [1, C, H, W] tensors, float32, clamped to [0, 1]
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
I0 = frame0.to(device)
|
||||
I2 = frame1.to(device)
|
||||
|
||||
padder = self._InputPadder(I0.shape, 32)
|
||||
I0_p, I2_p = padder.pad(I0, I2)
|
||||
|
||||
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
||||
batch_size = xs.shape[0]
|
||||
s_shape = xs.shape[-2:]
|
||||
interp_factor = num_intermediates + 1
|
||||
|
||||
coord_inputs = [
|
||||
(
|
||||
self.model.sample_coord_input(
|
||||
batch_size, s_shape,
|
||||
[1.0 / interp_factor * i],
|
||||
device=xs.device,
|
||||
upsample_ratio=self.ds_factor,
|
||||
),
|
||||
None,
|
||||
)
|
||||
for i in range(1, interp_factor)
|
||||
]
|
||||
timesteps = [
|
||||
i * 1.0 / interp_factor * torch.ones(xs.shape[0]).to(xs.device)
|
||||
for i in range(1, interp_factor)
|
||||
]
|
||||
|
||||
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
||||
|
||||
results = []
|
||||
for pred in all_outputs["imgt_pred"]:
|
||||
unpadded = padder.unpad(pred)
|
||||
results.append(torch.clamp(unpadded, 0, 1))
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user