diff --git a/README.md b/README.md index 61462de..0d01af7 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ ComfyUI custom nodes for [STAR (Spatial-Temporal Augmentation with Text-to-Video - **Auto-download**: all models (UNet checkpoint, OpenCLIP text encoder, temporal VAE) download automatically on first use - **VRAM offloading**: three modes to fit GPUs from 12GB to 40GB+ - **Long video support**: sliding-window chunking with 50% overlap +- **Segment-based processing**: bound peak RAM for long videos - **Color correction**: AdaIN and wavelet-based post-processing +- **Standalone CLI**: run from the terminal without ComfyUI for long videos ## Installation @@ -83,6 +85,31 @@ Models are stored in `ComfyUI/models/star/` and auto-downloaded on first use: The OpenCLIP text encoder and SVD temporal VAE are downloaded automatically by their respective libraries on first load. +## Standalone CLI + +For long videos where ComfyUI's RAM usage becomes a bottleneck, use the standalone script directly. It streams output frames to ffmpeg so peak memory stays bounded regardless of video length. + +```bash +# Activate your ComfyUI Python environment, then: +python inference.py input.mp4 -o output.mp4 + +# With model offloading for lower VRAM +python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8 + +# Image sequence input/output +python inference.py frames_in/ -o frames_out/ + +# Image sequence to video +python inference.py frames_in/ -o output.mp4 --fps 24 + +# Single image +python inference.py photo.png -o photo_4x.png +``` + +Audio is automatically copied from the input video. Use `--no-audio` to disable. + +Run `python inference.py --help` for all options. + ## Credits - [STAR](https://github.com/NJU-PCALab/STAR) by Rui Xie, Yinhong Liu et al. (Nanjing University) — ICCV 2025 diff --git a/inference.py b/inference.py new file mode 100755 index 0000000..0be4f1c --- /dev/null +++ b/inference.py @@ -0,0 +1,567 @@ +#!/usr/bin/env python3 +"""STAR Video Super-Resolution — Standalone Inference Script + +Memory-efficient video upscaling from the command line. Works outside +ComfyUI — just activate the same Python environment. + +Examples +-------- + # Video → video (audio is preserved automatically) + python inference.py input.mp4 -o output.mp4 + + # Lower VRAM (model offload + smaller segments) + python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8 + + # Image sequence → image sequence + python inference.py frames_in/ -o frames_out/ + + # Image sequence → video + python inference.py frames_in/ -o output.mp4 --fps 24 + + # Single image + python inference.py photo.png -o photo_4x.png +""" + +# ── Comfy module stubs (must run before star_pipeline import) ──────────── +import os +import sys +import types +from pathlib import Path + +SCRIPT_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(SCRIPT_DIR)) +sys.path.insert(0, str(SCRIPT_DIR / "STAR")) + +import torch # noqa: E402 — needed for stub defaults + +_comfy = types.ModuleType("comfy") +_comfy_utils = types.ModuleType("comfy.utils") +_comfy_mm = types.ModuleType("comfy.model_management") + + +class _ProgressBar: + """tqdm-based stand-in for comfy.utils.ProgressBar.""" + + def __init__(self, total): + from tqdm import tqdm + + self._bar = tqdm(total=total, desc="Denoising", unit="step") + + def update(self, n=1): + self._bar.update(n) + + def __del__(self): + if hasattr(self, "_bar"): + self._bar.close() + + +_comfy_utils.ProgressBar = _ProgressBar +_comfy_mm.get_torch_device = lambda: torch.device( + "cuda" if torch.cuda.is_available() else "cpu" +) +_comfy_mm.soft_empty_cache = lambda: torch.cuda.empty_cache() + +_comfy.utils = _comfy_utils +_comfy.model_management = _comfy_mm + +sys.modules["comfy"] = _comfy +sys.modules["comfy.utils"] = _comfy_utils +sys.modules["comfy.model_management"] = _comfy_mm + +# ── Standard imports ──────────────────────────────────────────────────── +import argparse # noqa: E402 +import json # noqa: E402 +import shutil # noqa: E402 +import subprocess # noqa: E402 + +import numpy as np # noqa: E402 +from PIL import Image # noqa: E402 + +# ── Constants ─────────────────────────────────────────────────────────── +HF_REPO = "SherryX/STAR" +HF_MODELS = { + "light_deg.pt": "I2VGen-XL-based/light_deg.pt", + "heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt", +} +VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".ts"} +IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"} + + +# ── Argument parsing ─────────────────────────────────────────────────── +def parse_args(): + class Fmt(argparse.ArgumentDefaultsHelpFormatter, + argparse.RawDescriptionHelpFormatter): + pass + + p = argparse.ArgumentParser( + description="STAR Video Super-Resolution — standalone inference", + formatter_class=Fmt, + epilog=__doc__, + ) + + # -- I/O -- + p.add_argument("input", help="Input video file, image file, or directory of frames") + p.add_argument("-o", "--output", + help="Output path (video file, image file, or directory). " + "Auto-generated with _star suffix if omitted.") + + # -- Model -- + g = p.add_argument_group("model") + g.add_argument("--model", default="light_deg.pt", + help="Model name (light_deg.pt / heavy_deg.pt) or path to .pt file") + g.add_argument("--precision", default="fp16", choices=["fp16", "bf16", "fp32"], + help="Weight precision") + g.add_argument("--offload", default="model", + choices=["disabled", "model", "aggressive"], + help="VRAM offloading strategy") + + # -- Processing -- + g = p.add_argument_group("processing") + g.add_argument("--upscale", type=int, default=4, help="Upscale factor") + g.add_argument("--segment-size", type=int, default=16, + help="Frames per segment (bounds peak RAM). 0 = all at once") + g.add_argument("--steps", type=int, default=15, help="Denoising steps") + g.add_argument("--guide-scale", type=float, default=7.5, help="Guidance scale") + g.add_argument("--solver-mode", default="fast", choices=["fast", "normal"]) + g.add_argument("--max-chunk-len", type=int, default=32, + help="Temporal chunk length inside diffusion loop") + g.add_argument("--seed", type=int, default=0, help="Random seed") + g.add_argument("--color-fix", default="adain", + choices=["adain", "wavelet", "none"], + help="Post-processing color correction") + g.add_argument("--prompt", default="", + help="Text prompt (empty = STAR built-in quality prompt)") + + # -- Video output -- + g = p.add_argument_group("video output") + g.add_argument("--fps", type=float, default=None, + help="Output FPS (default: match input, or 24 for image sequences)") + g.add_argument("--codec", default="libx264", help="FFmpeg video codec") + g.add_argument("--crf", type=int, default=18, + help="FFmpeg CRF quality (lower = better)") + g.add_argument("--pix-fmt", default="yuv420p", help="FFmpeg pixel format") + g.add_argument("--no-audio", action="store_true", + help="Do not copy audio from input video") + + return p.parse_args() + + +# ── Model resolution ─────────────────────────────────────────────────── +def resolve_model_path(model_arg: str) -> str: + if os.path.isfile(model_arg): + return model_arg + + search = [ + SCRIPT_DIR / "models" / model_arg, + # Standard ComfyUI layout: custom_nodes/Comfyui-STAR/../../models/star/ + SCRIPT_DIR / ".." / ".." / "models" / "star" / model_arg, + ] + for candidate in search: + candidate = candidate.resolve() + if candidate.is_file(): + return str(candidate) + + if model_arg not in HF_MODELS: + raise FileNotFoundError( + f"Model '{model_arg}' not found locally and is not a known " + "downloadable model. Provide a full path, or use " + "light_deg.pt / heavy_deg.pt." + ) + + from huggingface_hub import hf_hub_download + + print(f"[STAR] Downloading {model_arg} from HuggingFace ({HF_REPO})...") + dest_dir = str(SCRIPT_DIR / "models") + os.makedirs(dest_dir, exist_ok=True) + path = hf_hub_download( + repo_id=HF_REPO, filename=HF_MODELS[model_arg], local_dir=dest_dir, + ) + return path + + +# ── Model loading (mirrors STARModelLoader.load_model) ───────────────── +def load_model(model_path: str, precision: str, offload: str, device: torch.device): + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + dtype = dtype_map[precision] + keep_on = device if offload == "disabled" else "cpu" + + from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder + from video_to_video.utils.config import cfg + + print("[STAR] Loading text encoder (OpenCLIP ViT-H-14)...") + text_encoder = FrozenOpenCLIPEmbedder(device=device, pretrained="laion2b_s32b_b79k") + text_encoder.model.to(device) + negative_y = text_encoder(cfg.negative_prompt).detach() + text_encoder.model.to(keep_on) + + from video_to_video.modules.unet_v2v import ControlledV2VUNet + + print("[STAR] Loading UNet + ControlNet...") + generator = ControlledV2VUNet() + ckpt = torch.load(model_path, map_location="cpu", weights_only=False) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + generator.load_state_dict(ckpt, strict=False) + del ckpt + generator = generator.to(device=keep_on, dtype=dtype) + generator.eval() + + from video_to_video.diffusion.schedules_sdedit import noise_schedule + from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion + + sigmas = noise_schedule( + schedule="logsnr_cosine_interp", n=1000, + zero_terminal_snr=True, scale_min=2.0, scale_max=4.0, + ) + diffusion = GaussianDiffusion(sigmas=sigmas) + + from diffusers import AutoencoderKLTemporalDecoder + + print("[STAR] Loading temporal VAE...") + vae = AutoencoderKLTemporalDecoder.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid", + subfolder="vae", variant="fp16", + ) + vae.eval() + vae.requires_grad_(False) + vae.to(keep_on) + + torch.cuda.empty_cache() + print("[STAR] All models loaded.") + + return { + "text_encoder": text_encoder, + "generator": generator, + "diffusion": diffusion, + "vae": vae, + "negative_y": negative_y, + "device": device, + "dtype": dtype, + "offload": offload, + } + + +# ── Input reading ────────────────────────────────────────────────────── +def _ffprobe(path: str): + """Return (width, height, fps, nb_frames) via ffprobe.""" + cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(path), + ] + info = json.loads(subprocess.check_output(cmd)) + vs = next(s for s in info["streams"] if s["codec_type"] == "video") + w, h = int(vs["width"]), int(vs["height"]) + + num, den = map(int, vs.get("r_frame_rate", "24/1").split("/")) + fps = num / den if den else 24.0 + + nb = vs.get("nb_frames") + if nb and nb != "N/A": + n_frames = int(nb) + else: + dur = float(info.get("format", {}).get("duration", 0)) + n_frames = int(dur * fps) or 0 + + return w, h, fps, n_frames + + +def read_video(path: str): + """Read video → (np array [N,H,W,3] uint8, fps).""" + w, h, fps, est = _ffprobe(path) + print(f"[STAR] Input video: {w}x{h}, {fps:.2f} fps, ~{est} frames") + + cmd = [ + "ffmpeg", "-i", str(path), + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-v", "quiet", "-", + ] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + + frames = [] + fsize = w * h * 3 + while True: + raw = proc.stdout.read(fsize) + if len(raw) < fsize: + break + frames.append(np.frombuffer(raw, dtype=np.uint8).reshape(h, w, 3)) + + proc.wait() + print(f"[STAR] Read {len(frames)} frames") + return np.stack(frames), fps + + +def read_image_dir(directory: str): + """Read image directory → (np array [N,H,W,3] uint8, None).""" + d = Path(directory) + files = sorted(f for f in d.iterdir() if f.suffix.lower() in IMAGE_EXTS) + if not files: + raise FileNotFoundError(f"No image files in {d}") + + print(f"[STAR] Loading {len(files)} images from {d}") + frames = [np.array(Image.open(f).convert("RGB")) for f in files] + return np.stack(frames), None + + +def read_input(path: str): + """Auto-detect input type → (np array [N,H,W,3] uint8, fps | None).""" + p = Path(path) + if p.is_dir(): + return read_image_dir(path) + if p.suffix.lower() in VIDEO_EXTS: + return read_video(path) + if p.suffix.lower() in IMAGE_EXTS: + img = np.array(Image.open(p).convert("RGB")) + return img[np.newaxis], None + raise ValueError(f"Unsupported input: {path}") + + +# ── Output writing ───────────────────────────────────────────────────── +class VideoWriter: + """Stream RGB frames to ffmpeg, optionally copying audio from source.""" + + def __init__(self, output_path, fps, width, height, + codec="libx264", crf=18, pix_fmt="yuv420p", + audio_source=None): + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{width}x{height}", "-r", str(fps), + "-i", "-", + ] + if audio_source: + cmd += ["-i", str(audio_source)] + + cmd += ["-map", "0:v:0"] + if audio_source: + cmd += ["-map", "1:a?", "-c:a", "copy"] + + cmd += [ + "-c:v", codec, "-crf", str(crf), "-pix_fmt", pix_fmt, + "-movflags", "+faststart", + "-v", "warning", + str(output_path), + ] + self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) + self.count = 0 + + def write_frame(self, frame_uint8): + self.proc.stdin.write(frame_uint8.tobytes()) + self.count += 1 + + def close(self): + self.proc.stdin.close() + self.proc.wait() + + +class ImageSequenceWriter: + """Save frames as numbered image files.""" + + def __init__(self, out_dir, ext=".png"): + self.out_dir = Path(out_dir) + self.out_dir.mkdir(parents=True, exist_ok=True) + self.ext = ext + self.count = 0 + + def write_frame(self, frame_uint8): + Image.fromarray(frame_uint8).save( + self.out_dir / f"{self.count:06d}{self.ext}" + ) + self.count += 1 + + def close(self): + pass + + +class SingleImageWriter: + """Save a single output image.""" + + def __init__(self, path): + self.path = Path(path) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.count = 0 + + def write_frame(self, frame_uint8): + Image.fromarray(frame_uint8).save(self.path) + self.count += 1 + + def close(self): + pass + + +# ── Output path helpers ──────────────────────────────────────────────── +def is_video_path(p): + return Path(p).suffix.lower() in VIDEO_EXTS + + +def is_image_path(p): + return Path(p).suffix.lower() in IMAGE_EXTS + + +def auto_output(input_path: str) -> str: + p = Path(input_path) + if p.is_dir(): + return str(p.parent / (p.name + "_star")) + return str(p.parent / (p.stem + "_star" + p.suffix)) + + +def make_writer(output_path, fps, w, h, args, input_path, is_single_image): + """Create the appropriate writer for the output path.""" + if is_single_image and is_image_path(output_path): + return SingleImageWriter(output_path) + if is_video_path(output_path): + audio_src = input_path if ( + is_video_path(input_path) and not args.no_audio + ) else None + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + return VideoWriter( + output_path, fps, w, h, + codec=args.codec, crf=args.crf, pix_fmt=args.pix_fmt, + audio_source=audio_src, + ) + # Default: image sequence directory + return ImageSequenceWriter(output_path, ext=".png") + + +# ── Tensor ↔ numpy ───────────────────────────────────────────────────── +def to_tensor(frames_uint8): + """[N,H,W,3] uint8 numpy → [N,H,W,3] float32 torch in [0,1].""" + return torch.from_numpy(frames_uint8.copy()).float() / 255.0 + + +def to_uint8(tensor): + """[N,H,W,3] float32 torch in [0,1] → [N,H,W,3] uint8 numpy.""" + return (tensor.clamp(0, 1) * 255).byte().cpu().numpy() + + +# ── Segment-based processing with streaming output ───────────────────── +def _run_segment(star_model, frames_uint8, args): + """Process one segment through STAR → float32 tensor [F,H,W,3].""" + from star_pipeline import run_star_inference + + tensor = to_tensor(frames_uint8) + return run_star_inference( + star_model=star_model, + images=tensor, + upscale=args.upscale, + steps=args.steps, + guide_scale=args.guide_scale, + prompt=args.prompt, + solver_mode=args.solver_mode, + max_chunk_len=args.max_chunk_len, + seed=args.seed, + color_fix=args.color_fix, + ) + + +def _write_tensor(writer, tensor): + """Write a float32 [F,H,W,3] tensor as uint8 frames.""" + arr = to_uint8(tensor) + for i in range(arr.shape[0]): + writer.write_frame(arr[i]) + + +def process_and_stream(star_model, input_frames, writer, args): + """Process in segments, blend overlaps, and stream to writer.""" + total = input_frames.shape[0] + seg = args.segment_size + + # No segmentation — process everything at once + if seg <= 0 or total <= seg: + print(f"[STAR] Processing all {total} frame(s)...") + result = _run_segment(star_model, input_frames, args) + _write_tensor(writer, result) + return + + overlap = max(2, seg // 4) + stride = seg - overlap + + # Build segment boundaries + segments = [] + start = 0 + while start < total: + end = min(start + seg, total) + segments.append((start, end)) + if end == total: + break + start += stride + + print(f"[STAR] {total} frames → {len(segments)} segment(s), " + f"segment_size={seg}, overlap={overlap}") + + prev_tail = None # float32 tensor on CPU + + for idx, (s, e) in enumerate(segments): + print(f"\n[STAR] ── Segment {idx + 1}/{len(segments)}: " + f"frames {s}–{e - 1} ──") + + seg_result = _run_segment(star_model, input_frames[s:e], args) + + # Blend overlap with previous segment's tail + if prev_tail is not None: + n = prev_tail.shape[0] + head = seg_result[:n] + w = torch.linspace(0, 1, n, dtype=seg_result.dtype).view(n, 1, 1, 1) + blended = prev_tail * (1.0 - w) + head * w + _write_tensor(writer, blended) + remainder = seg_result[n:] + else: + remainder = seg_result + + if idx < len(segments) - 1: + # Keep tail for blending, write the rest + prev_tail = remainder[-overlap:].clone() + _write_tensor(writer, remainder[:-overlap]) + else: + # Last segment — write everything + _write_tensor(writer, remainder) + + del seg_result + torch.cuda.empty_cache() + + +# ── Main ──────────────────────────────────────────────────────────────── +def main(): + args = parse_args() + + # Validate environment + if not torch.cuda.is_available(): + print("Error: CUDA is not available. STAR requires a CUDA GPU.") + sys.exit(1) + if not shutil.which("ffmpeg") or not shutil.which("ffprobe"): + input_p = Path(args.input) + if input_p.suffix.lower() in VIDEO_EXTS or ( + args.output and Path(args.output).suffix.lower() in VIDEO_EXTS + ): + print("Error: ffmpeg/ffprobe not found. Install ffmpeg for video I/O.") + sys.exit(1) + + # Read input + input_frames, input_fps = read_input(args.input) + total = input_frames.shape[0] + h_in, w_in = input_frames.shape[1], input_frames.shape[2] + h_out, w_out = h_in * args.upscale, w_in * args.upscale + is_single = total == 1 and Path(args.input).is_file() and is_image_path(args.input) + + print(f"[STAR] {w_in}x{h_in} → {w_out}x{h_out} ({args.upscale}x), " + f"{total} frame(s)") + + # Output path + output_path = args.output or auto_output(args.input) + fps = args.fps or input_fps or 24.0 + + # Load model + device = torch.device("cuda") + model_path = resolve_model_path(args.model) + print(f"[STAR] Model: {model_path}") + star_model = load_model(model_path, args.precision, args.offload, device) + + # Create writer and process + writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single) + process_and_stream(star_model, input_frames, writer, args) + writer.close() + + print(f"\n[STAR] Done! {writer.count} frame(s) → {output_path}") + + +if __name__ == "__main__": + main()