Newer open_clip creates nn.MultiheadAttention with batch_first=True, but STAR's embedder unconditionally permutes to [seq, batch, embed]. This causes a RuntimeError in the text encoder (attn_mask shape mismatch). The patch detects batch_first at runtime and only permutes when needed. Patches in patches/ are auto-applied to the STAR submodule on startup and skip gracefully if already applied. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
656 lines
23 KiB
Python
Executable File
656 lines
23 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""STAR Video Super-Resolution — Standalone Inference Script
|
||
|
||
Memory-efficient video upscaling from the command line. Works outside
|
||
ComfyUI — just activate the same Python environment.
|
||
|
||
Examples
|
||
--------
|
||
# Video → video (audio is preserved automatically)
|
||
python inference.py input.mp4 -o output.mp4
|
||
|
||
# Lower VRAM (model offload + smaller segments)
|
||
python inference.py input.mp4 -o output.mp4 --offload model --segment-size 8
|
||
|
||
# Image sequence → image sequence
|
||
python inference.py frames_in/ -o frames_out/
|
||
|
||
# Image sequence → video
|
||
python inference.py frames_in/ -o output.mp4 --fps 24
|
||
|
||
# Single image
|
||
python inference.py photo.png -o photo_4x.png
|
||
"""
|
||
|
||
# ── Comfy module stubs (must run before star_pipeline import) ────────────
|
||
import os
|
||
import sys
|
||
import types
|
||
from pathlib import Path
|
||
|
||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||
STAR_REPO = SCRIPT_DIR / "STAR"
|
||
sys.path.insert(0, str(SCRIPT_DIR))
|
||
sys.path.insert(0, str(STAR_REPO))
|
||
|
||
# Apply patches from patches/ directory to the STAR submodule.
|
||
import subprocess # noqa: E402
|
||
|
||
_PATCHES_DIR = SCRIPT_DIR / "patches"
|
||
if _PATCHES_DIR.is_dir():
|
||
for _patch in sorted(_PATCHES_DIR.iterdir()):
|
||
if _patch.suffix != ".patch":
|
||
continue
|
||
if subprocess.call(
|
||
["git", "apply", "--check", "--reverse", str(_patch)],
|
||
cwd=str(STAR_REPO), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
||
) != 0:
|
||
if subprocess.call(["git", "apply", str(_patch)], cwd=str(STAR_REPO)) == 0:
|
||
print(f"[STAR] Applied patch: {_patch.name}")
|
||
else:
|
||
print(f"[STAR] Warning: failed to apply patch: {_patch.name}")
|
||
|
||
import torch # noqa: E402 — needed for stub defaults
|
||
|
||
_comfy = types.ModuleType("comfy")
|
||
_comfy_utils = types.ModuleType("comfy.utils")
|
||
_comfy_mm = types.ModuleType("comfy.model_management")
|
||
|
||
|
||
class _ProgressBar:
|
||
"""tqdm-based stand-in for comfy.utils.ProgressBar."""
|
||
|
||
def __init__(self, total):
|
||
from tqdm import tqdm
|
||
|
||
self._bar = tqdm(total=total, desc="Denoising", unit="step")
|
||
|
||
def update(self, n=1):
|
||
self._bar.update(n)
|
||
|
||
def __del__(self):
|
||
if hasattr(self, "_bar"):
|
||
self._bar.close()
|
||
|
||
|
||
_comfy_utils.ProgressBar = _ProgressBar
|
||
_comfy_mm.get_torch_device = lambda: torch.device(
|
||
"cuda" if torch.cuda.is_available() else "cpu"
|
||
)
|
||
_comfy_mm.soft_empty_cache = lambda: torch.cuda.empty_cache()
|
||
|
||
_comfy.utils = _comfy_utils
|
||
_comfy.model_management = _comfy_mm
|
||
|
||
sys.modules["comfy"] = _comfy
|
||
sys.modules["comfy.utils"] = _comfy_utils
|
||
sys.modules["comfy.model_management"] = _comfy_mm
|
||
|
||
# ── Attention backend dispatcher ──────────────────────────────────────
|
||
import torch.nn.functional as F # noqa: E402
|
||
|
||
_ATTN_BACKENDS = {"sdpa": None}
|
||
|
||
_real_xformers_mea = None
|
||
try:
|
||
import xformers.ops
|
||
_candidate = xformers.ops.memory_efficient_attention
|
||
if not getattr(_candidate, "_is_star_dispatcher", False):
|
||
_real_xformers_mea = _candidate
|
||
_ATTN_BACKENDS["xformers"] = _real_xformers_mea
|
||
except ImportError:
|
||
pass
|
||
|
||
_SAGE_VARIANTS = [
|
||
"sageattn",
|
||
"sageattn_qk_int8_pv_fp16_triton",
|
||
"sageattn_qk_int8_pv_fp16_cuda",
|
||
"sageattn_qk_int8_pv_fp8_cuda",
|
||
]
|
||
for _name in _SAGE_VARIANTS:
|
||
try:
|
||
_fn = getattr(__import__("sageattention", fromlist=[_name]), _name)
|
||
_ATTN_BACKENDS[_name] = _fn
|
||
except (ImportError, AttributeError):
|
||
pass
|
||
|
||
_active_attn = "sdpa"
|
||
|
||
|
||
def _set_attn(backend: str):
|
||
global _active_attn
|
||
if backend not in _ATTN_BACKENDS:
|
||
print(f"[STAR] Warning: backend '{backend}' not available, falling back to sdpa")
|
||
backend = "sdpa"
|
||
_active_attn = backend
|
||
print(f"[STAR] Attention backend: {backend}")
|
||
|
||
|
||
def _dispatched_mea(q, k, v, attn_bias=None, op=None):
|
||
if _active_attn == "sdpa":
|
||
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||
if _active_attn == "xformers":
|
||
return _real_xformers_mea(q, k, v, attn_bias=attn_bias, op=op)
|
||
fn = _ATTN_BACKENDS[_active_attn]
|
||
return fn(
|
||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0),
|
||
tensor_layout="HND", is_causal=False,
|
||
).squeeze(0)
|
||
|
||
|
||
_dispatched_mea._is_star_dispatcher = True
|
||
|
||
if "xformers" in sys.modules:
|
||
sys.modules["xformers"].ops.memory_efficient_attention = _dispatched_mea
|
||
else:
|
||
_xformers = types.ModuleType("xformers")
|
||
_xformers_ops = types.ModuleType("xformers.ops")
|
||
_xformers_ops.memory_efficient_attention = _dispatched_mea
|
||
_xformers.ops = _xformers_ops
|
||
sys.modules["xformers"] = _xformers
|
||
sys.modules["xformers.ops"] = _xformers_ops
|
||
|
||
print(f"[STAR] Available attention backends: {list(_ATTN_BACKENDS.keys())}")
|
||
|
||
# ── Standard imports ────────────────────────────────────────────────────
|
||
import argparse # noqa: E402
|
||
import json # noqa: E402
|
||
import shutil # noqa: E402
|
||
|
||
import numpy as np # noqa: E402
|
||
from PIL import Image # noqa: E402
|
||
|
||
# ── Constants ───────────────────────────────────────────────────────────
|
||
HF_REPO = "SherryX/STAR"
|
||
HF_MODELS = {
|
||
"light_deg.pt": "I2VGen-XL-based/light_deg.pt",
|
||
"heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt",
|
||
}
|
||
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".ts"}
|
||
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
|
||
|
||
|
||
# ── Argument parsing ───────────────────────────────────────────────────
|
||
def parse_args():
|
||
class Fmt(argparse.ArgumentDefaultsHelpFormatter,
|
||
argparse.RawDescriptionHelpFormatter):
|
||
pass
|
||
|
||
p = argparse.ArgumentParser(
|
||
description="STAR Video Super-Resolution — standalone inference",
|
||
formatter_class=Fmt,
|
||
epilog=__doc__,
|
||
)
|
||
|
||
# -- I/O --
|
||
p.add_argument("input", help="Input video file, image file, or directory of frames")
|
||
p.add_argument("-o", "--output",
|
||
help="Output path (video file, image file, or directory). "
|
||
"Auto-generated with _star suffix if omitted.")
|
||
|
||
# -- Model --
|
||
g = p.add_argument_group("model")
|
||
g.add_argument("--model", default="light_deg.pt",
|
||
help="Model name (light_deg.pt / heavy_deg.pt) or path to .pt file")
|
||
g.add_argument("--precision", default="fp16", choices=["fp16", "bf16", "fp32"],
|
||
help="Weight precision")
|
||
g.add_argument("--offload", default="model",
|
||
choices=["disabled", "model", "aggressive"],
|
||
help="VRAM offloading strategy")
|
||
|
||
# -- Processing --
|
||
g = p.add_argument_group("processing")
|
||
g.add_argument("--upscale", type=int, default=4, help="Upscale factor")
|
||
g.add_argument("--segment-size", type=int, default=16,
|
||
help="Frames per segment (bounds peak RAM). 0 = all at once")
|
||
g.add_argument("--steps", type=int, default=15, help="Denoising steps")
|
||
g.add_argument("--guide-scale", type=float, default=7.5, help="Guidance scale")
|
||
g.add_argument("--solver-mode", default="fast", choices=["fast", "normal"])
|
||
g.add_argument("--max-chunk-len", type=int, default=32,
|
||
help="Temporal chunk length inside diffusion loop")
|
||
g.add_argument("--seed", type=int, default=0, help="Random seed")
|
||
g.add_argument("--color-fix", default="adain",
|
||
choices=["adain", "wavelet", "none"],
|
||
help="Post-processing color correction")
|
||
g.add_argument("--prompt", default="",
|
||
help="Text prompt (empty = STAR built-in quality prompt)")
|
||
g.add_argument("--attention", default="sdpa",
|
||
choices=list(_ATTN_BACKENDS.keys()),
|
||
help="Attention backend")
|
||
|
||
# -- Video output --
|
||
g = p.add_argument_group("video output")
|
||
g.add_argument("--fps", type=float, default=None,
|
||
help="Output FPS (default: match input, or 24 for image sequences)")
|
||
g.add_argument("--codec", default="libx264", help="FFmpeg video codec")
|
||
g.add_argument("--crf", type=int, default=18,
|
||
help="FFmpeg CRF quality (lower = better)")
|
||
g.add_argument("--pix-fmt", default="yuv420p", help="FFmpeg pixel format")
|
||
g.add_argument("--no-audio", action="store_true",
|
||
help="Do not copy audio from input video")
|
||
|
||
return p.parse_args()
|
||
|
||
|
||
# ── Model resolution ───────────────────────────────────────────────────
|
||
def resolve_model_path(model_arg: str) -> str:
|
||
if os.path.isfile(model_arg):
|
||
return model_arg
|
||
|
||
search = [
|
||
SCRIPT_DIR / "models" / model_arg,
|
||
# Standard ComfyUI layout: custom_nodes/Comfyui-STAR/../../models/star/
|
||
SCRIPT_DIR / ".." / ".." / "models" / "star" / model_arg,
|
||
]
|
||
for candidate in search:
|
||
candidate = candidate.resolve()
|
||
if candidate.is_file():
|
||
return str(candidate)
|
||
|
||
if model_arg not in HF_MODELS:
|
||
raise FileNotFoundError(
|
||
f"Model '{model_arg}' not found locally and is not a known "
|
||
"downloadable model. Provide a full path, or use "
|
||
"light_deg.pt / heavy_deg.pt."
|
||
)
|
||
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
print(f"[STAR] Downloading {model_arg} from HuggingFace ({HF_REPO})...")
|
||
dest_dir = str(SCRIPT_DIR / "models")
|
||
os.makedirs(dest_dir, exist_ok=True)
|
||
path = hf_hub_download(
|
||
repo_id=HF_REPO, filename=HF_MODELS[model_arg], local_dir=dest_dir,
|
||
)
|
||
return path
|
||
|
||
|
||
# ── Model loading (mirrors STARModelLoader.load_model) ─────────────────
|
||
def load_model(model_path: str, precision: str, offload: str, device: torch.device):
|
||
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
||
dtype = dtype_map[precision]
|
||
keep_on = device if offload == "disabled" else "cpu"
|
||
|
||
from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder
|
||
from video_to_video.utils.config import cfg
|
||
|
||
print("[STAR] Loading text encoder (OpenCLIP ViT-H-14)...")
|
||
text_encoder = FrozenOpenCLIPEmbedder(device=device, pretrained="laion2b_s32b_b79k")
|
||
text_encoder.model.to(device)
|
||
negative_y = text_encoder(cfg.negative_prompt).detach()
|
||
text_encoder.model.to(keep_on)
|
||
|
||
from video_to_video.modules.unet_v2v import ControlledV2VUNet
|
||
|
||
print("[STAR] Loading UNet + ControlNet...")
|
||
generator = ControlledV2VUNet()
|
||
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
|
||
if "state_dict" in ckpt:
|
||
ckpt = ckpt["state_dict"]
|
||
generator.load_state_dict(ckpt, strict=False)
|
||
del ckpt
|
||
generator = generator.to(device=keep_on, dtype=dtype)
|
||
generator.eval()
|
||
|
||
from video_to_video.diffusion.schedules_sdedit import noise_schedule
|
||
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
|
||
|
||
sigmas = noise_schedule(
|
||
schedule="logsnr_cosine_interp", n=1000,
|
||
zero_terminal_snr=True, scale_min=2.0, scale_max=4.0,
|
||
)
|
||
diffusion = GaussianDiffusion(sigmas=sigmas)
|
||
|
||
from diffusers import AutoencoderKLTemporalDecoder
|
||
|
||
print("[STAR] Loading temporal VAE...")
|
||
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
||
"stabilityai/stable-video-diffusion-img2vid",
|
||
subfolder="vae", variant="fp16",
|
||
)
|
||
vae.eval()
|
||
vae.requires_grad_(False)
|
||
vae.to(keep_on)
|
||
|
||
torch.cuda.empty_cache()
|
||
print("[STAR] All models loaded.")
|
||
|
||
return {
|
||
"text_encoder": text_encoder,
|
||
"generator": generator,
|
||
"diffusion": diffusion,
|
||
"vae": vae,
|
||
"negative_y": negative_y,
|
||
"device": device,
|
||
"dtype": dtype,
|
||
"offload": offload,
|
||
}
|
||
|
||
|
||
# ── Input reading ──────────────────────────────────────────────────────
|
||
def _ffprobe(path: str):
|
||
"""Return (width, height, fps, nb_frames) via ffprobe."""
|
||
cmd = [
|
||
"ffprobe", "-v", "quiet", "-print_format", "json",
|
||
"-show_streams", "-show_format", str(path),
|
||
]
|
||
info = json.loads(subprocess.check_output(cmd))
|
||
vs = next(s for s in info["streams"] if s["codec_type"] == "video")
|
||
w, h = int(vs["width"]), int(vs["height"])
|
||
|
||
num, den = map(int, vs.get("r_frame_rate", "24/1").split("/"))
|
||
fps = num / den if den else 24.0
|
||
|
||
nb = vs.get("nb_frames")
|
||
if nb and nb != "N/A":
|
||
n_frames = int(nb)
|
||
else:
|
||
dur = float(info.get("format", {}).get("duration", 0))
|
||
n_frames = int(dur * fps) or 0
|
||
|
||
return w, h, fps, n_frames
|
||
|
||
|
||
def read_video(path: str):
|
||
"""Read video → (np array [N,H,W,3] uint8, fps)."""
|
||
w, h, fps, est = _ffprobe(path)
|
||
print(f"[STAR] Input video: {w}x{h}, {fps:.2f} fps, ~{est} frames")
|
||
|
||
cmd = [
|
||
"ffmpeg", "-i", str(path),
|
||
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
||
"-v", "quiet", "-",
|
||
]
|
||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
|
||
|
||
frames = []
|
||
fsize = w * h * 3
|
||
while True:
|
||
raw = proc.stdout.read(fsize)
|
||
if len(raw) < fsize:
|
||
break
|
||
frames.append(np.frombuffer(raw, dtype=np.uint8).reshape(h, w, 3))
|
||
|
||
proc.wait()
|
||
print(f"[STAR] Read {len(frames)} frames")
|
||
return np.stack(frames), fps
|
||
|
||
|
||
def read_image_dir(directory: str):
|
||
"""Read image directory → (np array [N,H,W,3] uint8, None)."""
|
||
d = Path(directory)
|
||
files = sorted(f for f in d.iterdir() if f.suffix.lower() in IMAGE_EXTS)
|
||
if not files:
|
||
raise FileNotFoundError(f"No image files in {d}")
|
||
|
||
print(f"[STAR] Loading {len(files)} images from {d}")
|
||
frames = [np.array(Image.open(f).convert("RGB")) for f in files]
|
||
return np.stack(frames), None
|
||
|
||
|
||
def read_input(path: str):
|
||
"""Auto-detect input type → (np array [N,H,W,3] uint8, fps | None)."""
|
||
p = Path(path)
|
||
if p.is_dir():
|
||
return read_image_dir(path)
|
||
if p.suffix.lower() in VIDEO_EXTS:
|
||
return read_video(path)
|
||
if p.suffix.lower() in IMAGE_EXTS:
|
||
img = np.array(Image.open(p).convert("RGB"))
|
||
return img[np.newaxis], None
|
||
raise ValueError(f"Unsupported input: {path}")
|
||
|
||
|
||
# ── Output writing ─────────────────────────────────────────────────────
|
||
class VideoWriter:
|
||
"""Stream RGB frames to ffmpeg, optionally copying audio from source."""
|
||
|
||
def __init__(self, output_path, fps, width, height,
|
||
codec="libx264", crf=18, pix_fmt="yuv420p",
|
||
audio_source=None):
|
||
cmd = [
|
||
"ffmpeg", "-y",
|
||
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
||
"-s", f"{width}x{height}", "-r", str(fps),
|
||
"-i", "-",
|
||
]
|
||
if audio_source:
|
||
cmd += ["-i", str(audio_source)]
|
||
|
||
cmd += ["-map", "0:v:0"]
|
||
if audio_source:
|
||
cmd += ["-map", "1:a?", "-c:a", "copy"]
|
||
|
||
cmd += [
|
||
"-c:v", codec, "-crf", str(crf), "-pix_fmt", pix_fmt,
|
||
"-movflags", "+faststart",
|
||
"-v", "warning",
|
||
str(output_path),
|
||
]
|
||
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
|
||
self.count = 0
|
||
|
||
def write_frame(self, frame_uint8):
|
||
self.proc.stdin.write(frame_uint8.tobytes())
|
||
self.count += 1
|
||
|
||
def close(self):
|
||
self.proc.stdin.close()
|
||
self.proc.wait()
|
||
|
||
|
||
class ImageSequenceWriter:
|
||
"""Save frames as numbered image files."""
|
||
|
||
def __init__(self, out_dir, ext=".png"):
|
||
self.out_dir = Path(out_dir)
|
||
self.out_dir.mkdir(parents=True, exist_ok=True)
|
||
self.ext = ext
|
||
self.count = 0
|
||
|
||
def write_frame(self, frame_uint8):
|
||
Image.fromarray(frame_uint8).save(
|
||
self.out_dir / f"{self.count:06d}{self.ext}"
|
||
)
|
||
self.count += 1
|
||
|
||
def close(self):
|
||
pass
|
||
|
||
|
||
class SingleImageWriter:
|
||
"""Save a single output image."""
|
||
|
||
def __init__(self, path):
|
||
self.path = Path(path)
|
||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||
self.count = 0
|
||
|
||
def write_frame(self, frame_uint8):
|
||
Image.fromarray(frame_uint8).save(self.path)
|
||
self.count += 1
|
||
|
||
def close(self):
|
||
pass
|
||
|
||
|
||
# ── Output path helpers ────────────────────────────────────────────────
|
||
def is_video_path(p):
|
||
return Path(p).suffix.lower() in VIDEO_EXTS
|
||
|
||
|
||
def is_image_path(p):
|
||
return Path(p).suffix.lower() in IMAGE_EXTS
|
||
|
||
|
||
def auto_output(input_path: str) -> str:
|
||
p = Path(input_path)
|
||
if p.is_dir():
|
||
return str(p.parent / (p.name + "_star"))
|
||
return str(p.parent / (p.stem + "_star" + p.suffix))
|
||
|
||
|
||
def make_writer(output_path, fps, w, h, args, input_path, is_single_image):
|
||
"""Create the appropriate writer for the output path."""
|
||
if is_single_image and is_image_path(output_path):
|
||
return SingleImageWriter(output_path)
|
||
if is_video_path(output_path):
|
||
audio_src = input_path if (
|
||
is_video_path(input_path) and not args.no_audio
|
||
) else None
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
return VideoWriter(
|
||
output_path, fps, w, h,
|
||
codec=args.codec, crf=args.crf, pix_fmt=args.pix_fmt,
|
||
audio_source=audio_src,
|
||
)
|
||
# Default: image sequence directory
|
||
return ImageSequenceWriter(output_path, ext=".png")
|
||
|
||
|
||
# ── Tensor ↔ numpy ─────────────────────────────────────────────────────
|
||
def to_tensor(frames_uint8):
|
||
"""[N,H,W,3] uint8 numpy → [N,H,W,3] float32 torch in [0,1]."""
|
||
return torch.from_numpy(frames_uint8.copy()).float() / 255.0
|
||
|
||
|
||
def to_uint8(tensor):
|
||
"""[N,H,W,3] float32 torch in [0,1] → [N,H,W,3] uint8 numpy."""
|
||
return (tensor.clamp(0, 1) * 255).byte().cpu().numpy()
|
||
|
||
|
||
# ── Segment-based processing with streaming output ─────────────────────
|
||
def _run_segment(star_model, frames_uint8, args):
|
||
"""Process one segment through STAR → float32 tensor [F,H,W,3]."""
|
||
from star_pipeline import run_star_inference
|
||
|
||
tensor = to_tensor(frames_uint8)
|
||
return run_star_inference(
|
||
star_model=star_model,
|
||
images=tensor,
|
||
upscale=args.upscale,
|
||
steps=args.steps,
|
||
guide_scale=args.guide_scale,
|
||
prompt=args.prompt,
|
||
solver_mode=args.solver_mode,
|
||
max_chunk_len=args.max_chunk_len,
|
||
seed=args.seed,
|
||
color_fix=args.color_fix,
|
||
)
|
||
|
||
|
||
def _write_tensor(writer, tensor):
|
||
"""Write a float32 [F,H,W,3] tensor as uint8 frames."""
|
||
arr = to_uint8(tensor)
|
||
for i in range(arr.shape[0]):
|
||
writer.write_frame(arr[i])
|
||
|
||
|
||
def process_and_stream(star_model, input_frames, writer, args):
|
||
"""Process in segments, blend overlaps, and stream to writer."""
|
||
total = input_frames.shape[0]
|
||
seg = args.segment_size
|
||
|
||
# No segmentation — process everything at once
|
||
if seg <= 0 or total <= seg:
|
||
print(f"[STAR] Processing all {total} frame(s)...")
|
||
result = _run_segment(star_model, input_frames, args)
|
||
_write_tensor(writer, result)
|
||
return
|
||
|
||
overlap = max(2, seg // 4)
|
||
stride = seg - overlap
|
||
|
||
# Build segment boundaries
|
||
segments = []
|
||
start = 0
|
||
while start < total:
|
||
end = min(start + seg, total)
|
||
segments.append((start, end))
|
||
if end == total:
|
||
break
|
||
start += stride
|
||
|
||
print(f"[STAR] {total} frames → {len(segments)} segment(s), "
|
||
f"segment_size={seg}, overlap={overlap}")
|
||
|
||
prev_tail = None # float32 tensor on CPU
|
||
|
||
for idx, (s, e) in enumerate(segments):
|
||
print(f"\n[STAR] ── Segment {idx + 1}/{len(segments)}: "
|
||
f"frames {s}–{e - 1} ──")
|
||
|
||
seg_result = _run_segment(star_model, input_frames[s:e], args)
|
||
|
||
# Blend overlap with previous segment's tail
|
||
if prev_tail is not None:
|
||
n = prev_tail.shape[0]
|
||
head = seg_result[:n]
|
||
w = torch.linspace(0, 1, n, dtype=seg_result.dtype).view(n, 1, 1, 1)
|
||
blended = prev_tail * (1.0 - w) + head * w
|
||
_write_tensor(writer, blended)
|
||
remainder = seg_result[n:]
|
||
else:
|
||
remainder = seg_result
|
||
|
||
if idx < len(segments) - 1:
|
||
# Keep tail for blending, write the rest
|
||
prev_tail = remainder[-overlap:].clone()
|
||
_write_tensor(writer, remainder[:-overlap])
|
||
else:
|
||
# Last segment — write everything
|
||
_write_tensor(writer, remainder)
|
||
|
||
del seg_result
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
# ── Main ────────────────────────────────────────────────────────────────
|
||
def main():
|
||
args = parse_args()
|
||
|
||
# Validate environment
|
||
if not torch.cuda.is_available():
|
||
print("Error: CUDA is not available. STAR requires a CUDA GPU.")
|
||
sys.exit(1)
|
||
if not shutil.which("ffmpeg") or not shutil.which("ffprobe"):
|
||
input_p = Path(args.input)
|
||
if input_p.suffix.lower() in VIDEO_EXTS or (
|
||
args.output and Path(args.output).suffix.lower() in VIDEO_EXTS
|
||
):
|
||
print("Error: ffmpeg/ffprobe not found. Install ffmpeg for video I/O.")
|
||
sys.exit(1)
|
||
|
||
# Read input
|
||
input_frames, input_fps = read_input(args.input)
|
||
total = input_frames.shape[0]
|
||
h_in, w_in = input_frames.shape[1], input_frames.shape[2]
|
||
h_out, w_out = h_in * args.upscale, w_in * args.upscale
|
||
is_single = total == 1 and Path(args.input).is_file() and is_image_path(args.input)
|
||
|
||
print(f"[STAR] {w_in}x{h_in} → {w_out}x{h_out} ({args.upscale}x), "
|
||
f"{total} frame(s)")
|
||
|
||
# Output path
|
||
output_path = args.output or auto_output(args.input)
|
||
fps = args.fps or input_fps or 24.0
|
||
|
||
# Load model
|
||
device = torch.device("cuda")
|
||
model_path = resolve_model_path(args.model)
|
||
print(f"[STAR] Model: {model_path}")
|
||
star_model = load_model(model_path, args.precision, args.offload, device)
|
||
|
||
_set_attn(args.attention)
|
||
|
||
# Create writer and process
|
||
writer = make_writer(output_path, fps, w_out, h_out, args, args.input, is_single)
|
||
process_and_stream(star_model, input_frames, writer, args)
|
||
writer.close()
|
||
|
||
print(f"\n[STAR] Done! {writer.count} frame(s) → {output_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|