diff --git a/__init__.py b/__init__.py index 4eaca51..0faca17 100644 --- a/__init__.py +++ b/__init__.py @@ -30,14 +30,16 @@ def _auto_install_deps(): _auto_install_deps() -from .nodes import LoadBIMVFIModel, BIMVFIInterpolate +from .nodes import LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate NODE_CLASS_MAPPINGS = { "LoadBIMVFIModel": LoadBIMVFIModel, "BIMVFIInterpolate": BIMVFIInterpolate, + "BIMVFISegmentInterpolate": BIMVFISegmentInterpolate, } NODE_DISPLAY_NAME_MAPPINGS = { "LoadBIMVFIModel": "Load BIM-VFI Model", "BIMVFIInterpolate": "BIM-VFI Interpolate", + "BIMVFISegmentInterpolate": "BIM-VFI Segment Interpolate", } diff --git a/nodes.py b/nodes.py index 8b57caf..5190263 100644 --- a/nodes.py +++ b/nodes.py @@ -261,3 +261,60 @@ class BIMVFIInterpolate: # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) + + +class BIMVFISegmentInterpolate(BIMVFIInterpolate): + """Process a numbered segment of the input batch. + + Chain multiple instances with Save nodes between them to bound peak RAM. + The model pass-through output forces sequential execution so each segment + saves and frees from RAM before the next starts. + """ + + @classmethod + def INPUT_TYPES(cls): + base = BIMVFIInterpolate.INPUT_TYPES() + base["required"]["segment_index"] = ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Which segment to process (0-based). " + "Segments overlap by 1 frame for seamless stitching. " + "Connect the model output to the next Segment Interpolate's model input to chain execution.", + }) + base["required"]["segment_size"] = ("INT", { + "default": 500, "min": 2, "max": 10000, "step": 1, + "tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame. " + "Output is identical to processing all frames at once with BIM-VFI Interpolate.", + }) + return base + + RETURN_TYPES = ("IMAGE", "BIM_VFI_MODEL") + RETURN_NAMES = ("images", "model") + FUNCTION = "interpolate" + CATEGORY = "video/BIM-VFI" + + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size, + segment_index, segment_size): + total_input = images.shape[0] + + # Compute segment boundaries (1-frame overlap) + start = segment_index * (segment_size - 1) + end = min(start + segment_size, total_input) + + if start >= total_input - 1: + # Past the end — return empty single frame + model + return (images[:1], model) + + segment_images = images[start:end] + is_continuation = segment_index > 0 + + # Delegate to the parent interpolation logic + (result,) = super().interpolate( + segment_images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size, + ) + + if is_continuation: + result = result[1:] # skip duplicate boundary frame + + return (result, model)