Add segment-based processing for long videos to reduce RAM usage

Process videos in overlapping segments (25% overlap with linear crossfade
blending) so peak memory is bounded by one segment rather than the full
video. New segment_size parameter on the Super-Resolution node (default 0
= all at once, recommended 16-32 for long videos). Also update README
clone URL to GitHub mirror.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 23:28:01 +01:00
parent 5f9287cfac
commit f7021e95f4
3 changed files with 122 additions and 4 deletions

View File

@@ -21,7 +21,7 @@ Search for `ComfyUI-STAR` in ComfyUI Manager and install.
```bash ```bash
cd ComfyUI/custom_nodes cd ComfyUI/custom_nodes
git clone --recursive git@192.168.1.1:Ethanfel/Comfyui-STAR.git git clone --recursive https://github.com/ethanfel/Comfyui-STAR.git
cd Comfyui-STAR cd Comfyui-STAR
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -56,6 +56,7 @@ Runs the STAR diffusion pipeline on an image batch.
| **max_chunk_len** | Max frames per chunk (4128, default 32). Lower = less VRAM for long videos. | | **max_chunk_len** | Max frames per chunk (4128, default 32). Lower = less VRAM for long videos. |
| **seed** | Random seed for reproducibility. | | **seed** | Random seed for reproducibility. |
| **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. | | **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. |
| **segment_size** | Process video in segments of this many frames to reduce RAM usage (0256, default 0). 0 = process all at once. Recommended: 1632 for long videos. Segments overlap by 25% with linear crossfade blending. |
## VRAM Requirements ## VRAM Requirements

View File

@@ -205,6 +205,10 @@ class STARVideoSuperResolution:
"default": "adain", "default": "adain",
"tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.", "tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.",
}), }),
"segment_size": ("INT", {
"default": 0, "min": 0, "max": 256,
"tooltip": "Process video in segments of this many frames to reduce RAM usage. 0 = process all at once. Recommended: 16-32 for long videos.",
}),
} }
} }
@@ -226,10 +230,9 @@ class STARVideoSuperResolution:
max_chunk_len, max_chunk_len,
seed, seed,
color_fix, color_fix,
segment_size=0,
): ):
from .star_pipeline import run_star_inference kwargs = dict(
result = run_star_inference(
star_model=star_model, star_model=star_model,
images=images, images=images,
upscale=upscale, upscale=upscale,
@@ -241,6 +244,14 @@ class STARVideoSuperResolution:
seed=seed, seed=seed,
color_fix=color_fix, color_fix=color_fix,
) )
if segment_size > 0:
from .star_pipeline import run_star_inference_segmented
result = run_star_inference_segmented(segment_size=segment_size, **kwargs)
else:
from .star_pipeline import run_star_inference
result = run_star_inference(**kwargs)
return (result,) return (result,)

View File

@@ -130,6 +130,112 @@ def _move(module, device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def run_star_inference_segmented(
star_model: dict,
images: torch.Tensor,
segment_size: int,
upscale: int = 4,
steps: int = 15,
guide_scale: float = 7.5,
prompt: str = "",
solver_mode: str = "fast",
max_chunk_len: int = 32,
seed: int = 0,
color_fix: str = "adain",
) -> torch.Tensor:
"""Run STAR inference in overlapping segments to bound peak RAM usage.
Each segment of `segment_size` frames is processed independently through
the full pipeline. Overlap regions (25% of segment_size, minimum 2 frames)
are blended with a linear crossfade to avoid temporal seam artifacts.
"""
total_frames = images.shape[0]
# Fall back to single-shot if the video fits in one segment.
if total_frames <= segment_size:
return run_star_inference(
star_model=star_model, images=images, upscale=upscale, steps=steps,
guide_scale=guide_scale, prompt=prompt, solver_mode=solver_mode,
max_chunk_len=max_chunk_len, seed=seed, color_fix=color_fix,
)
overlap = max(2, segment_size // 4)
stride = segment_size - overlap
# Build list of (start, end) frame indices for each segment.
segments = []
start = 0
while start < total_frames:
end = min(start + segment_size, total_frames)
segments.append((start, end))
if end == total_frames:
break
start += stride
print(f"[STAR] Segmented processing: {total_frames} frames, "
f"segment_size={segment_size}, overlap={overlap}, "
f"{len(segments)} segment(s)")
result_chunks: list[torch.Tensor] = [] # each on CPU, [F_i, H, W, 3]
prev_tail: torch.Tensor | None = None # overlap tail from previous segment
for seg_idx, (seg_start, seg_end) in enumerate(segments):
print(f"[STAR] Processing segment {seg_idx + 1}/{len(segments)}: "
f"frames {seg_start}-{seg_end - 1}")
seg_images = images[seg_start:seg_end]
seg_result = run_star_inference(
star_model=star_model,
images=seg_images,
upscale=upscale,
steps=steps,
guide_scale=guide_scale,
prompt=prompt,
solver_mode=solver_mode,
max_chunk_len=max_chunk_len,
seed=seed,
color_fix=color_fix,
)
# seg_result: [F_seg, H, W, 3] float32 on CPU
if prev_tail is not None:
# Blend the overlap region between previous segment's tail and
# this segment's head using a linear ramp.
n_overlap = prev_tail.shape[0]
head = seg_result[:n_overlap]
# Linear ramp: weight for new segment goes from 0→1
weight = torch.linspace(0, 1, n_overlap, dtype=seg_result.dtype)
weight = weight.view(n_overlap, 1, 1, 1) # broadcast over H, W, C
blended = prev_tail * (1.0 - weight) + head * weight
result_chunks.append(blended)
# Append the non-overlapping portion of this segment.
remainder = seg_result[n_overlap:]
else:
remainder = seg_result
if seg_idx < len(segments) - 1:
# Save the tail for blending with the next segment, and only
# append the non-tail portion now.
prev_tail = remainder[-overlap:].clone()
result_chunks.append(remainder[:-overlap])
else:
# Last segment — append everything remaining.
result_chunks.append(remainder)
prev_tail = None
# Free segment tensors.
del seg_result, seg_images
torch.cuda.empty_cache()
result = torch.cat(result_chunks, dim=0)
del result_chunks
return result
def run_star_inference( def run_star_inference(
star_model: dict, star_model: dict,
images: torch.Tensor, images: torch.Tensor,