diff --git a/README.md b/README.md index f30506c..4fa9070 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Search for `ComfyUI-STAR` in ComfyUI Manager and install. ```bash cd ComfyUI/custom_nodes -git clone --recursive git@192.168.1.1:Ethanfel/Comfyui-STAR.git +git clone --recursive https://github.com/ethanfel/Comfyui-STAR.git cd Comfyui-STAR pip install -r requirements.txt ``` @@ -56,6 +56,7 @@ Runs the STAR diffusion pipeline on an image batch. | **max_chunk_len** | Max frames per chunk (4–128, default 32). Lower = less VRAM for long videos. | | **seed** | Random seed for reproducibility. | | **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. | +| **segment_size** | Process video in segments of this many frames to reduce RAM usage (0–256, default 0). 0 = process all at once. Recommended: 16–32 for long videos. Segments overlap by 25% with linear crossfade blending. | ## VRAM Requirements diff --git a/nodes.py b/nodes.py index e34de81..5140a94 100644 --- a/nodes.py +++ b/nodes.py @@ -205,6 +205,10 @@ class STARVideoSuperResolution: "default": "adain", "tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.", }), + "segment_size": ("INT", { + "default": 0, "min": 0, "max": 256, + "tooltip": "Process video in segments of this many frames to reduce RAM usage. 0 = process all at once. Recommended: 16-32 for long videos.", + }), } } @@ -226,10 +230,9 @@ class STARVideoSuperResolution: max_chunk_len, seed, color_fix, + segment_size=0, ): - from .star_pipeline import run_star_inference - - result = run_star_inference( + kwargs = dict( star_model=star_model, images=images, upscale=upscale, @@ -241,6 +244,14 @@ class STARVideoSuperResolution: seed=seed, color_fix=color_fix, ) + + if segment_size > 0: + from .star_pipeline import run_star_inference_segmented + result = run_star_inference_segmented(segment_size=segment_size, **kwargs) + else: + from .star_pipeline import run_star_inference + result = run_star_inference(**kwargs) + return (result,) diff --git a/star_pipeline.py b/star_pipeline.py index d79aec9..643faed 100644 --- a/star_pipeline.py +++ b/star_pipeline.py @@ -130,6 +130,112 @@ def _move(module, device): torch.cuda.empty_cache() +def run_star_inference_segmented( + star_model: dict, + images: torch.Tensor, + segment_size: int, + upscale: int = 4, + steps: int = 15, + guide_scale: float = 7.5, + prompt: str = "", + solver_mode: str = "fast", + max_chunk_len: int = 32, + seed: int = 0, + color_fix: str = "adain", +) -> torch.Tensor: + """Run STAR inference in overlapping segments to bound peak RAM usage. + + Each segment of `segment_size` frames is processed independently through + the full pipeline. Overlap regions (25% of segment_size, minimum 2 frames) + are blended with a linear crossfade to avoid temporal seam artifacts. + """ + total_frames = images.shape[0] + + # Fall back to single-shot if the video fits in one segment. + if total_frames <= segment_size: + return run_star_inference( + star_model=star_model, images=images, upscale=upscale, steps=steps, + guide_scale=guide_scale, prompt=prompt, solver_mode=solver_mode, + max_chunk_len=max_chunk_len, seed=seed, color_fix=color_fix, + ) + + overlap = max(2, segment_size // 4) + stride = segment_size - overlap + + # Build list of (start, end) frame indices for each segment. + segments = [] + start = 0 + while start < total_frames: + end = min(start + segment_size, total_frames) + segments.append((start, end)) + if end == total_frames: + break + start += stride + + print(f"[STAR] Segmented processing: {total_frames} frames, " + f"segment_size={segment_size}, overlap={overlap}, " + f"{len(segments)} segment(s)") + + result_chunks: list[torch.Tensor] = [] # each on CPU, [F_i, H, W, 3] + prev_tail: torch.Tensor | None = None # overlap tail from previous segment + + for seg_idx, (seg_start, seg_end) in enumerate(segments): + print(f"[STAR] Processing segment {seg_idx + 1}/{len(segments)}: " + f"frames {seg_start}-{seg_end - 1}") + + seg_images = images[seg_start:seg_end] + + seg_result = run_star_inference( + star_model=star_model, + images=seg_images, + upscale=upscale, + steps=steps, + guide_scale=guide_scale, + prompt=prompt, + solver_mode=solver_mode, + max_chunk_len=max_chunk_len, + seed=seed, + color_fix=color_fix, + ) + # seg_result: [F_seg, H, W, 3] float32 on CPU + + if prev_tail is not None: + # Blend the overlap region between previous segment's tail and + # this segment's head using a linear ramp. + n_overlap = prev_tail.shape[0] + head = seg_result[:n_overlap] + + # Linear ramp: weight for new segment goes from 0→1 + weight = torch.linspace(0, 1, n_overlap, dtype=seg_result.dtype) + weight = weight.view(n_overlap, 1, 1, 1) # broadcast over H, W, C + + blended = prev_tail * (1.0 - weight) + head * weight + result_chunks.append(blended) + + # Append the non-overlapping portion of this segment. + remainder = seg_result[n_overlap:] + else: + remainder = seg_result + + if seg_idx < len(segments) - 1: + # Save the tail for blending with the next segment, and only + # append the non-tail portion now. + prev_tail = remainder[-overlap:].clone() + result_chunks.append(remainder[:-overlap]) + else: + # Last segment — append everything remaining. + result_chunks.append(remainder) + prev_tail = None + + # Free segment tensors. + del seg_result, seg_images + torch.cuda.empty_cache() + + result = torch.cat(result_chunks, dim=0) + del result_chunks + return result + + def run_star_inference( star_model: dict, images: torch.Tensor,