Add VFI Optimizer node for auto-tuning hardware settings

Benchmarks the user's GPU with the actual model and resolution via a
single calibration frame pair, then outputs optimal batch_size,
chunk_size, keep_device, all_on_gpu, and clear_cache_after_n_frames
as a connectable VFI_SETTINGS type. All 8 Interpolate/SegmentInterpolate
nodes accept the new optional settings input — existing workflows
without the optimizer work unchanged.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-27 22:15:48 +01:00
parent 7257c1aa4d
commit 9f66233b53
2 changed files with 285 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ from .nodes import (
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
VFIOptimizer,
)
NODE_CLASS_MAPPINGS = {
@@ -19,6 +20,7 @@ NODE_CLASS_MAPPINGS = {
"LoadGIMMVFIModel": LoadGIMMVFIModel,
"GIMMVFIInterpolate": GIMMVFIInterpolate,
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
"VFIOptimizer": VFIOptimizer,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -35,4 +37,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
"VFIOptimizer": "VFI Optimizer",
}

294
nodes.py
View File

@@ -5,6 +5,7 @@ import logging
import shutil
import subprocess
import tempfile
import time
import torch
import folder_paths
from comfy.utils import ProgressBar
@@ -40,6 +41,51 @@ def _check_cupy(model_name):
)
def _get_system_ram_gb():
"""Return total system RAM in GB."""
try:
import psutil
return psutil.virtual_memory().total / (1024 ** 3)
except ImportError:
pass
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if line.startswith("MemTotal:"):
return int(line.split()[1]) / (1024 ** 2) # kB -> GB
except (OSError, ValueError):
pass
return 16.0 # safe fallback
def _apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames):
"""Override manual values with optimizer settings if provided."""
if settings is None:
return batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames
return (
settings["batch_size"],
settings["chunk_size"],
settings["keep_device"],
settings["all_on_gpu"],
settings["clear_cache_after_n_frames"],
)
def _clear_model_cache(model):
"""Clear warp caches based on model type."""
if isinstance(model, BiMVFIModel):
clear_backwarp_cache()
elif isinstance(model, EMAVFIModel):
clear_ema_warp_cache()
elif isinstance(model, SGMVFIModel):
clear_sgm_warp_cache()
elif isinstance(model, GIMMVFIModel):
clear_gimm_caches()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _compute_target_fps_params(source_fps, target_fps):
"""Compute oversampling parameters for target FPS mode.
@@ -222,7 +268,13 @@ class BIMVFIInterpolate:
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
},
"optional": {
"settings": ("VFI_SETTINGS", {
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
@@ -297,7 +349,11 @@ class BIMVFIInterpolate:
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
if images.shape[0] < 2:
return (images, images)
@@ -428,7 +484,11 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
@@ -656,6 +716,174 @@ class TweenConcatVideos:
return {"result": (output_path,)}
class VFIOptimizer:
"""Benchmark the user's GPU with the actual model and resolution to compute
optimal batch_size, chunk_size, keep_device, all_on_gpu, and
clear_cache_after_n_frames. Outputs a VFI_SETTINGS dict that can be
connected to any Interpolate or Segment Interpolate node.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input images — only the first 2 frames are used for calibration.",
}),
"model": ("*", {
"tooltip": "Any VFI model (BIM, EMA, SGM, GIMM). Used for benchmark inference.",
}),
"min_free_vram_gb": ("FLOAT", {
"default": 2.0, "min": 0.0, "max": 48.0, "step": 0.5,
"tooltip": "VRAM to keep free for other tasks (ComfyUI, OS, etc). "
"Higher = safer but slower.",
}),
},
"optional": {
"force_batch_size": ("INT", {
"default": 0, "min": 0, "max": 64, "step": 1,
"tooltip": "Override auto-computed batch_size. 0 = auto.",
}),
},
}
RETURN_TYPES = ("VFI_SETTINGS",)
RETURN_NAMES = ("settings",)
FUNCTION = "optimize"
CATEGORY = "video/Tween"
@staticmethod
def _conservative_defaults():
"""Return safe fallback settings."""
return ({
"batch_size": 1,
"chunk_size": 0,
"keep_device": True,
"all_on_gpu": False,
"clear_cache_after_n_frames": 5,
"_info": {"source": "conservative_defaults"},
},)
def optimize(self, images, model, min_free_vram_gb, force_batch_size=0):
if images.shape[0] < 2 or not torch.cuda.is_available():
logger.info("VFI Optimizer: <2 frames or no CUDA, returning conservative defaults")
return self._conservative_defaults()
device = torch.device("cuda")
# --- Static analysis: model VRAM ---
model_params = getattr(model, "model", model)
if hasattr(model_params, "parameters"):
model_vram_bytes = sum(
p.nelement() * p.element_size() for p in model_params.parameters()
)
else:
model_vram_bytes = 0
model_vram_mb = model_vram_bytes / (1024 ** 2)
# --- Calibration: run 1 frame pair ---
frame0 = images[0:1].permute(0, 3, 1, 2) # [1, C, H, W]
frame1 = images[1:2].permute(0, 3, 1, 2)
try:
model.to(device)
torch.cuda.reset_peak_memory_stats(device)
mem_before = torch.cuda.memory_allocated(device)
t0 = time.perf_counter()
with torch.no_grad():
model.interpolate_batch(frame0, frame1, time_step=0.5)
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
peak_mem = torch.cuda.max_memory_allocated(device)
per_pair_vram_bytes = peak_mem - mem_before
except Exception as e:
logger.warning(f"VFI Optimizer: calibration failed ({e}), returning conservative defaults")
try:
_clear_model_cache(model)
model.to("cpu")
except Exception:
pass
return self._conservative_defaults()
finally:
_clear_model_cache(model)
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
per_pair_vram_mb = max(per_pair_vram_bytes / (1024 ** 2), 1.0)
# --- Compute settings ---
total_vram_mb = torch.cuda.get_device_properties(device).total_mem / (1024 ** 2)
min_free_mb = min_free_vram_gb * 1024
available_mb = total_vram_mb - min_free_mb - model_vram_mb
# batch_size
if force_batch_size > 0:
batch_size = force_batch_size
else:
batch_size = max(1, min(int(available_mb * 0.85 / per_pair_vram_mb), 64))
# all_on_gpu: estimate output frames for 2x on 2 input frames → 3 output
# More generally, estimate if a modest output (e.g. 100 2x-frames) fits
H, W = images.shape[1], images.shape[2]
frame_mb = H * W * 3 * 4 / (1024 ** 2) # float32 [C,H,W]
estimated_output_frames = 199 # 100 input → 199 output at 2x
output_vram_mb = estimated_output_frames * frame_mb
vram_after_model = total_vram_mb - min_free_mb - model_vram_mb
all_on_gpu = output_vram_mb < vram_after_model * 0.5
# keep_device: True unless VRAM is extremely tight
keep_device = available_mb > model_vram_mb * 0.5
# clear_cache_after_n_frames: scale with headroom ratio
headroom_ratio = available_mb / max(per_pair_vram_mb * batch_size, 1.0)
clear_cache = max(3, min(int(headroom_ratio * 5), 20))
# chunk_size: based on system RAM
system_ram_gb = _get_system_ram_gb()
system_ram_mb = system_ram_gb * 1024
# Estimate if full 2x output fits in 60% of RAM
# Conservative: assume 100 input frames worth of output
if output_vram_mb < system_ram_mb * 0.6:
chunk_size = 0 # fits in RAM, no chunking needed
else:
# How many input frames can we afford per chunk?
# Each input frame produces ~2 output frames at 2x, each frame_mb
frames_per_mb = 1.0 / max(frame_mb * 2, 0.001)
chunk_size = max(4, int(system_ram_mb * 0.4 * frames_per_mb))
settings = {
"batch_size": batch_size,
"chunk_size": chunk_size,
"keep_device": keep_device,
"all_on_gpu": all_on_gpu,
"clear_cache_after_n_frames": clear_cache,
"_info": {
"source": "VFI Optimizer",
"model_vram_mb": round(model_vram_mb, 1),
"per_pair_vram_mb": round(per_pair_vram_mb, 1),
"calibration_time_ms": round(elapsed * 1000, 1),
"total_vram_mb": round(total_vram_mb, 1),
"available_mb": round(available_mb, 1),
"system_ram_gb": round(system_ram_gb, 1),
"resolution": f"{W}x{H}",
},
}
logger.info(
f"VFI Optimizer: batch_size={batch_size}, chunk_size={chunk_size}, "
f"keep_device={keep_device}, all_on_gpu={all_on_gpu}, "
f"clear_cache={clear_cache} | "
f"model={model_vram_mb:.0f}MB, per_pair={per_pair_vram_mb:.0f}MB, "
f"available={available_mb:.0f}MB, "
f"calibration={elapsed*1000:.0f}ms, res={W}x{H}"
)
return (settings,)
# ---------------------------------------------------------------------------
# EMA-VFI nodes
# ---------------------------------------------------------------------------
@@ -780,7 +1008,13 @@ class EMAVFIInterpolate:
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
},
"optional": {
"settings": ("VFI_SETTINGS", {
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
@@ -848,7 +1082,11 @@ class EMAVFIInterpolate:
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
if images.shape[0] < 2:
return (images, images)
@@ -978,7 +1216,11 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
@@ -1195,7 +1437,13 @@ class SGMVFIInterpolate:
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
},
"optional": {
"settings": ("VFI_SETTINGS", {
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
@@ -1263,7 +1511,11 @@ class SGMVFIInterpolate:
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
if images.shape[0] < 2:
return (images, images)
@@ -1393,7 +1645,11 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
@@ -1627,7 +1883,13 @@ class GIMMVFIInterpolate:
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
},
"optional": {
"settings": ("VFI_SETTINGS", {
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
@@ -1741,7 +2003,11 @@ class GIMMVFIInterpolate:
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
if images.shape[0] < 2:
return (images, images)
@@ -1888,7 +2154,11 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size, segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
source_fps=0.0, target_fps=0.0, settings=None):
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
all_on_gpu, clear_cache_after_n_frames)
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0