Add VFI Optimizer node for auto-tuning hardware settings
Benchmarks the user's GPU with the actual model and resolution via a single calibration frame pair, then outputs optimal batch_size, chunk_size, keep_device, all_on_gpu, and clear_cache_after_n_frames as a connectable VFI_SETTINGS type. All 8 Interpolate/SegmentInterpolate nodes accept the new optional settings input — existing workflows without the optimizer work unchanged. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from .nodes import (
|
||||
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
||||
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
||||
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
||||
VFIOptimizer,
|
||||
)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -19,6 +20,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LoadGIMMVFIModel": LoadGIMMVFIModel,
|
||||
"GIMMVFIInterpolate": GIMMVFIInterpolate,
|
||||
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
|
||||
"VFIOptimizer": VFIOptimizer,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -35,4 +37,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
|
||||
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
|
||||
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
|
||||
"VFIOptimizer": "VFI Optimizer",
|
||||
}
|
||||
|
||||
294
nodes.py
294
nodes.py
@@ -5,6 +5,7 @@ import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import torch
|
||||
import folder_paths
|
||||
from comfy.utils import ProgressBar
|
||||
@@ -40,6 +41,51 @@ def _check_cupy(model_name):
|
||||
)
|
||||
|
||||
|
||||
def _get_system_ram_gb():
|
||||
"""Return total system RAM in GB."""
|
||||
try:
|
||||
import psutil
|
||||
return psutil.virtual_memory().total / (1024 ** 3)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("MemTotal:"):
|
||||
return int(line.split()[1]) / (1024 ** 2) # kB -> GB
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
return 16.0 # safe fallback
|
||||
|
||||
|
||||
def _apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames):
|
||||
"""Override manual values with optimizer settings if provided."""
|
||||
if settings is None:
|
||||
return batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames
|
||||
return (
|
||||
settings["batch_size"],
|
||||
settings["chunk_size"],
|
||||
settings["keep_device"],
|
||||
settings["all_on_gpu"],
|
||||
settings["clear_cache_after_n_frames"],
|
||||
)
|
||||
|
||||
|
||||
def _clear_model_cache(model):
|
||||
"""Clear warp caches based on model type."""
|
||||
if isinstance(model, BiMVFIModel):
|
||||
clear_backwarp_cache()
|
||||
elif isinstance(model, EMAVFIModel):
|
||||
clear_ema_warp_cache()
|
||||
elif isinstance(model, SGMVFIModel):
|
||||
clear_sgm_warp_cache()
|
||||
elif isinstance(model, GIMMVFIModel):
|
||||
clear_gimm_caches()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _compute_target_fps_params(source_fps, target_fps):
|
||||
"""Compute oversampling parameters for target FPS mode.
|
||||
|
||||
@@ -222,7 +268,13 @@ class BIMVFIInterpolate:
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"settings": ("VFI_SETTINGS", {
|
||||
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
|
||||
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
@@ -297,7 +349,11 @@ class BIMVFIInterpolate:
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
if images.shape[0] < 2:
|
||||
return (images, images)
|
||||
|
||||
@@ -428,7 +484,11 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
@@ -656,6 +716,174 @@ class TweenConcatVideos:
|
||||
return {"result": (output_path,)}
|
||||
|
||||
|
||||
class VFIOptimizer:
|
||||
"""Benchmark the user's GPU with the actual model and resolution to compute
|
||||
optimal batch_size, chunk_size, keep_device, all_on_gpu, and
|
||||
clear_cache_after_n_frames. Outputs a VFI_SETTINGS dict that can be
|
||||
connected to any Interpolate or Segment Interpolate node.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Input images — only the first 2 frames are used for calibration.",
|
||||
}),
|
||||
"model": ("*", {
|
||||
"tooltip": "Any VFI model (BIM, EMA, SGM, GIMM). Used for benchmark inference.",
|
||||
}),
|
||||
"min_free_vram_gb": ("FLOAT", {
|
||||
"default": 2.0, "min": 0.0, "max": 48.0, "step": 0.5,
|
||||
"tooltip": "VRAM to keep free for other tasks (ComfyUI, OS, etc). "
|
||||
"Higher = safer but slower.",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"force_batch_size": ("INT", {
|
||||
"default": 0, "min": 0, "max": 64, "step": 1,
|
||||
"tooltip": "Override auto-computed batch_size. 0 = auto.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VFI_SETTINGS",)
|
||||
RETURN_NAMES = ("settings",)
|
||||
FUNCTION = "optimize"
|
||||
CATEGORY = "video/Tween"
|
||||
|
||||
@staticmethod
|
||||
def _conservative_defaults():
|
||||
"""Return safe fallback settings."""
|
||||
return ({
|
||||
"batch_size": 1,
|
||||
"chunk_size": 0,
|
||||
"keep_device": True,
|
||||
"all_on_gpu": False,
|
||||
"clear_cache_after_n_frames": 5,
|
||||
"_info": {"source": "conservative_defaults"},
|
||||
},)
|
||||
|
||||
def optimize(self, images, model, min_free_vram_gb, force_batch_size=0):
|
||||
if images.shape[0] < 2 or not torch.cuda.is_available():
|
||||
logger.info("VFI Optimizer: <2 frames or no CUDA, returning conservative defaults")
|
||||
return self._conservative_defaults()
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# --- Static analysis: model VRAM ---
|
||||
model_params = getattr(model, "model", model)
|
||||
if hasattr(model_params, "parameters"):
|
||||
model_vram_bytes = sum(
|
||||
p.nelement() * p.element_size() for p in model_params.parameters()
|
||||
)
|
||||
else:
|
||||
model_vram_bytes = 0
|
||||
model_vram_mb = model_vram_bytes / (1024 ** 2)
|
||||
|
||||
# --- Calibration: run 1 frame pair ---
|
||||
frame0 = images[0:1].permute(0, 3, 1, 2) # [1, C, H, W]
|
||||
frame1 = images[1:2].permute(0, 3, 1, 2)
|
||||
|
||||
try:
|
||||
model.to(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
mem_before = torch.cuda.memory_allocated(device)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
model.interpolate_batch(frame0, frame1, time_step=0.5)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
peak_mem = torch.cuda.max_memory_allocated(device)
|
||||
per_pair_vram_bytes = peak_mem - mem_before
|
||||
except Exception as e:
|
||||
logger.warning(f"VFI Optimizer: calibration failed ({e}), returning conservative defaults")
|
||||
try:
|
||||
_clear_model_cache(model)
|
||||
model.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
return self._conservative_defaults()
|
||||
finally:
|
||||
_clear_model_cache(model)
|
||||
model.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
per_pair_vram_mb = max(per_pair_vram_bytes / (1024 ** 2), 1.0)
|
||||
|
||||
# --- Compute settings ---
|
||||
total_vram_mb = torch.cuda.get_device_properties(device).total_mem / (1024 ** 2)
|
||||
min_free_mb = min_free_vram_gb * 1024
|
||||
available_mb = total_vram_mb - min_free_mb - model_vram_mb
|
||||
|
||||
# batch_size
|
||||
if force_batch_size > 0:
|
||||
batch_size = force_batch_size
|
||||
else:
|
||||
batch_size = max(1, min(int(available_mb * 0.85 / per_pair_vram_mb), 64))
|
||||
|
||||
# all_on_gpu: estimate output frames for 2x on 2 input frames → 3 output
|
||||
# More generally, estimate if a modest output (e.g. 100 2x-frames) fits
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
frame_mb = H * W * 3 * 4 / (1024 ** 2) # float32 [C,H,W]
|
||||
estimated_output_frames = 199 # 100 input → 199 output at 2x
|
||||
output_vram_mb = estimated_output_frames * frame_mb
|
||||
vram_after_model = total_vram_mb - min_free_mb - model_vram_mb
|
||||
all_on_gpu = output_vram_mb < vram_after_model * 0.5
|
||||
|
||||
# keep_device: True unless VRAM is extremely tight
|
||||
keep_device = available_mb > model_vram_mb * 0.5
|
||||
|
||||
# clear_cache_after_n_frames: scale with headroom ratio
|
||||
headroom_ratio = available_mb / max(per_pair_vram_mb * batch_size, 1.0)
|
||||
clear_cache = max(3, min(int(headroom_ratio * 5), 20))
|
||||
|
||||
# chunk_size: based on system RAM
|
||||
system_ram_gb = _get_system_ram_gb()
|
||||
system_ram_mb = system_ram_gb * 1024
|
||||
# Estimate if full 2x output fits in 60% of RAM
|
||||
# Conservative: assume 100 input frames worth of output
|
||||
if output_vram_mb < system_ram_mb * 0.6:
|
||||
chunk_size = 0 # fits in RAM, no chunking needed
|
||||
else:
|
||||
# How many input frames can we afford per chunk?
|
||||
# Each input frame produces ~2 output frames at 2x, each frame_mb
|
||||
frames_per_mb = 1.0 / max(frame_mb * 2, 0.001)
|
||||
chunk_size = max(4, int(system_ram_mb * 0.4 * frames_per_mb))
|
||||
|
||||
settings = {
|
||||
"batch_size": batch_size,
|
||||
"chunk_size": chunk_size,
|
||||
"keep_device": keep_device,
|
||||
"all_on_gpu": all_on_gpu,
|
||||
"clear_cache_after_n_frames": clear_cache,
|
||||
"_info": {
|
||||
"source": "VFI Optimizer",
|
||||
"model_vram_mb": round(model_vram_mb, 1),
|
||||
"per_pair_vram_mb": round(per_pair_vram_mb, 1),
|
||||
"calibration_time_ms": round(elapsed * 1000, 1),
|
||||
"total_vram_mb": round(total_vram_mb, 1),
|
||||
"available_mb": round(available_mb, 1),
|
||||
"system_ram_gb": round(system_ram_gb, 1),
|
||||
"resolution": f"{W}x{H}",
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"VFI Optimizer: batch_size={batch_size}, chunk_size={chunk_size}, "
|
||||
f"keep_device={keep_device}, all_on_gpu={all_on_gpu}, "
|
||||
f"clear_cache={clear_cache} | "
|
||||
f"model={model_vram_mb:.0f}MB, per_pair={per_pair_vram_mb:.0f}MB, "
|
||||
f"available={available_mb:.0f}MB, "
|
||||
f"calibration={elapsed*1000:.0f}ms, res={W}x{H}"
|
||||
)
|
||||
|
||||
return (settings,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EMA-VFI nodes
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -780,7 +1008,13 @@ class EMAVFIInterpolate:
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"settings": ("VFI_SETTINGS", {
|
||||
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
|
||||
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
@@ -848,7 +1082,11 @@ class EMAVFIInterpolate:
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
if images.shape[0] < 2:
|
||||
return (images, images)
|
||||
|
||||
@@ -978,7 +1216,11 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
@@ -1195,7 +1437,13 @@ class SGMVFIInterpolate:
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"settings": ("VFI_SETTINGS", {
|
||||
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
|
||||
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
@@ -1263,7 +1511,11 @@ class SGMVFIInterpolate:
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
if images.shape[0] < 2:
|
||||
return (images, images)
|
||||
|
||||
@@ -1393,7 +1645,11 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
@@ -1627,7 +1883,13 @@ class GIMMVFIInterpolate:
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"settings": ("VFI_SETTINGS", {
|
||||
"tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, "
|
||||
"chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
@@ -1741,7 +2003,11 @@ class GIMMVFIInterpolate:
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
if images.shape[0] < 2:
|
||||
return (images, images)
|
||||
|
||||
@@ -1888,7 +2154,11 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size, segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
source_fps=0.0, target_fps=0.0, settings=None):
|
||||
batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \
|
||||
_apply_vfi_settings(settings, batch_size, chunk_size, keep_device,
|
||||
all_on_gpu, clear_cache_after_n_frames)
|
||||
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user