Files
Comfyui-STAR/inference.py
Ethanfel 8a440761d1 Fix noise level (900 not 1000) and prompt concatenation to match original STAR
The original STAR inference uses total_noise_levels=900, preserving input
structure during SDEdit. We had 1000 which starts from near-pure noise,
destroying the input. Also always append the quality prompt to user text
instead of using it only as a fallback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 02:03:34 +01:00

672 lines
23 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""STAR Video Super-Resolution — Standalone Inference Script
Memory-efficient video upscaling from the command line. Works outside
ComfyUI — just activate the same Python environment.
Examples
--------
# Video → video (audio is preserved automatically)
python inference.py input.mp4 -o output.mp4
# Lower VRAM (model offload + smaller segments)
python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8
# Image sequence → image sequence
python inference.py frames_in/ -o frames_out/
# Image sequence → video
python inference.py frames_in/ -o output.mp4 --fps 24
# Single image
python inference.py photo.png -o photo_4x.png
"""
# ── Comfy module stubs (must run before star_pipeline import) ────────────
import os
import sys
import types
from pathlib import Path
SCRIPT_DIR = Path(__file__).resolve().parent
STAR_REPO = SCRIPT_DIR / "STAR"
sys.path.insert(0, str(SCRIPT_DIR))
sys.path.insert(0, str(STAR_REPO))
# Apply patches from patches/ directory to the STAR submodule.
import subprocess # noqa: E402
_PATCHES_DIR = SCRIPT_DIR / "patches"
if _PATCHES_DIR.is_dir():
for _patch in sorted(_PATCHES_DIR.iterdir()):
if _patch.suffix != ".patch":
continue
if subprocess.call(
["git", "apply", "--check", "--reverse", str(_patch)],
cwd=str(STAR_REPO), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
) != 0:
if subprocess.call(["git", "apply", str(_patch)], cwd=str(STAR_REPO)) == 0:
print(f"[STAR] Applied patch: {_patch.name}")
else:
print(f"[STAR] Warning: failed to apply patch: {_patch.name}")
import torch # noqa: E402 — needed for stub defaults
_comfy = types.ModuleType("comfy")
_comfy_utils = types.ModuleType("comfy.utils")
_comfy_mm = types.ModuleType("comfy.model_management")
class _ProgressBar:
"""tqdm-based stand-in for comfy.utils.ProgressBar."""
def __init__(self, total):
from tqdm import tqdm
self._bar = tqdm(total=total, desc="Denoising", unit="step")
def update(self, n=1):
self._bar.update(n)
def __del__(self):
if hasattr(self, "_bar"):
self._bar.close()
_comfy_utils.ProgressBar = _ProgressBar
_comfy_mm.get_torch_device = lambda: torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
_comfy_mm.soft_empty_cache = lambda: torch.cuda.empty_cache()
_comfy.utils = _comfy_utils
_comfy.model_management = _comfy_mm
sys.modules["comfy"] = _comfy
sys.modules["comfy.utils"] = _comfy_utils
sys.modules["comfy.model_management"] = _comfy_mm
# ── Attention backend dispatcher ──────────────────────────────────────
import torch.nn.functional as F # noqa: E402
_ATTN_BACKENDS = {"sdpa": None}
_real_xformers_mea = None
try:
import xformers.ops
_candidate = xformers.ops.memory_efficient_attention
if not getattr(_candidate, "_is_star_dispatcher", False):
_real_xformers_mea = _candidate
_ATTN_BACKENDS["xformers"] = _real_xformers_mea
except ImportError:
pass
_SAGE_VARIANTS = [
"sageattn",
"sageattn_qk_int8_pv_fp16_triton",
"sageattn_qk_int8_pv_fp16_cuda",
"sageattn_qk_int8_pv_fp8_cuda",
]
for _name in _SAGE_VARIANTS:
try:
_fn = getattr(__import__("sageattention", fromlist=[_name]), _name)
_ATTN_BACKENDS[_name] = _fn
except (ImportError, AttributeError):
pass
# Manual attention (guaranteed correct, used as diagnostic baseline)
_ATTN_BACKENDS["math"] = "math"
_active_attn = "sdpa"
def _set_attn(backend: str):
global _active_attn
if backend not in _ATTN_BACKENDS:
print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa")
backend = "sdpa"
_active_attn = backend
print(f"[STAR] Attention backend: {backend}")
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
if _active_attn == "xformers":
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
if _active_attn == "math":
# Naive batched attention — slow but guaranteed correct.
scale = q.shape[-1] ** -0.5
cs = max(1, 2**28 // (q.shape[1] * q.shape[1] * max(q.element_size(), 1)))
outs = []
for i in range(0, q.shape[0], cs):
qi, ki, vi = q[i:i+cs], k[i:i+cs], v[i:i+cs]
a = torch.bmm(qi * scale, ki.transpose(1, 2))
if attn_bias is not None:
a = a + (attn_bias[i:i+cs] if attn_bias.shape[0] > 1 else attn_bias)
outs.append(torch.bmm(a.softmax(dim=-1), vi))
return torch.cat(outs)
if _active_attn == "sdpa":
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# SageAttention variants: need 4D tensors (batch, heads, seq, dim)
fn = _ATTN_BACKENDS[_active_attn]
return fn(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
tensor_layout="HND", is_causal=False,
).squeeze(0)
_dispatched_mea._is_star_dispatcher = True
if "xformers" in sys.modules:
sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea
else:
_xformers = types.ModuleType("xformers")
_xformers_ops = types.ModuleType("xformers.ops")
_xformers_ops.memory_efficient_attention = _dispatched_mea
_xformers.ops = _xformers_ops
sys.modules["xformers"] = _xformers
sys.modules["xformers.ops"] = _xformers_ops
print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
# ── Standard imports ────────────────────────────────────────────────────
import argparse # noqa: E402
import json # noqa: E402
import shutil # noqa: E402
import numpy as np # noqa: E402
from PIL import Image # noqa: E402
# ── Constants ───────────────────────────────────────────────────────────
HF_REPO = "SherryX/STAR"
HF_MODELS = {
"light_deg.pt": "I2VGen-XL-based/light_deg.pt",
"heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt",
}
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".ts"}
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
# ── Argument parsing ───────────────────────────────────────────────────
def parse_args():
class Fmt(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawDescriptionHelpFormatter):
pass
p = argparse.ArgumentParser(
description="STAR Video Super-Resolution — standalone inference",
formatter_class=Fmt,
epilog=__doc__,
)
# -- I/O --
p.add_argument("input", help="Input video file, image file, or directory of frames")
p.add_argument("-o", "--output",
help="Output path (video file, image file, or directory). "
"Auto-generated with _star suffix if omitted.")
# -- Model --
g = p.add_argument_group("model")
g.add_argument("--model", default="light_deg.pt",
help="Model name (light_deg.pt / heavy_deg.pt) or path to .pt file")
g.add_argument("--precision", default="fp16", choices=["fp16", "bf16", "fp32"],
help="Weight precision")
g.add_argument("--offload", default="model",
choices=["disabled", "model", "aggressive"],
help="VRAM offloading strategy")
# -- Processing --
g = p.add_argument_group("processing")
g.add_argument("--upscale", type=int, default=4, help="Upscale factor")
g.add_argument("--segment-size", type=int, default=16,
help="Frames per segment (bounds peak RAM). 0 = all at once")
g.add_argument("--steps", type=int, default=15, help="Denoising steps")
g.add_argument("--guide-scale", type=float, default=7.5, help="Guidance scale")
g.add_argument("--solver-mode", default="fast", choices=["fast", "normal"])
g.add_argument("--max-chunk-len", type=int, default=32,
help="Temporal chunk length inside diffusion loop")
g.add_argument("--seed", type=int, default=0, help="Random seed")
g.add_argument("--color-fix", default="adain",
choices=["adain", "wavelet", "none"],
help="Post-processing color correction")
g.add_argument("--prompt", default="",
help="Text prompt (empty = STAR built-in quality prompt)")
g.add_argument("--attention", default="sdpa",
choices=list(_ATTN_BACKENDS.keys()),
help="Attention backend")
# -- Video output --
g = p.add_argument_group("video output")
g.add_argument("--fps", type=float, default=None,
help="Output FPS (default: match input, or 24 for image sequences)")
g.add_argument("--codec", default="libx264", help="FFmpeg video codec")
g.add_argument("--crf", type=int, default=18,
help="FFmpeg CRF quality (lower = better)")
g.add_argument("--pix-fmt", default="yuv420p", help="FFmpeg pixel format")
g.add_argument("--no-audio", action="store_true",
help="Do not copy audio from input video")
return p.parse_args()
# ── Model resolution ───────────────────────────────────────────────────
def resolve_model_path(model_arg: str) -> str:
if os.path.isfile(model_arg):
return model_arg
search = [
SCRIPT_DIR / "models" / model_arg,
# Standard ComfyUI layout: custom_nodes/Comfyui-STAR/../../models/star/
SCRIPT_DIR / ".." / ".." / "models" / "star" / model_arg,
]
for candidate in search:
candidate = candidate.resolve()
if candidate.is_file():
return str(candidate)
if model_arg not in HF_MODELS:
raise FileNotFoundError(
f"Model '{model_arg}' not found locally and is not a known "
"downloadable model. Provide a full path, or use "
"light_deg.pt / heavy_deg.pt."
)
from huggingface_hub import hf_hub_download
print(f"[STAR] Downloading {model_arg} from HuggingFace ({HF_REPO})...")
dest_dir = str(SCRIPT_DIR / "models")
os.makedirs(dest_dir, exist_ok=True)
path = hf_hub_download(
repo_id=HF_REPO, filename=HF_MODELS[model_arg], local_dir=dest_dir,
)
return path
# ── Model loading (mirrors STARModelLoader.load_model) ─────────────────
def load_model(model_path: str, precision: str, offload: str, device: torch.device):
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map[precision]
keep_on = device if offload == "disabled" else "cpu"
from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder
from video_to_video.utils.config import cfg
print("[STAR] Loading text encoder (OpenCLIP ViT-H-14)...")
text_encoder = FrozenOpenCLIPEmbedder(device=device, pretrained="laion2b_s32b_b79k")
text_encoder.model.to(device)
negative_y = text_encoder(cfg.negative_prompt).detach()
text_encoder.model.to(keep_on)
from video_to_video.modules.unet_v2v import ControlledV2VUNet
print("[STAR] Loading UNet + ControlNet...")
generator = ControlledV2VUNet()
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
generator.load_state_dict(ckpt, strict=False)
del ckpt
generator = generator.to(device=keep_on, dtype=dtype)
generator.eval()
from video_to_video.diffusion.schedules_sdedit import noise_schedule
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
sigmas = noise_schedule(
schedule="logsnr_cosine_interp", n=1000,
zero_terminal_snr=True, scale_min=2.0, scale_max=4.0,
)
diffusion = GaussianDiffusion(sigmas=sigmas)
from diffusers import AutoencoderKLTemporalDecoder
print("[STAR] Loading temporal VAE...")
vae = AutoencoderKLTemporalDecoder.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="vae", variant="fp16",
)
vae.eval()
vae.requires_grad_(False)
vae.to(keep_on)
torch.cuda.empty_cache()
print("[STAR] All models loaded.")
return {
"text_encoder": text_encoder,
"generator": generator,
"diffusion": diffusion,
"vae": vae,
"negative_y": negative_y,
"device": device,
"dtype": dtype,
"offload": offload,
}
# ── Input reading ──────────────────────────────────────────────────────
def _ffprobe(path: str):
"""Return (width, height, fps, nb_frames) via ffprobe."""
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", "-show_format", str(path),
]
info = json.loads(subprocess.check_output(cmd))
vs = next(s for s in info["streams"] if s["codec_type"] == "video")
w, h = int(vs["width"]), int(vs["height"])
num, den = map(int, vs.get("r_frame_rate", "24/1").split("/"))
fps = num / den if den else 24.0
nb = vs.get("nb_frames")
if nb and nb != "N/A":
n_frames = int(nb)
else:
dur = float(info.get("format", {}).get("duration", 0))
n_frames = int(dur * fps) or 0
return w, h, fps, n_frames
def read_video(path: str):
"""Read video → (np array [N,H,W,3] uint8, fps)."""
w, h, fps, est = _ffprobe(path)
print(f"[STAR] Input video: {w}x{h}, {fps:.2f} fps, ~{est} frames")
cmd = [
"ffmpeg", "-i", str(path),
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-v", "quiet", "-",
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
frames = []
fsize = w * h * 3
while True:
raw = proc.stdout.read(fsize)
if len(raw) < fsize:
break
frames.append(np.frombuffer(raw, dtype=np.uint8).reshape(h, w, 3))
proc.wait()
print(f"[STAR] Read {len(frames)} frames")
return np.stack(frames), fps
def read_image_dir(directory: str):
"""Read image directory → (np array [N,H,W,3] uint8, None)."""
d = Path(directory)
files = sorted(f for f in d.iterdir() if f.suffix.lower() in IMAGE_EXTS)
if not files:
raise FileNotFoundError(f"No image files in {d}")
print(f"[STAR] Loading {len(files)} images from {d}")
frames = [np.array(Image.open(f).convert("RGB")) for f in files]
return np.stack(frames), None
def read_input(path: str):
"""Auto-detect input type → (np array [N,H,W,3] uint8, fps | None)."""
p = Path(path)
if p.is_dir():
return read_image_dir(path)
if p.suffix.lower() in VIDEO_EXTS:
return read_video(path)
if p.suffix.lower() in IMAGE_EXTS:
img = np.array(Image.open(p).convert("RGB"))
return img[np.newaxis], None
raise ValueError(f"Unsupported input: {path}")
# ── Output writing ─────────────────────────────────────────────────────
class VideoWriter:
"""Stream RGB frames to ffmpeg, optionally copying audio from source."""
def __init__(self, output_path, fps, width, height,
codec="libx264", crf=18, pix_fmt="yuv420p",
audio_source=None):
cmd = [
"ffmpeg", "-y",
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-s", f"{width}x{height}", "-r", str(fps),
"-i", "-",
]
if audio_source:
cmd += ["-i", str(audio_source)]
cmd += ["-map", "0:v:0"]
if audio_source:
cmd += ["-map", "1:a?", "-c:a", "copy"]
cmd += [
"-c:v", codec, "-crf", str(crf), "-pix_fmt", pix_fmt,
"-movflags", "+faststart",
"-v", "warning",
str(output_path),
]
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
self.count = 0
def write_frame(self, frame_uint8):
self.proc.stdin.write(frame_uint8.tobytes())
self.count += 1
def close(self):
self.proc.stdin.close()
self.proc.wait()
class ImageSequenceWriter:
"""Save frames as numbered image files."""
def __init__(self, out_dir, ext=".png"):
self.out_dir = Path(out_dir)
self.out_dir.mkdir(parents=True, exist_ok=True)
self.ext = ext
self.count = 0
def write_frame(self, frame_uint8):
Image.fromarray(frame_uint8).save(
self.out_dir / f"{self.count:06d}{self.ext}"
)
self.count += 1
def close(self):
pass
class SingleImageWriter:
"""Save a single output image."""
def __init__(self, path):
self.path = Path(path)
self.path.parent.mkdir(parents=True, exist_ok=True)
self.count = 0
def write_frame(self, frame_uint8):
Image.fromarray(frame_uint8).save(self.path)
self.count += 1
def close(self):
pass
# ── Output path helpers ────────────────────────────────────────────────
def is_video_path(p):
return Path(p).suffix.lower() in VIDEO_EXTS
def is_image_path(p):
return Path(p).suffix.lower() in IMAGE_EXTS
def auto_output(input_path: str) -> str:
p = Path(input_path)
if p.is_dir():
return str(p.parent / (p.name + "_star"))
return str(p.parent / (p.stem + "_star" + p.suffix))
def make_writer(output_path, fps, w, h, args, input_path, is_single_image):
"""Create the appropriate writer for the output path."""
if is_single_image and is_image_path(output_path):
return SingleImageWriter(output_path)
if is_video_path(output_path):
audio_src = input_path if (
is_video_path(input_path) and not args.no_audio
) else None
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
return VideoWriter(
output_path, fps, w, h,
codec=args.codec, crf=args.crf, pix_fmt=args.pix_fmt,
audio_source=audio_src,
)
# Default: image sequence directory
return ImageSequenceWriter(output_path, ext=".png")
# ── Tensor ↔ numpy ─────────────────────────────────────────────────────
def to_tensor(frames_uint8):
"""[N,H,W,3] uint8 numpy → [N,H,W,3] float32 torch in [0,1]."""
return torch.from_numpy(frames_uint8.copy()).float() / 255.0
def to_uint8(tensor):
"""[N,H,W,3] float32 torch in [0,1] → [N,H,W,3] uint8 numpy."""
return (tensor.clamp(0, 1) * 255).byte().cpu().numpy()
# ── Segment-based processing with streaming output ─────────────────────
def _run_segment(star_model, frames_uint8, args):
"""Process one segment through STAR → float32 tensor [F,H,W,3]."""
from star_pipeline import run_star_inference
tensor = to_tensor(frames_uint8)
return run_star_inference(
star_model=star_model,
images=tensor,
upscale=args.upscale,
steps=args.steps,
guide_scale=args.guide_scale,
prompt=args.prompt,
solver_mode=args.solver_mode,
max_chunk_len=args.max_chunk_len,
seed=args.seed,
color_fix=args.color_fix,
)
def _write_tensor(writer, tensor):
"""Write a float32 [F,H,W,3] tensor as uint8 frames."""
arr = to_uint8(tensor)
for i in range(arr.shape[0]):
writer.write_frame(arr[i])
def process_and_stream(star_model, input_frames, writer, args):
"""Process in segments, blend overlaps, and stream to writer."""
total = input_frames.shape[0]
seg = args.segment_size
# No segmentation — process everything at once
if seg <= 0 or total <= seg:
print(f"[STAR] Processing all {total} frame(s)...")
result = _run_segment(star_model, input_frames, args)
_write_tensor(writer, result)
return
overlap = max(2, seg // 4)
stride = seg - overlap
# Build segment boundaries
segments = []
start = 0
while start < total:
end = min(start + seg, total)
segments.append((start, end))
if end == total:
break
start += stride
print(f"[STAR] {total} frames → {len(segments)} segment(s), "
f"segment_size={seg}, overlap={overlap}")
prev_tail = None # float32 tensor on CPU
for idx, (s, e) in enumerate(segments):
print(f"\n[STAR] ── Segment {idx + 1}/{len(segments)}: "
f"frames {s}{e - 1} ──")
seg_result = _run_segment(star_model, input_frames[s:e], args)
# Blend overlap with previous segment's tail
if prev_tail is not None:
n = prev_tail.shape[0]
head = seg_result[:n]
w = torch.linspace(0, 1, n, dtype=seg_result.dtype).view(n, 1, 1, 1)
blended = prev_tail * (1.0 - w) + head * w
_write_tensor(writer, blended)
remainder = seg_result[n:]
else:
remainder = seg_result
if idx < len(segments) - 1:
# Keep tail for blending, write the rest
prev_tail = remainder[-overlap:].clone()
_write_tensor(writer, remainder[:-overlap])
else:
# Last segment — write everything
_write_tensor(writer, remainder)
del seg_result
torch.cuda.empty_cache()
# ── Main ────────────────────────────────────────────────────────────────
def main():
args = parse_args()
# Validate environment
if not torch.cuda.is_available():
print("Error: CUDA is not available. STAR requires a CUDA GPU.")
sys.exit(1)
if not shutil.which("ffmpeg") or not shutil.which("ffprobe"):
input_p = Path(args.input)
if input_p.suffix.lower() in VIDEO_EXTS or (
args.output and Path(args.output).suffix.lower() in VIDEO_EXTS
):
print("Error: ffmpeg/ffprobe not found. Install ffmpeg for video I/O.")
sys.exit(1)
# Read input
input_frames, input_fps = read_input(args.input)
total = input_frames.shape[0]
h_in, w_in = input_frames.shape[1], input_frames.shape[2]
h_out, w_out = h_in * args.upscale, w_in * args.upscale
is_single = total == 1 and Path(args.input).is_file() and is_image_path(args.input)
print(f"[STAR] {w_in}x{h_in}{w_out}x{h_out} ({args.upscale}x), "
f"{total} frame(s)")
# Output path
output_path = args.output or auto_output(args.input)
fps = args.fps or input_fps or 24.0
# Load model
device = torch.device("cuda")
model_path = resolve_model_path(args.model)
print(f"[STAR] Model: {model_path}")
star_model = load_model(model_path, args.precision, args.offload, device)
_set_attn(args.attention)
# Create writer and process
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
process_and_stream(star_model, input_frames, writer, args)
writer.close()
print(f"\n[STAR] Done! {writer.count} frame(s) → {output_path}")
if __name__ == "__main__":
main()