From 3e8148b7e2ba0b31184b4c42e8df1d4c3fbacf76 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 12 Feb 2026 19:08:42 +0100 Subject: [PATCH] Add chunk_size for long video support, fix cache clearing, add README MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - chunk_size input splits input into overlapping segments processed independently then stitched, bounding memory for 1000+ frame videos while producing identical results to processing all at once - Fix cache clearing logic: use counter instead of modulo so it triggers regardless of batch_size value - Replace inefficient torch.cat gather with direct tensor slicing - Add README with usage guide, VRAM recommendations, and full attribution to BiM-VFI (Seo, Oh, Kim — CVPR 2025, KAIST VIC Lab) Co-Authored-By: Claude Opus 4.6 --- README.md | 88 +++++++++++++++++++++++++++++++++++ nodes.py | 134 ++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 183 insertions(+), 39 deletions(-) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..9c7ab68 --- /dev/null +++ b/README.md @@ -0,0 +1,88 @@ +# ComfyUI BIM-VFI + +ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM. + +## Nodes + +### Load BIM-VFI Model + +Loads the BiM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/bim-vfi/`. + +| Input | Description | +|-------|-------------| +| **model_path** | Checkpoint file from `models/bim-vfi/` | +| **auto_pyr_level** | Auto-select pyramid level by resolution (<540p=3, 540p=5, 1080p=6, 4K=7) | +| **pyr_level** | Manual pyramid level (3-7), only used when auto is off | + +### BIM-VFI Interpolate + +Interpolates frames from an image batch. + +| Input | Description | +|-------|-------------| +| **images** | Input image batch | +| **model** | Model from the loader node | +| **multiplier** | 2x, 4x, or 8x frame rate (recursive 2x passes) | +| **batch_size** | Frame pairs processed simultaneously (higher = faster, more VRAM) | +| **chunk_size** | Process in segments of N input frames (0 = disabled). Bounds memory for very long videos. Result is identical to processing all at once | +| **keep_device** | Keep model on GPU between pairs (faster, ~200MB constant VRAM) | +| **all_on_gpu** | Keep all intermediate frames on GPU (fast, needs large VRAM) | +| **clear_cache_after_n_frames** | Clear CUDA cache every N pairs to prevent VRAM buildup | + +**Output frame count:** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 + +## Installation + +Clone into your ComfyUI `custom_nodes/` directory: + +```bash +cd ComfyUI/custom_nodes +git clone https://github.com/your-user/Comfyui-BIM-VFI.git +``` + +Dependencies (`gdown`, `cupy`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version. + +To install manually: + +```bash +cd Comfyui-BIM-VFI +python install.py +``` + +### Requirements + +- PyTorch with CUDA +- `cupy` (matching your CUDA version) +- `gdown` (for model auto-download) + +## VRAM Guide + +| VRAM | Recommended settings | +|------|---------------------| +| 8 GB | batch_size=1, chunk_size=500 | +| 24 GB | batch_size=2-4, chunk_size=1000 | +| 48 GB+ | batch_size=4-16, all_on_gpu=true | +| 96 GB+ | batch_size=8-16, all_on_gpu=true, chunk_size=0 | + +## Acknowledgments + +This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab). The model architecture files in `bim_vfi_arch/` are vendored from their repository with minimal modifications (relative imports, inference-only paths). + +**Paper:** +> Wonyong Seo, Jihyong Oh, and Munchurl Kim. +> "BiM-VFI: Bidirectional Motion Field-Guided Frame Interpolation for Video with Non-uniform Motions." +> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2025. +> [[arXiv]](https://arxiv.org/abs/2412.11365) [[Project Page]](https://kaist-viclab.github.io/BiM-VFI_site/) [[GitHub]](https://github.com/KAIST-VICLab/BiM-VFI) + +```bibtex +@inproceedings{seo2025bimvfi, + title={BiM-VFI: Bidirectional Motion Field-Guided Frame Interpolation for Video with Non-uniform Motions}, + author={Seo, Wonyong and Oh, Jihyong and Kim, Munchurl}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2025} +} +``` + +## License + +The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details. diff --git a/nodes.py b/nodes.py index a2bce66..8b57caf 100644 --- a/nodes.py +++ b/nodes.py @@ -123,6 +123,10 @@ class BIMVFIInterpolate: "default": 1, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full. Recommended: 1 for 8GB, 2-4 for 24GB, 4-16 for 48GB+.", }), + "chunk_size": ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Process input frames in chunks of this size (0=disabled). Each chunk runs all interpolation passes independently then results are stitched seamlessly. Use for very long videos (1000+ frames) to bound memory. Result is identical to processing all at once.", + }), } } @@ -131,47 +135,28 @@ class BIMVFIInterpolate: FUNCTION = "interpolate" CATEGORY = "video/BIM-VFI" - def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size): - if images.shape[0] < 2: - return (images,) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] - - # all_on_gpu implies keep_device - if all_on_gpu: - keep_device = True - - # Where to store intermediate frames - storage_device = device if all_on_gpu else torch.device("cpu") - - # Convert from ComfyUI [B, H, W, C] to model [B, C, H, W] - frames = images.permute(0, 3, 1, 2).to(storage_device) - - # After each 2x pass, frame count = 2*N - 1, so compute total pairs across passes - n = frames.shape[0] - total_steps = 0 - for _ in range(num_passes): - total_steps += n - 1 - n = 2 * n - 1 - - pbar = ProgressBar(total_steps) - step = 0 - - if keep_device: - model.to(device) + def _interpolate_frames(self, frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref): + """Run all interpolation passes on a chunk of frames. + Args: + frames: [N, C, H, W] tensor on storage_device + step_ref: list with single int, mutable counter for progress bar + Returns: + Interpolated frames as [M, C, H, W] tensor on storage_device + """ for pass_idx in range(num_passes): new_frames = [] num_pairs = frames.shape[0] - 1 + pairs_since_clear = 0 for i in range(0, num_pairs, batch_size): batch_end = min(i + batch_size, num_pairs) actual_batch = batch_end - i - # Gather batch of pairs - frames0 = torch.cat([frames[j:j+1] for j in range(i, batch_end)], dim=0) - frames1 = torch.cat([frames[j+1:j+2] for j in range(i, batch_end)], dim=0) + frames0 = frames[i:batch_end] + frames1 = frames[i + 1:batch_end + 1] if not keep_device: model.to(device) @@ -182,19 +167,19 @@ class BIMVFIInterpolate: if not keep_device: model.to("cpu") - # Interleave: original frame, then interpolated frame for j in range(actual_batch): new_frames.append(frames[i + j:i + j + 1]) new_frames.append(mids[j:j+1]) - step += actual_batch - pbar.update_absolute(step, total_steps) + step_ref[0] += actual_batch + pbar.update_absolute(step_ref[0]) - if not all_on_gpu and (batch_end) % clear_cache_after_n_frames == 0 and torch.cuda.is_available(): + pairs_since_clear += actual_batch + if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): clear_backwarp_cache() torch.cuda.empty_cache() + pairs_since_clear = 0 - # Append last frame new_frames.append(frames[-1:]) frames = torch.cat(new_frames, dim=0) @@ -202,6 +187,77 @@ class BIMVFIInterpolate: clear_backwarp_cache() torch.cuda.empty_cache() - # Convert back to ComfyUI [B, H, W, C], on CPU for ComfyUI - result = frames.cpu().permute(0, 2, 3, 1) + return frames + + @staticmethod + def _count_steps(num_frames, num_passes): + """Count total interpolation steps for a given input frame count.""" + n = num_frames + total = 0 + for _ in range(num_passes): + total += n - 1 + n = 2 * n - 1 + return total + + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size): + if images.shape[0] < 2: + return (images,) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + + if all_on_gpu: + keep_device = True + + storage_device = device if all_on_gpu else torch.device("cpu") + + # Convert from ComfyUI [B, H, W, C] to model [B, C, H, W] + all_frames = images.permute(0, 3, 1, 2).to(storage_device) + total_input = all_frames.shape[0] + + # Build chunk boundaries (1-frame overlap between consecutive chunks) + if chunk_size < 2 or chunk_size >= total_input: + chunks = [(0, total_input)] + else: + chunks = [] + start = 0 + while start < total_input - 1: + end = min(start + chunk_size, total_input) + chunks.append((start, end)) + start = end - 1 # overlap by 1 frame + if end == total_input: + break + + # Calculate total progress steps across all chunks + total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks) + pbar = ProgressBar(total_steps) + step_ref = [0] + + if keep_device: + model.to(device) + + result_chunks = [] + for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks): + chunk_frames = all_frames[chunk_start:chunk_end].clone() + + chunk_result = self._interpolate_frames( + chunk_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + + # Skip first frame of subsequent chunks (duplicate of previous chunk's last frame) + if chunk_idx > 0: + chunk_result = chunk_result[1:] + + # Move completed chunk to CPU to bound memory when chunking + if len(chunks) > 1: + chunk_result = chunk_result.cpu() + + result_chunks.append(chunk_result) + + result = torch.cat(result_chunks, dim=0) + # Convert back to ComfyUI [B, H, W, C], on CPU + result = result.cpu().permute(0, 2, 3, 1) return (result,)