Add segment-based processing for long videos to reduce RAM usage
Process videos in overlapping segments (25% overlap with linear crossfade blending) so peak memory is bounded by one segment rather than the full video. New segment_size parameter on the Super-Resolution node (default 0 = all at once, recommended 16-32 for long videos). Also update README clone URL to GitHub mirror. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
106
star_pipeline.py
106
star_pipeline.py
@@ -130,6 +130,112 @@ def _move(module, device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_star_inference_segmented(
|
||||
star_model: dict,
|
||||
images: torch.Tensor,
|
||||
segment_size: int,
|
||||
upscale: int = 4,
|
||||
steps: int = 15,
|
||||
guide_scale: float = 7.5,
|
||||
prompt: str = "",
|
||||
solver_mode: str = "fast",
|
||||
max_chunk_len: int = 32,
|
||||
seed: int = 0,
|
||||
color_fix: str = "adain",
|
||||
) -> torch.Tensor:
|
||||
"""Run STAR inference in overlapping segments to bound peak RAM usage.
|
||||
|
||||
Each segment of `segment_size` frames is processed independently through
|
||||
the full pipeline. Overlap regions (25% of segment_size, minimum 2 frames)
|
||||
are blended with a linear crossfade to avoid temporal seam artifacts.
|
||||
"""
|
||||
total_frames = images.shape[0]
|
||||
|
||||
# Fall back to single-shot if the video fits in one segment.
|
||||
if total_frames <= segment_size:
|
||||
return run_star_inference(
|
||||
star_model=star_model, images=images, upscale=upscale, steps=steps,
|
||||
guide_scale=guide_scale, prompt=prompt, solver_mode=solver_mode,
|
||||
max_chunk_len=max_chunk_len, seed=seed, color_fix=color_fix,
|
||||
)
|
||||
|
||||
overlap = max(2, segment_size // 4)
|
||||
stride = segment_size - overlap
|
||||
|
||||
# Build list of (start, end) frame indices for each segment.
|
||||
segments = []
|
||||
start = 0
|
||||
while start < total_frames:
|
||||
end = min(start + segment_size, total_frames)
|
||||
segments.append((start, end))
|
||||
if end == total_frames:
|
||||
break
|
||||
start += stride
|
||||
|
||||
print(f"[STAR] Segmented processing: {total_frames} frames, "
|
||||
f"segment_size={segment_size}, overlap={overlap}, "
|
||||
f"{len(segments)} segment(s)")
|
||||
|
||||
result_chunks: list[torch.Tensor] = [] # each on CPU, [F_i, H, W, 3]
|
||||
prev_tail: torch.Tensor | None = None # overlap tail from previous segment
|
||||
|
||||
for seg_idx, (seg_start, seg_end) in enumerate(segments):
|
||||
print(f"[STAR] Processing segment {seg_idx + 1}/{len(segments)}: "
|
||||
f"frames {seg_start}-{seg_end - 1}")
|
||||
|
||||
seg_images = images[seg_start:seg_end]
|
||||
|
||||
seg_result = run_star_inference(
|
||||
star_model=star_model,
|
||||
images=seg_images,
|
||||
upscale=upscale,
|
||||
steps=steps,
|
||||
guide_scale=guide_scale,
|
||||
prompt=prompt,
|
||||
solver_mode=solver_mode,
|
||||
max_chunk_len=max_chunk_len,
|
||||
seed=seed,
|
||||
color_fix=color_fix,
|
||||
)
|
||||
# seg_result: [F_seg, H, W, 3] float32 on CPU
|
||||
|
||||
if prev_tail is not None:
|
||||
# Blend the overlap region between previous segment's tail and
|
||||
# this segment's head using a linear ramp.
|
||||
n_overlap = prev_tail.shape[0]
|
||||
head = seg_result[:n_overlap]
|
||||
|
||||
# Linear ramp: weight for new segment goes from 0→1
|
||||
weight = torch.linspace(0, 1, n_overlap, dtype=seg_result.dtype)
|
||||
weight = weight.view(n_overlap, 1, 1, 1) # broadcast over H, W, C
|
||||
|
||||
blended = prev_tail * (1.0 - weight) + head * weight
|
||||
result_chunks.append(blended)
|
||||
|
||||
# Append the non-overlapping portion of this segment.
|
||||
remainder = seg_result[n_overlap:]
|
||||
else:
|
||||
remainder = seg_result
|
||||
|
||||
if seg_idx < len(segments) - 1:
|
||||
# Save the tail for blending with the next segment, and only
|
||||
# append the non-tail portion now.
|
||||
prev_tail = remainder[-overlap:].clone()
|
||||
result_chunks.append(remainder[:-overlap])
|
||||
else:
|
||||
# Last segment — append everything remaining.
|
||||
result_chunks.append(remainder)
|
||||
prev_tail = None
|
||||
|
||||
# Free segment tensors.
|
||||
del seg_result, seg_images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
del result_chunks
|
||||
return result
|
||||
|
||||
|
||||
def run_star_inference(
|
||||
star_model: dict,
|
||||
images: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user