Add standalone CLI for memory-efficient video upscaling

Standalone inference script that works outside ComfyUI — just activate
the same Python venv. Streams output frames to ffmpeg so peak RAM stays
bounded regardless of video length. Supports video files, image
sequences, and single images. Audio is automatically preserved from
input videos.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 23:46:58 +01:00
parent 8794f8ddec
commit 6cf314baf4
2 changed files with 594 additions and 0 deletions

View File

@@ -9,7 +9,9 @@ ComfyUI custom nodes for [STAR (Spatial-Temporal Augmentation with Text-to-Video
- **Auto-download**: all models (UNet checkpoint, OpenCLIP text encoder, temporal VAE) download automatically on first use - **Auto-download**: all models (UNet checkpoint, OpenCLIP text encoder, temporal VAE) download automatically on first use
- **VRAM offloading**: three modes to fit GPUs from 12GB to 40GB+ - **VRAM offloading**: three modes to fit GPUs from 12GB to 40GB+
- **Long video support**: sliding-window chunking with 50% overlap - **Long video support**: sliding-window chunking with 50% overlap
- **Segment-based processing**: bound peak RAM for long videos
- **Color correction**: AdaIN and wavelet-based post-processing - **Color correction**: AdaIN and wavelet-based post-processing
- **Standalone CLI**: run from the terminal without ComfyUI for long videos
## Installation ## Installation
@@ -83,6 +85,31 @@ Models are stored in `ComfyUI/models/star/` and auto-downloaded on first use:
The OpenCLIP text encoder and SVD temporal VAE are downloaded automatically by their respective libraries on first load. The OpenCLIP text encoder and SVD temporal VAE are downloaded automatically by their respective libraries on first load.
## Standalone CLI
For long videos where ComfyUI's RAM usage becomes a bottleneck, use the standalone script directly. It streams output frames to ffmpeg so peak memory stays bounded regardless of video length.
```bash
# Activate your ComfyUI Python environment, then:
python inference.py input.mp4 -o output.mp4
# With model offloading for lower VRAM
python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8
# Image sequence input/output
python inference.py frames_in/ -o frames_out/
# Image sequence to video
python inference.py frames_in/ -o output.mp4 --fps 24
# Single image
python inference.py photo.png -o photo_4x.png
```
Audio is automatically copied from the input video. Use `--no-audio` to disable.
Run `python inference.py --help` for all options.
## Credits ## Credits
- [STAR](https://github.com/NJU-PCALab/STAR) by Rui Xie, Yinhong Liu et al. (Nanjing University) — ICCV 2025 - [STAR](https://github.com/NJU-PCALab/STAR) by Rui Xie, Yinhong Liu et al. (Nanjing University) — ICCV 2025

567
inference.py Executable file
View File

@@ -0,0 +1,567 @@
#!/usr/bin/env python3
"""STAR Video Super-Resolution — Standalone Inference Script
Memory-efficient video upscaling from the command line. Works outside
ComfyUI — just activate the same Python environment.
Examples
--------
# Video → video (audio is preserved automatically)
python inference.py input.mp4 -o output.mp4
# Lower VRAM (model offload + smaller segments)
python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8
# Image sequence → image sequence
python inference.py frames_in/ -o frames_out/
# Image sequence → video
python inference.py frames_in/ -o output.mp4 --fps 24
# Single image
python inference.py photo.png -o photo_4x.png
"""
# ── Comfy module stubs (must run before star_pipeline import) ────────────
import os
import sys
import types
from pathlib import Path
SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(SCRIPT_DIR))
sys.path.insert(0, str(SCRIPT_DIR / "STAR"))
import torch # noqa: E402 — needed for stub defaults
_comfy = types.ModuleType("comfy")
_comfy_utils = types.ModuleType("comfy.utils")
_comfy_mm = types.ModuleType("comfy.model_management")
class _ProgressBar:
"""tqdm-based stand-in for comfy.utils.ProgressBar."""
def __init__(self, total):
from tqdm import tqdm
self._bar = tqdm(total=total, desc="Denoising", unit="step")
def update(self, n=1):
self._bar.update(n)
def __del__(self):
if hasattr(self, "_bar"):
self._bar.close()
_comfy_utils.ProgressBar = _ProgressBar
_comfy_mm.get_torch_device = lambda: torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
_comfy_mm.soft_empty_cache = lambda: torch.cuda.empty_cache()
_comfy.utils = _comfy_utils
_comfy.model_management = _comfy_mm
sys.modules["comfy"] = _comfy
sys.modules["comfy.utils"] = _comfy_utils
sys.modules["comfy.model_management"] = _comfy_mm
# ── Standard imports ────────────────────────────────────────────────────
import argparse # noqa: E402
import json # noqa: E402
import shutil # noqa: E402
import subprocess # noqa: E402
import numpy as np # noqa: E402
from PIL import Image # noqa: E402
# ── Constants ───────────────────────────────────────────────────────────
HF_REPO = "SherryX/STAR"
HF_MODELS = {
"light_deg.pt": "I2VGen-XL-based/light_deg.pt",
"heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt",
}
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".ts"}
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
# ── Argument parsing ───────────────────────────────────────────────────
def parse_args():
class Fmt(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawDescriptionHelpFormatter):
pass
p = argparse.ArgumentParser(
description="STAR Video Super-Resolution — standalone inference",
formatter_class=Fmt,
epilog=__doc__,
)
# -- I/O --
p.add_argument("input", help="Input video file, image file, or directory of frames")
p.add_argument("-o", "--output",
help="Output path (video file, image file, or directory). "
"Auto-generated with _star suffix if omitted.")
# -- Model --
g = p.add_argument_group("model")
g.add_argument("--model", default="light_deg.pt",
help="Model name (light_deg.pt / heavy_deg.pt) or path to .pt file")
g.add_argument("--precision", default="fp16", choices=["fp16", "bf16", "fp32"],
help="Weight precision")
g.add_argument("--offload", default="model",
choices=["disabled", "model", "aggressive"],
help="VRAM offloading strategy")
# -- Processing --
g = p.add_argument_group("processing")
g.add_argument("--upscale", type=int, default=4, help="Upscale factor")
g.add_argument("--segment-size", type=int, default=16,
help="Frames per segment (bounds peak RAM). 0 = all at once")
g.add_argument("--steps", type=int, default=15, help="Denoising steps")
g.add_argument("--guide-scale", type=float, default=7.5, help="Guidance scale")
g.add_argument("--solver-mode", default="fast", choices=["fast", "normal"])
g.add_argument("--max-chunk-len", type=int, default=32,
help="Temporal chunk length inside diffusion loop")
g.add_argument("--seed", type=int, default=0, help="Random seed")
g.add_argument("--color-fix", default="adain",
choices=["adain", "wavelet", "none"],
help="Post-processing color correction")
g.add_argument("--prompt", default="",
help="Text prompt (empty = STAR built-in quality prompt)")
# -- Video output --
g = p.add_argument_group("video output")
g.add_argument("--fps", type=float, default=None,
help="Output FPS (default: match input, or 24 for image sequences)")
g.add_argument("--codec", default="libx264", help="FFmpeg video codec")
g.add_argument("--crf", type=int, default=18,
help="FFmpeg CRF quality (lower = better)")
g.add_argument("--pix-fmt", default="yuv420p", help="FFmpeg pixel format")
g.add_argument("--no-audio", action="store_true",
help="Do not copy audio from input video")
return p.parse_args()
# ── Model resolution ───────────────────────────────────────────────────
def resolve_model_path(model_arg: str) -> str:
if os.path.isfile(model_arg):
return model_arg
search = [
SCRIPT_DIR / "models" / model_arg,
# Standard ComfyUI layout: custom_nodes/Comfyui-STAR/../../models/star/
SCRIPT_DIR / ".." / ".." / "models" / "star" / model_arg,
]
for candidate in search:
candidate = candidate.resolve()
if candidate.is_file():
return str(candidate)
if model_arg not in HF_MODELS:
raise FileNotFoundError(
f"Model '{model_arg}' not found locally and is not a known "
"downloadable model. Provide a full path, or use "
"light_deg.pt / heavy_deg.pt."
)
from huggingface_hub import hf_hub_download
print(f"[STAR] Downloading {model_arg} from HuggingFace ({HF_REPO})...")
dest_dir = str(SCRIPT_DIR / "models")
os.makedirs(dest_dir, exist_ok=True)
path = hf_hub_download(
repo_id=HF_REPO, filename=HF_MODELS[model_arg], local_dir=dest_dir,
)
return path
# ── Model loading (mirrors STARModelLoader.load_model) ─────────────────
def load_model(model_path: str, precision: str, offload: str, device: torch.device):
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map[precision]
keep_on = device if offload == "disabled" else "cpu"
from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder
from video_to_video.utils.config import cfg
print("[STAR] Loading text encoder (OpenCLIP ViT-H-14)...")
text_encoder = FrozenOpenCLIPEmbedder(device=device, pretrained="laion2b_s32b_b79k")
text_encoder.model.to(device)
negative_y = text_encoder(cfg.negative_prompt).detach()
text_encoder.model.to(keep_on)
from video_to_video.modules.unet_v2v import ControlledV2VUNet
print("[STAR] Loading UNet + ControlNet...")
generator = ControlledV2VUNet()
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
generator.load_state_dict(ckpt, strict=False)
del ckpt
generator = generator.to(device=keep_on, dtype=dtype)
generator.eval()
from video_to_video.diffusion.schedules_sdedit import noise_schedule
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
sigmas = noise_schedule(
schedule="logsnr_cosine_interp", n=1000,
zero_terminal_snr=True, scale_min=2.0, scale_max=4.0,
)
diffusion = GaussianDiffusion(sigmas=sigmas)
from diffusers import AutoencoderKLTemporalDecoder
print("[STAR] Loading temporal VAE...")
vae = AutoencoderKLTemporalDecoder.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="vae", variant="fp16",
)
vae.eval()
vae.requires_grad_(False)
vae.to(keep_on)
torch.cuda.empty_cache()
print("[STAR] All models loaded.")
return {
"text_encoder": text_encoder,
"generator": generator,
"diffusion": diffusion,
"vae": vae,
"negative_y": negative_y,
"device": device,
"dtype": dtype,
"offload": offload,
}
# ── Input reading ──────────────────────────────────────────────────────
def _ffprobe(path: str):
"""Return (width, height, fps, nb_frames) via ffprobe."""
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", "-show_format", str(path),
]
info = json.loads(subprocess.check_output(cmd))
vs = next(s for s in info["streams"] if s["codec_type"] == "video")
w, h = int(vs["width"]), int(vs["height"])
num, den = map(int, vs.get("r_frame_rate", "24/1").split("/"))
fps = num / den if den else 24.0
nb = vs.get("nb_frames")
if nb and nb != "N/A":
n_frames = int(nb)
else:
dur = float(info.get("format", {}).get("duration", 0))
n_frames = int(dur * fps) or 0
return w, h, fps, n_frames
def read_video(path: str):
"""Read video → (np array [N,H,W,3] uint8, fps)."""
w, h, fps, est = _ffprobe(path)
print(f"[STAR] Input video: {w}x{h}, {fps:.2f} fps, ~{est} frames")
cmd = [
"ffmpeg", "-i", str(path),
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-v", "quiet", "-",
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
frames = []
fsize = w * h * 3
while True:
raw = proc.stdout.read(fsize)
if len(raw) < fsize:
break
frames.append(np.frombuffer(raw, dtype=np.uint8).reshape(h, w, 3))
proc.wait()
print(f"[STAR] Read {len(frames)} frames")
return np.stack(frames), fps
def read_image_dir(directory: str):
"""Read image directory → (np array [N,H,W,3] uint8, None)."""
d = Path(directory)
files = sorted(f for f in d.iterdir() if f.suffix.lower() in IMAGE_EXTS)
if not files:
raise FileNotFoundError(f"No image files in {d}")
print(f"[STAR] Loading {len(files)} images from {d}")
frames = [np.array(Image.open(f).convert("RGB")) for f in files]
return np.stack(frames), None
def read_input(path: str):
"""Auto-detect input type → (np array [N,H,W,3] uint8, fps | None)."""
p = Path(path)
if p.is_dir():
return read_image_dir(path)
if p.suffix.lower() in VIDEO_EXTS:
return read_video(path)
if p.suffix.lower() in IMAGE_EXTS:
img = np.array(Image.open(p).convert("RGB"))
return img[np.newaxis], None
raise ValueError(f"Unsupported input: {path}")
# ── Output writing ─────────────────────────────────────────────────────
class VideoWriter:
"""Stream RGB frames to ffmpeg, optionally copying audio from source."""
def __init__(self, output_path, fps, width, height,
codec="libx264", crf=18, pix_fmt="yuv420p",
audio_source=None):
cmd = [
"ffmpeg", "-y",
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-s", f"{width}x{height}", "-r", str(fps),
"-i", "-",
]
if audio_source:
cmd += ["-i", str(audio_source)]
cmd += ["-map", "0:v:0"]
if audio_source:
cmd += ["-map", "1:a?", "-c:a", "copy"]
cmd += [
"-c:v", codec, "-crf", str(crf), "-pix_fmt", pix_fmt,
"-movflags", "+faststart",
"-v", "warning",
str(output_path),
]
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
self.count = 0
def write_frame(self, frame_uint8):
self.proc.stdin.write(frame_uint8.tobytes())
self.count += 1
def close(self):
self.proc.stdin.close()
self.proc.wait()
class ImageSequenceWriter:
"""Save frames as numbered image files."""
def __init__(self, out_dir, ext=".png"):
self.out_dir = Path(out_dir)
self.out_dir.mkdir(parents=True, exist_ok=True)
self.ext = ext
self.count = 0
def write_frame(self, frame_uint8):
Image.fromarray(frame_uint8).save(
self.out_dir / f"{self.count:06d}{self.ext}"
)
self.count += 1
def close(self):
pass
class SingleImageWriter:
"""Save a single output image."""
def __init__(self, path):
self.path = Path(path)
self.path.parent.mkdir(parents=True, exist_ok=True)
self.count = 0
def write_frame(self, frame_uint8):
Image.fromarray(frame_uint8).save(self.path)
self.count += 1
def close(self):
pass
# ── Output path helpers ────────────────────────────────────────────────
def is_video_path(p):
return Path(p).suffix.lower() in VIDEO_EXTS
def is_image_path(p):
return Path(p).suffix.lower() in IMAGE_EXTS
def auto_output(input_path: str) -> str:
p = Path(input_path)
if p.is_dir():
return str(p.parent / (p.name + "_star"))
return str(p.parent / (p.stem + "_star" + p.suffix))
def make_writer(output_path, fps, w, h, args, input_path, is_single_image):
"""Create the appropriate writer for the output path."""
if is_single_image and is_image_path(output_path):
return SingleImageWriter(output_path)
if is_video_path(output_path):
audio_src = input_path if (
is_video_path(input_path) and not args.no_audio
) else None
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
return VideoWriter(
output_path, fps, w, h,
codec=args.codec, crf=args.crf, pix_fmt=args.pix_fmt,
audio_source=audio_src,
)
# Default: image sequence directory
return ImageSequenceWriter(output_path, ext=".png")
# ── Tensor ↔ numpy ─────────────────────────────────────────────────────
def to_tensor(frames_uint8):
"""[N,H,W,3] uint8 numpy → [N,H,W,3] float32 torch in [0,1]."""
return torch.from_numpy(frames_uint8.copy()).float() / 255.0
def to_uint8(tensor):
"""[N,H,W,3] float32 torch in [0,1] → [N,H,W,3] uint8 numpy."""
return (tensor.clamp(0, 1) * 255).byte().cpu().numpy()
# ── Segment-based processing with streaming output ─────────────────────
def _run_segment(star_model, frames_uint8, args):
"""Process one segment through STAR → float32 tensor [F,H,W,3]."""
from star_pipeline import run_star_inference
tensor = to_tensor(frames_uint8)
return run_star_inference(
star_model=star_model,
images=tensor,
upscale=args.upscale,
steps=args.steps,
guide_scale=args.guide_scale,
prompt=args.prompt,
solver_mode=args.solver_mode,
max_chunk_len=args.max_chunk_len,
seed=args.seed,
color_fix=args.color_fix,
)
def _write_tensor(writer, tensor):
"""Write a float32 [F,H,W,3] tensor as uint8 frames."""
arr = to_uint8(tensor)
for i in range(arr.shape[0]):
writer.write_frame(arr[i])
def process_and_stream(star_model, input_frames, writer, args):
"""Process in segments, blend overlaps, and stream to writer."""
total = input_frames.shape[0]
seg = args.segment_size
# No segmentation — process everything at once
if seg <= 0 or total <= seg:
print(f"[STAR] Processing all {total} frame(s)...")
result = _run_segment(star_model, input_frames, args)
_write_tensor(writer, result)
return
overlap = max(2, seg // 4)
stride = seg - overlap
# Build segment boundaries
segments = []
start = 0
while start < total:
end = min(start + seg, total)
segments.append((start, end))
if end == total:
break
start += stride
print(f"[STAR] {total} frames → {len(segments)} segment(s), "
f"segment_size={seg}, overlap={overlap}")
prev_tail = None # float32 tensor on CPU
for idx, (s, e) in enumerate(segments):
print(f"\n[STAR] ── Segment {idx + 1}/{len(segments)}: "
f"frames {s}{e - 1} ──")
seg_result = _run_segment(star_model, input_frames[s:e], args)
# Blend overlap with previous segment's tail
if prev_tail is not None:
n = prev_tail.shape[0]
head = seg_result[:n]
w = torch.linspace(0, 1, n, dtype=seg_result.dtype).view(n, 1, 1, 1)
blended = prev_tail * (1.0 - w) + head * w
_write_tensor(writer, blended)
remainder = seg_result[n:]
else:
remainder = seg_result
if idx < len(segments) - 1:
# Keep tail for blending, write the rest
prev_tail = remainder[-overlap:].clone()
_write_tensor(writer, remainder[:-overlap])
else:
# Last segment — write everything
_write_tensor(writer, remainder)
del seg_result
torch.cuda.empty_cache()
# ── Main ────────────────────────────────────────────────────────────────
def main():
args = parse_args()
# Validate environment
if not torch.cuda.is_available():
print("Error: CUDA is not available. STAR requires a CUDA GPU.")
sys.exit(1)
if not shutil.which("ffmpeg") or not shutil.which("ffprobe"):
input_p = Path(args.input)
if input_p.suffix.lower() in VIDEO_EXTS or (
args.output and Path(args.output).suffix.lower() in VIDEO_EXTS
):
print("Error: ffmpeg/ffprobe not found. Install ffmpeg for video I/O.")
sys.exit(1)
# Read input
input_frames, input_fps = read_input(args.input)
total = input_frames.shape[0]
h_in, w_in = input_frames.shape[1], input_frames.shape[2]
h_out, w_out = h_in * args.upscale, w_in * args.upscale
is_single = total == 1 and Path(args.input).is_file() and is_image_path(args.input)
print(f"[STAR] {w_in}x{h_in}{w_out}x{h_out} ({args.upscale}x), "
f"{total} frame(s)")
# Output path
output_path = args.output or auto_output(args.input)
fps = args.fps or input_fps or 24.0
# Load model
device = torch.device("cuda")
model_path = resolve_model_path(args.model)
print(f"[STAR] Model: {model_path}")
star_model = load_model(model_path, args.precision, args.offload, device)
# Create writer and process
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
process_and_stream(star_model, input_frames, writer, args)
writer.close()
print(f"\n[STAR] Done! {writer.count} frame(s) → {output_path}")
if __name__ == "__main__":
main()