Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
396
nodes.py
396
nodes.py
@@ -8,10 +8,11 @@ import torch
|
||||
import folder_paths
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
|
||||
from .bim_vfi_arch import clear_backwarp_cache
|
||||
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
||||
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
||||
from .gimm_vfi_arch import clear_gimm_caches
|
||||
|
||||
logger = logging.getLogger("Tween")
|
||||
|
||||
@@ -40,6 +41,17 @@ SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi")
|
||||
if not os.path.exists(SGM_MODEL_DIR):
|
||||
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
|
||||
|
||||
# GIMM-VFI
|
||||
GIMM_HF_REPO = "Kijai/GIMM-VFI_safetensors"
|
||||
GIMM_AVAILABLE_MODELS = [
|
||||
"gimmvfi_r_arb_lpips_fp32.safetensors",
|
||||
"gimmvfi_f_arb_lpips_fp32.safetensors",
|
||||
]
|
||||
|
||||
GIMM_MODEL_DIR = os.path.join(folder_paths.models_dir, "gimm-vfi")
|
||||
if not os.path.exists(GIMM_MODEL_DIR):
|
||||
os.makedirs(GIMM_MODEL_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def get_available_models():
|
||||
"""List available checkpoint files in the bim-vfi model directory."""
|
||||
@@ -1113,3 +1125,385 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
|
||||
return (result, model)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GIMM-VFI nodes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_available_gimm_models():
|
||||
"""List available GIMM-VFI checkpoint files in the gimm-vfi model directory."""
|
||||
models = []
|
||||
if os.path.isdir(GIMM_MODEL_DIR):
|
||||
for f in os.listdir(GIMM_MODEL_DIR):
|
||||
if f.endswith((".safetensors", ".pth", ".pt", ".ckpt")):
|
||||
# Exclude flow estimator checkpoints from the model list
|
||||
if f.startswith(("raft-", "flowformer_")):
|
||||
continue
|
||||
models.append(f)
|
||||
if not models:
|
||||
models = list(GIMM_AVAILABLE_MODELS)
|
||||
return sorted(models)
|
||||
|
||||
|
||||
def download_gimm_model(filename, dest_dir):
|
||||
"""Download a GIMM-VFI file from HuggingFace."""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"huggingface_hub is required to auto-download GIMM-VFI models. "
|
||||
"Install it with: pip install huggingface_hub"
|
||||
)
|
||||
logger.info(f"Downloading {filename} from HuggingFace ({GIMM_HF_REPO})...")
|
||||
hf_hub_download(
|
||||
repo_id=GIMM_HF_REPO,
|
||||
filename=filename,
|
||||
local_dir=dest_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
dest_path = os.path.join(dest_dir, filename)
|
||||
if not os.path.exists(dest_path):
|
||||
raise RuntimeError(f"Failed to download {filename} to {dest_path}")
|
||||
logger.info(f"Downloaded {filename}")
|
||||
|
||||
|
||||
class LoadGIMMVFIModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model_path": (get_available_gimm_models(), {
|
||||
"default": GIMM_AVAILABLE_MODELS[0],
|
||||
"tooltip": "Checkpoint file from models/gimm-vfi/. Auto-downloads from HuggingFace on first use. "
|
||||
"RAFT variant (~80MB) or FlowFormer variant (~123MB) auto-detected from filename.",
|
||||
}),
|
||||
"ds_factor": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.125, "max": 1.0, "step": 0.125,
|
||||
"tooltip": "Downscale factor for internal processing. 1.0 = full resolution. "
|
||||
"Lower values reduce VRAM usage and speed up inference at the cost of quality. "
|
||||
"Try 0.5 for 4K inputs.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GIMM_VFI_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "video/GIMM-VFI"
|
||||
|
||||
def load_model(self, model_path, ds_factor):
|
||||
full_path = os.path.join(GIMM_MODEL_DIR, model_path)
|
||||
|
||||
# Auto-download main model if missing
|
||||
if not os.path.exists(full_path):
|
||||
logger.info(f"Model not found at {full_path}, attempting download...")
|
||||
download_gimm_model(model_path, GIMM_MODEL_DIR)
|
||||
|
||||
# Detect and download matching flow estimator
|
||||
if "gimmvfi_f" in model_path.lower():
|
||||
flow_filename = "flowformer_sintel_fp32.safetensors"
|
||||
else:
|
||||
flow_filename = "raft-things_fp32.safetensors"
|
||||
|
||||
flow_path = os.path.join(GIMM_MODEL_DIR, flow_filename)
|
||||
if not os.path.exists(flow_path):
|
||||
logger.info(f"Flow estimator not found, downloading {flow_filename}...")
|
||||
download_gimm_model(flow_filename, GIMM_MODEL_DIR)
|
||||
|
||||
wrapper = GIMMVFIModel(
|
||||
checkpoint_path=full_path,
|
||||
flow_checkpoint_path=flow_path,
|
||||
variant="auto",
|
||||
ds_factor=ds_factor,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
logger.info(f"GIMM-VFI model loaded (variant={wrapper.variant_name}, ds_factor={ds_factor})")
|
||||
return (wrapper,)
|
||||
|
||||
|
||||
class GIMMVFIInterpolate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
|
||||
}),
|
||||
"model": ("GIMM_VFI_MODEL", {
|
||||
"tooltip": "GIMM-VFI model from the Load GIMM-VFI Model node.",
|
||||
}),
|
||||
"multiplier": ([2, 4, 8], {
|
||||
"default": 2,
|
||||
"tooltip": "Frame rate multiplier. In single-pass mode, all intermediate frames are generated "
|
||||
"in one forward pass per pair. In recursive mode, uses 2x passes like other models.",
|
||||
}),
|
||||
"single_pass": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Use GIMM-VFI's single-pass arbitrary-timestep mode. Generates all intermediate frames "
|
||||
"per pair in one forward pass (no recursive 2x passes). Disable to use the standard "
|
||||
"recursive approach (same as BIM/EMA/SGM).",
|
||||
}),
|
||||
"clear_cache_after_n_frames": ("INT", {
|
||||
"default": 10, "min": 1, "max": 100, "step": 1,
|
||||
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
|
||||
}),
|
||||
"keep_device": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
|
||||
}),
|
||||
"all_on_gpu": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
|
||||
}),
|
||||
"batch_size": ("INT", {
|
||||
"default": 1, "min": 1, "max": 64, "step": 1,
|
||||
"tooltip": "Number of frame pairs to process simultaneously in recursive mode. Ignored in single-pass mode (pairs are processed one at a time since each generates multiple frames).",
|
||||
}),
|
||||
"chunk_size": ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "interpolate"
|
||||
CATEGORY = "video/GIMM-VFI"
|
||||
|
||||
def _interpolate_frames_single_pass(self, frames, model, multiplier,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref):
|
||||
"""Single-pass interpolation using GIMM-VFI's arbitrary timestep capability."""
|
||||
num_intermediates = multiplier - 1
|
||||
new_frames = []
|
||||
num_pairs = frames.shape[0] - 1
|
||||
pairs_since_clear = 0
|
||||
|
||||
for i in range(num_pairs):
|
||||
frame0 = frames[i:i+1]
|
||||
frame1 = frames[i+1:i+2]
|
||||
|
||||
if not keep_device:
|
||||
model.to(device)
|
||||
|
||||
mids = model.interpolate_multi(frame0, frame1, num_intermediates)
|
||||
mids = [m.to(storage_device) for m in mids]
|
||||
|
||||
if not keep_device:
|
||||
model.to("cpu")
|
||||
|
||||
new_frames.append(frames[i:i+1])
|
||||
for m in mids:
|
||||
new_frames.append(m)
|
||||
|
||||
step_ref[0] += 1
|
||||
pbar.update_absolute(step_ref[0])
|
||||
|
||||
pairs_since_clear += 1
|
||||
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
|
||||
clear_gimm_caches()
|
||||
torch.cuda.empty_cache()
|
||||
pairs_since_clear = 0
|
||||
|
||||
new_frames.append(frames[-1:])
|
||||
result = torch.cat(new_frames, dim=0)
|
||||
|
||||
if not all_on_gpu and torch.cuda.is_available():
|
||||
clear_gimm_caches()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return result
|
||||
|
||||
def _interpolate_frames(self, frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref):
|
||||
"""Recursive 2x interpolation (standard approach, same as other models)."""
|
||||
for pass_idx in range(num_passes):
|
||||
new_frames = []
|
||||
num_pairs = frames.shape[0] - 1
|
||||
pairs_since_clear = 0
|
||||
|
||||
for i in range(0, num_pairs, batch_size):
|
||||
batch_end = min(i + batch_size, num_pairs)
|
||||
actual_batch = batch_end - i
|
||||
|
||||
frames0 = frames[i:batch_end]
|
||||
frames1 = frames[i + 1:batch_end + 1]
|
||||
|
||||
if not keep_device:
|
||||
model.to(device)
|
||||
|
||||
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
|
||||
mids = mids.to(storage_device)
|
||||
|
||||
if not keep_device:
|
||||
model.to("cpu")
|
||||
|
||||
for j in range(actual_batch):
|
||||
new_frames.append(frames[i + j:i + j + 1])
|
||||
new_frames.append(mids[j:j+1])
|
||||
|
||||
step_ref[0] += actual_batch
|
||||
pbar.update_absolute(step_ref[0])
|
||||
|
||||
pairs_since_clear += actual_batch
|
||||
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
|
||||
clear_gimm_caches()
|
||||
torch.cuda.empty_cache()
|
||||
pairs_since_clear = 0
|
||||
|
||||
new_frames.append(frames[-1:])
|
||||
frames = torch.cat(new_frames, dim=0)
|
||||
|
||||
if not all_on_gpu and torch.cuda.is_available():
|
||||
clear_gimm_caches()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def _count_steps(num_frames, num_passes):
|
||||
"""Count total interpolation steps for recursive mode."""
|
||||
n = num_frames
|
||||
total = 0
|
||||
for _ in range(num_passes):
|
||||
total += n - 1
|
||||
n = 2 * n - 1
|
||||
return total
|
||||
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if not single_pass:
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
|
||||
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
|
||||
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
|
||||
total_input = all_frames.shape[0]
|
||||
|
||||
# Build chunk boundaries (1-frame overlap between consecutive chunks)
|
||||
if chunk_size < 2 or chunk_size >= total_input:
|
||||
chunks = [(0, total_input)]
|
||||
else:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < total_input - 1:
|
||||
end = min(start + chunk_size, total_input)
|
||||
chunks.append((start, end))
|
||||
start = end - 1 # overlap by 1 frame
|
||||
if end == total_input:
|
||||
break
|
||||
|
||||
# Calculate total progress steps across all chunks
|
||||
if single_pass:
|
||||
total_steps = sum(ce - cs - 1 for cs, ce in chunks)
|
||||
else:
|
||||
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
result_chunks = []
|
||||
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
|
||||
chunk_frames = all_frames[chunk_start:chunk_end].clone()
|
||||
|
||||
if single_pass:
|
||||
chunk_result = self._interpolate_frames_single_pass(
|
||||
chunk_frames, model, multiplier,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
else:
|
||||
chunk_result = self._interpolate_frames(
|
||||
chunk_frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
|
||||
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
|
||||
if chunk_idx > 0:
|
||||
chunk_result = chunk_result[1:]
|
||||
|
||||
# Move completed chunk to CPU to bound memory when chunking
|
||||
if len(chunks) > 1:
|
||||
chunk_result = chunk_result.cpu()
|
||||
|
||||
result_chunks.append(chunk_result)
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||
result = result.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
|
||||
|
||||
class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
"""Process a numbered segment of the input batch for GIMM-VFI.
|
||||
|
||||
Chain multiple instances with Save nodes between them to bound peak RAM.
|
||||
The model pass-through output forces sequential execution so each segment
|
||||
saves and frees from RAM before the next starts.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
base = GIMMVFIInterpolate.INPUT_TYPES()
|
||||
base["required"]["segment_index"] = ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
|
||||
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
|
||||
"Chain the model output to the next Segment Interpolate to force sequential execution.",
|
||||
})
|
||||
base["required"]["segment_size"] = ("INT", {
|
||||
"default": 500, "min": 2, "max": 10000, "step": 1,
|
||||
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
|
||||
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
|
||||
})
|
||||
return base
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "GIMM_VFI_MODEL")
|
||||
RETURN_NAMES = ("images", "model")
|
||||
FUNCTION = "interpolate"
|
||||
CATEGORY = "video/GIMM-VFI"
|
||||
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size, segment_index, segment_size):
|
||||
total_input = images.shape[0]
|
||||
|
||||
# Compute segment boundaries (1-frame overlap)
|
||||
start = segment_index * (segment_size - 1)
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input - 1:
|
||||
# Past the end — return empty single frame + model
|
||||
return (images[:1], model)
|
||||
|
||||
segment_images = images[start:end]
|
||||
is_continuation = segment_index > 0
|
||||
|
||||
# Delegate to the parent interpolation logic
|
||||
(result,) = super().interpolate(
|
||||
segment_images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size,
|
||||
)
|
||||
|
||||
if is_continuation:
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
|
||||
return (result, model)
|
||||
|
||||
Reference in New Issue
Block a user