Add segment-based processing for long videos to reduce RAM usage
Process videos in overlapping segments (25% overlap with linear crossfade blending) so peak memory is bounded by one segment rather than the full video. New segment_size parameter on the Super-Resolution node (default 0 = all at once, recommended 16-32 for long videos). Also update README clone URL to GitHub mirror. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ Search for `ComfyUI-STAR` in ComfyUI Manager and install.
|
||||
|
||||
```bash
|
||||
cd ComfyUI/custom_nodes
|
||||
git clone --recursive git@192.168.1.1:Ethanfel/Comfyui-STAR.git
|
||||
git clone --recursive https://github.com/ethanfel/Comfyui-STAR.git
|
||||
cd Comfyui-STAR
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
@@ -56,6 +56,7 @@ Runs the STAR diffusion pipeline on an image batch.
|
||||
| **max_chunk_len** | Max frames per chunk (4–128, default 32). Lower = less VRAM for long videos. |
|
||||
| **seed** | Random seed for reproducibility. |
|
||||
| **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. |
|
||||
| **segment_size** | Process video in segments of this many frames to reduce RAM usage (0–256, default 0). 0 = process all at once. Recommended: 16–32 for long videos. Segments overlap by 25% with linear crossfade blending. |
|
||||
|
||||
## VRAM Requirements
|
||||
|
||||
|
||||
17
nodes.py
17
nodes.py
@@ -205,6 +205,10 @@ class STARVideoSuperResolution:
|
||||
"default": "adain",
|
||||
"tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.",
|
||||
}),
|
||||
"segment_size": ("INT", {
|
||||
"default": 0, "min": 0, "max": 256,
|
||||
"tooltip": "Process video in segments of this many frames to reduce RAM usage. 0 = process all at once. Recommended: 16-32 for long videos.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,10 +230,9 @@ class STARVideoSuperResolution:
|
||||
max_chunk_len,
|
||||
seed,
|
||||
color_fix,
|
||||
segment_size=0,
|
||||
):
|
||||
from .star_pipeline import run_star_inference
|
||||
|
||||
result = run_star_inference(
|
||||
kwargs = dict(
|
||||
star_model=star_model,
|
||||
images=images,
|
||||
upscale=upscale,
|
||||
@@ -241,6 +244,14 @@ class STARVideoSuperResolution:
|
||||
seed=seed,
|
||||
color_fix=color_fix,
|
||||
)
|
||||
|
||||
if segment_size > 0:
|
||||
from .star_pipeline import run_star_inference_segmented
|
||||
result = run_star_inference_segmented(segment_size=segment_size, **kwargs)
|
||||
else:
|
||||
from .star_pipeline import run_star_inference
|
||||
result = run_star_inference(**kwargs)
|
||||
|
||||
return (result,)
|
||||
|
||||
|
||||
|
||||
106
star_pipeline.py
106
star_pipeline.py
@@ -130,6 +130,112 @@ def _move(module, device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_star_inference_segmented(
|
||||
star_model: dict,
|
||||
images: torch.Tensor,
|
||||
segment_size: int,
|
||||
upscale: int = 4,
|
||||
steps: int = 15,
|
||||
guide_scale: float = 7.5,
|
||||
prompt: str = "",
|
||||
solver_mode: str = "fast",
|
||||
max_chunk_len: int = 32,
|
||||
seed: int = 0,
|
||||
color_fix: str = "adain",
|
||||
) -> torch.Tensor:
|
||||
"""Run STAR inference in overlapping segments to bound peak RAM usage.
|
||||
|
||||
Each segment of `segment_size` frames is processed independently through
|
||||
the full pipeline. Overlap regions (25% of segment_size, minimum 2 frames)
|
||||
are blended with a linear crossfade to avoid temporal seam artifacts.
|
||||
"""
|
||||
total_frames = images.shape[0]
|
||||
|
||||
# Fall back to single-shot if the video fits in one segment.
|
||||
if total_frames <= segment_size:
|
||||
return run_star_inference(
|
||||
star_model=star_model, images=images, upscale=upscale, steps=steps,
|
||||
guide_scale=guide_scale, prompt=prompt, solver_mode=solver_mode,
|
||||
max_chunk_len=max_chunk_len, seed=seed, color_fix=color_fix,
|
||||
)
|
||||
|
||||
overlap = max(2, segment_size // 4)
|
||||
stride = segment_size - overlap
|
||||
|
||||
# Build list of (start, end) frame indices for each segment.
|
||||
segments = []
|
||||
start = 0
|
||||
while start < total_frames:
|
||||
end = min(start + segment_size, total_frames)
|
||||
segments.append((start, end))
|
||||
if end == total_frames:
|
||||
break
|
||||
start += stride
|
||||
|
||||
print(f"[STAR] Segmented processing: {total_frames} frames, "
|
||||
f"segment_size={segment_size}, overlap={overlap}, "
|
||||
f"{len(segments)} segment(s)")
|
||||
|
||||
result_chunks: list[torch.Tensor] = [] # each on CPU, [F_i, H, W, 3]
|
||||
prev_tail: torch.Tensor | None = None # overlap tail from previous segment
|
||||
|
||||
for seg_idx, (seg_start, seg_end) in enumerate(segments):
|
||||
print(f"[STAR] Processing segment {seg_idx + 1}/{len(segments)}: "
|
||||
f"frames {seg_start}-{seg_end - 1}")
|
||||
|
||||
seg_images = images[seg_start:seg_end]
|
||||
|
||||
seg_result = run_star_inference(
|
||||
star_model=star_model,
|
||||
images=seg_images,
|
||||
upscale=upscale,
|
||||
steps=steps,
|
||||
guide_scale=guide_scale,
|
||||
prompt=prompt,
|
||||
solver_mode=solver_mode,
|
||||
max_chunk_len=max_chunk_len,
|
||||
seed=seed,
|
||||
color_fix=color_fix,
|
||||
)
|
||||
# seg_result: [F_seg, H, W, 3] float32 on CPU
|
||||
|
||||
if prev_tail is not None:
|
||||
# Blend the overlap region between previous segment's tail and
|
||||
# this segment's head using a linear ramp.
|
||||
n_overlap = prev_tail.shape[0]
|
||||
head = seg_result[:n_overlap]
|
||||
|
||||
# Linear ramp: weight for new segment goes from 0→1
|
||||
weight = torch.linspace(0, 1, n_overlap, dtype=seg_result.dtype)
|
||||
weight = weight.view(n_overlap, 1, 1, 1) # broadcast over H, W, C
|
||||
|
||||
blended = prev_tail * (1.0 - weight) + head * weight
|
||||
result_chunks.append(blended)
|
||||
|
||||
# Append the non-overlapping portion of this segment.
|
||||
remainder = seg_result[n_overlap:]
|
||||
else:
|
||||
remainder = seg_result
|
||||
|
||||
if seg_idx < len(segments) - 1:
|
||||
# Save the tail for blending with the next segment, and only
|
||||
# append the non-tail portion now.
|
||||
prev_tail = remainder[-overlap:].clone()
|
||||
result_chunks.append(remainder[:-overlap])
|
||||
else:
|
||||
# Last segment — append everything remaining.
|
||||
result_chunks.append(remainder)
|
||||
prev_tail = None
|
||||
|
||||
# Free segment tensors.
|
||||
del seg_result, seg_images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
del result_chunks
|
||||
return result
|
||||
|
||||
|
||||
def run_star_inference(
|
||||
star_model: dict,
|
||||
images: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user