From 9f66233b5375627ba174d5c02bebf8efaabaacb4 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Feb 2026 22:15:48 +0100 Subject: [PATCH] Add VFI Optimizer node for auto-tuning hardware settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmarks the user's GPU with the actual model and resolution via a single calibration frame pair, then outputs optimal batch_size, chunk_size, keep_device, all_on_gpu, and clear_cache_after_n_frames as a connectable VFI_SETTINGS type. All 8 Interpolate/SegmentInterpolate nodes accept the new optional settings input — existing workflows without the optimizer work unchanged. Co-Authored-By: Claude Opus 4.6 --- __init__.py | 3 + nodes.py | 294 +++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 285 insertions(+), 12 deletions(-) diff --git a/__init__.py b/__init__.py index 671c42b..d150021 100644 --- a/__init__.py +++ b/__init__.py @@ -3,6 +3,7 @@ from .nodes import ( LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, + VFIOptimizer, ) NODE_CLASS_MAPPINGS = { @@ -19,6 +20,7 @@ NODE_CLASS_MAPPINGS = { "LoadGIMMVFIModel": LoadGIMMVFIModel, "GIMMVFIInterpolate": GIMMVFIInterpolate, "GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate, + "VFIOptimizer": VFIOptimizer, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -35,4 +37,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadGIMMVFIModel": "Load GIMM-VFI Model", "GIMMVFIInterpolate": "GIMM-VFI Interpolate", "GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate", + "VFIOptimizer": "VFI Optimizer", } diff --git a/nodes.py b/nodes.py index 011d76b..1f5dc50 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import logging import shutil import subprocess import tempfile +import time import torch import folder_paths from comfy.utils import ProgressBar @@ -40,6 +41,51 @@ def _check_cupy(model_name): ) +def _get_system_ram_gb(): + """Return total system RAM in GB.""" + try: + import psutil + return psutil.virtual_memory().total / (1024 ** 3) + except ImportError: + pass + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if line.startswith("MemTotal:"): + return int(line.split()[1]) / (1024 ** 2) # kB -> GB + except (OSError, ValueError): + pass + return 16.0 # safe fallback + + +def _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames): + """Override manual values with optimizer settings if provided.""" + if settings is None: + return batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames + return ( + settings["batch_size"], + settings["chunk_size"], + settings["keep_device"], + settings["all_on_gpu"], + settings["clear_cache_after_n_frames"], + ) + + +def _clear_model_cache(model): + """Clear warp caches based on model type.""" + if isinstance(model, BiMVFIModel): + clear_backwarp_cache() + elif isinstance(model, EMAVFIModel): + clear_ema_warp_cache() + elif isinstance(model, SGMVFIModel): + clear_sgm_warp_cache() + elif isinstance(model, GIMMVFIModel): + clear_gimm_caches() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _compute_target_fps_params(source_fps, target_fps): """Compute oversampling parameters for target FPS mode. @@ -222,7 +268,13 @@ class BIMVFIInterpolate: "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", }), - } + }, + "optional": { + "settings": ("VFI_SETTINGS", { + "tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, " + "chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.", + }), + }, } RETURN_TYPES = ("IMAGE", "IMAGE") @@ -297,7 +349,11 @@ class BIMVFIInterpolate: def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + if images.shape[0] < 2: return (images, images) @@ -428,7 +484,11 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, segment_index, segment_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + total_input = images.shape[0] use_target_fps = target_fps > 0 and source_fps > 0 @@ -656,6 +716,174 @@ class TweenConcatVideos: return {"result": (output_path,)} +class VFIOptimizer: + """Benchmark the user's GPU with the actual model and resolution to compute + optimal batch_size, chunk_size, keep_device, all_on_gpu, and + clear_cache_after_n_frames. Outputs a VFI_SETTINGS dict that can be + connected to any Interpolate or Segment Interpolate node. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE", { + "tooltip": "Input images — only the first 2 frames are used for calibration.", + }), + "model": ("*", { + "tooltip": "Any VFI model (BIM, EMA, SGM, GIMM). Used for benchmark inference.", + }), + "min_free_vram_gb": ("FLOAT", { + "default": 2.0, "min": 0.0, "max": 48.0, "step": 0.5, + "tooltip": "VRAM to keep free for other tasks (ComfyUI, OS, etc). " + "Higher = safer but slower.", + }), + }, + "optional": { + "force_batch_size": ("INT", { + "default": 0, "min": 0, "max": 64, "step": 1, + "tooltip": "Override auto-computed batch_size. 0 = auto.", + }), + }, + } + + RETURN_TYPES = ("VFI_SETTINGS",) + RETURN_NAMES = ("settings",) + FUNCTION = "optimize" + CATEGORY = "video/Tween" + + @staticmethod + def _conservative_defaults(): + """Return safe fallback settings.""" + return ({ + "batch_size": 1, + "chunk_size": 0, + "keep_device": True, + "all_on_gpu": False, + "clear_cache_after_n_frames": 5, + "_info": {"source": "conservative_defaults"}, + },) + + def optimize(self, images, model, min_free_vram_gb, force_batch_size=0): + if images.shape[0] < 2 or not torch.cuda.is_available(): + logger.info("VFI Optimizer: <2 frames or no CUDA, returning conservative defaults") + return self._conservative_defaults() + + device = torch.device("cuda") + + # --- Static analysis: model VRAM --- + model_params = getattr(model, "model", model) + if hasattr(model_params, "parameters"): + model_vram_bytes = sum( + p.nelement() * p.element_size() for p in model_params.parameters() + ) + else: + model_vram_bytes = 0 + model_vram_mb = model_vram_bytes / (1024 ** 2) + + # --- Calibration: run 1 frame pair --- + frame0 = images[0:1].permute(0, 3, 1, 2) # [1, C, H, W] + frame1 = images[1:2].permute(0, 3, 1, 2) + + try: + model.to(device) + torch.cuda.reset_peak_memory_stats(device) + mem_before = torch.cuda.memory_allocated(device) + + t0 = time.perf_counter() + with torch.no_grad(): + model.interpolate_batch(frame0, frame1, time_step=0.5) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + + peak_mem = torch.cuda.max_memory_allocated(device) + per_pair_vram_bytes = peak_mem - mem_before + except Exception as e: + logger.warning(f"VFI Optimizer: calibration failed ({e}), returning conservative defaults") + try: + _clear_model_cache(model) + model.to("cpu") + except Exception: + pass + return self._conservative_defaults() + finally: + _clear_model_cache(model) + model.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + per_pair_vram_mb = max(per_pair_vram_bytes / (1024 ** 2), 1.0) + + # --- Compute settings --- + total_vram_mb = torch.cuda.get_device_properties(device).total_mem / (1024 ** 2) + min_free_mb = min_free_vram_gb * 1024 + available_mb = total_vram_mb - min_free_mb - model_vram_mb + + # batch_size + if force_batch_size > 0: + batch_size = force_batch_size + else: + batch_size = max(1, min(int(available_mb * 0.85 / per_pair_vram_mb), 64)) + + # all_on_gpu: estimate output frames for 2x on 2 input frames → 3 output + # More generally, estimate if a modest output (e.g. 100 2x-frames) fits + H, W = images.shape[1], images.shape[2] + frame_mb = H * W * 3 * 4 / (1024 ** 2) # float32 [C,H,W] + estimated_output_frames = 199 # 100 input → 199 output at 2x + output_vram_mb = estimated_output_frames * frame_mb + vram_after_model = total_vram_mb - min_free_mb - model_vram_mb + all_on_gpu = output_vram_mb < vram_after_model * 0.5 + + # keep_device: True unless VRAM is extremely tight + keep_device = available_mb > model_vram_mb * 0.5 + + # clear_cache_after_n_frames: scale with headroom ratio + headroom_ratio = available_mb / max(per_pair_vram_mb * batch_size, 1.0) + clear_cache = max(3, min(int(headroom_ratio * 5), 20)) + + # chunk_size: based on system RAM + system_ram_gb = _get_system_ram_gb() + system_ram_mb = system_ram_gb * 1024 + # Estimate if full 2x output fits in 60% of RAM + # Conservative: assume 100 input frames worth of output + if output_vram_mb < system_ram_mb * 0.6: + chunk_size = 0 # fits in RAM, no chunking needed + else: + # How many input frames can we afford per chunk? + # Each input frame produces ~2 output frames at 2x, each frame_mb + frames_per_mb = 1.0 / max(frame_mb * 2, 0.001) + chunk_size = max(4, int(system_ram_mb * 0.4 * frames_per_mb)) + + settings = { + "batch_size": batch_size, + "chunk_size": chunk_size, + "keep_device": keep_device, + "all_on_gpu": all_on_gpu, + "clear_cache_after_n_frames": clear_cache, + "_info": { + "source": "VFI Optimizer", + "model_vram_mb": round(model_vram_mb, 1), + "per_pair_vram_mb": round(per_pair_vram_mb, 1), + "calibration_time_ms": round(elapsed * 1000, 1), + "total_vram_mb": round(total_vram_mb, 1), + "available_mb": round(available_mb, 1), + "system_ram_gb": round(system_ram_gb, 1), + "resolution": f"{W}x{H}", + }, + } + + logger.info( + f"VFI Optimizer: batch_size={batch_size}, chunk_size={chunk_size}, " + f"keep_device={keep_device}, all_on_gpu={all_on_gpu}, " + f"clear_cache={clear_cache} | " + f"model={model_vram_mb:.0f}MB, per_pair={per_pair_vram_mb:.0f}MB, " + f"available={available_mb:.0f}MB, " + f"calibration={elapsed*1000:.0f}ms, res={W}x{H}" + ) + + return (settings,) + + # --------------------------------------------------------------------------- # EMA-VFI nodes # --------------------------------------------------------------------------- @@ -780,7 +1008,13 @@ class EMAVFIInterpolate: "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", }), - } + }, + "optional": { + "settings": ("VFI_SETTINGS", { + "tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, " + "chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.", + }), + }, } RETURN_TYPES = ("IMAGE", "IMAGE") @@ -848,7 +1082,11 @@ class EMAVFIInterpolate: def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + if images.shape[0] < 2: return (images, images) @@ -978,7 +1216,11 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, segment_index, segment_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + total_input = images.shape[0] use_target_fps = target_fps > 0 and source_fps > 0 @@ -1195,7 +1437,13 @@ class SGMVFIInterpolate: "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", }), - } + }, + "optional": { + "settings": ("VFI_SETTINGS", { + "tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, " + "chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.", + }), + }, } RETURN_TYPES = ("IMAGE", "IMAGE") @@ -1263,7 +1511,11 @@ class SGMVFIInterpolate: def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + if images.shape[0] < 2: return (images, images) @@ -1393,7 +1645,11 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, segment_index, segment_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + total_input = images.shape[0] use_target_fps = target_fps > 0 and source_fps > 0 @@ -1627,7 +1883,13 @@ class GIMMVFIInterpolate: "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", }), - } + }, + "optional": { + "settings": ("VFI_SETTINGS", { + "tooltip": "Auto-tuned settings from VFI Optimizer. Overrides batch_size, " + "chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames.", + }), + }, } RETURN_TYPES = ("IMAGE", "IMAGE") @@ -1741,7 +2003,11 @@ class GIMMVFIInterpolate: def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + if images.shape[0] < 2: return (images, images) @@ -1888,7 +2154,11 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, segment_index, segment_size, - source_fps=0.0, target_fps=0.0): + source_fps=0.0, target_fps=0.0, settings=None): + batch_size, chunk_size, keep_device, all_on_gpu, clear_cache_after_n_frames = \ + _apply_vfi_settings(settings, batch_size, chunk_size, keep_device, + all_on_gpu, clear_cache_after_n_frames) + total_input = images.shape[0] use_target_fps = target_fps > 0 and source_fps > 0