Initial release: ComfyUI nodes for STAR video super-resolution
Two-node package wrapping the STAR (ICCV 2025) diffusion-based video upscaling pipeline: - STAR Model Loader: loads UNet+ControlNet, OpenCLIP text encoder, and temporal VAE with auto-download from HuggingFace - STAR Video Super-Resolution: runs the full diffusion pipeline with configurable upscale factor, guidance, solver mode, chunking, and color correction Includes three VRAM offload modes (disabled/model/aggressive) to support GPUs from 12GB to 40GB+. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
.claude/
|
||||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "STAR"]
|
||||||
|
path = STAR
|
||||||
|
url = https://github.com/NJU-PCALab/STAR.git
|
||||||
88
README.md
Normal file
88
README.md
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# ComfyUI-STAR
|
||||||
|
|
||||||
|
ComfyUI custom nodes for [STAR (Spatial-Temporal Augmentation with Text-to-Video Models for Real-World Video Super-Resolution)](https://github.com/NJU-PCALab/STAR) — a diffusion-based video upscaling model (ICCV 2025).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Diffusion-based 4x video super-resolution** with temporal coherence
|
||||||
|
- **Two model variants**: `light_deg.pt` (light degradation) and `heavy_deg.pt` (heavy degradation)
|
||||||
|
- **Auto-download**: all models (UNet checkpoint, OpenCLIP text encoder, temporal VAE) download automatically on first use
|
||||||
|
- **VRAM offloading**: three modes to fit GPUs from 12GB to 40GB+
|
||||||
|
- **Long video support**: sliding-window chunking with 50% overlap
|
||||||
|
- **Color correction**: AdaIN and wavelet-based post-processing
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### ComfyUI Manager
|
||||||
|
|
||||||
|
Search for `ComfyUI-STAR` in ComfyUI Manager and install.
|
||||||
|
|
||||||
|
### Manual
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ComfyUI/custom_nodes
|
||||||
|
git clone --recursive git@192.168.1.1:Ethanfel/Comfyui-STAR.git
|
||||||
|
cd Comfyui-STAR
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
> The `--recursive` flag clones the STAR submodule. If you forgot it, run `git submodule update --init` afterwards.
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### STAR Model Loader
|
||||||
|
|
||||||
|
Loads the STAR model components (UNet+ControlNet, OpenCLIP text encoder, temporal VAE).
|
||||||
|
|
||||||
|
| Input | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| **model_name** | `light_deg.pt` for mildly degraded video, `heavy_deg.pt` for heavily degraded video. Auto-downloaded from HuggingFace on first use. |
|
||||||
|
| **precision** | `fp16` (recommended), `bf16`, or `fp32`. |
|
||||||
|
| **offload** | `disabled` (~39GB VRAM), `model` (~16GB — swaps components to CPU when idle), `aggressive` (~12GB — model offload + single-frame VAE decode). |
|
||||||
|
|
||||||
|
### STAR Video Super-Resolution
|
||||||
|
|
||||||
|
Runs the STAR diffusion pipeline on an image batch.
|
||||||
|
|
||||||
|
| Input | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| **star_model** | Connect from STAR Model Loader. |
|
||||||
|
| **images** | Input video frames (IMAGE batch). |
|
||||||
|
| **upscale** | Upscale factor (1–8, default 4). |
|
||||||
|
| **steps** | Denoising steps (1–100, default 15). Ignored in `fast` mode. |
|
||||||
|
| **guide_scale** | Classifier-free guidance scale (1–20, default 7.5). |
|
||||||
|
| **prompt** | Text prompt. Leave empty for STAR's built-in quality prompt. |
|
||||||
|
| **solver_mode** | `fast` (optimized 15-step schedule) or `normal` (uniform schedule). |
|
||||||
|
| **max_chunk_len** | Max frames per chunk (4–128, default 32). Lower = less VRAM for long videos. |
|
||||||
|
| **seed** | Random seed for reproducibility. |
|
||||||
|
| **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. |
|
||||||
|
|
||||||
|
## VRAM Requirements
|
||||||
|
|
||||||
|
| Offload Mode | Approximate VRAM | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| disabled | ~39 GB | Fastest — everything on GPU |
|
||||||
|
| model | ~16 GB | Components swap to CPU between stages |
|
||||||
|
| aggressive | ~12 GB | Model offload + frame-by-frame VAE decode |
|
||||||
|
|
||||||
|
Reducing `max_chunk_len` further lowers VRAM usage for long videos at the cost of slightly more processing time.
|
||||||
|
|
||||||
|
## Model Weights
|
||||||
|
|
||||||
|
Models are stored in `ComfyUI/models/star/` and auto-downloaded on first use:
|
||||||
|
|
||||||
|
| Model | Use Case | Source |
|
||||||
|
|-------|----------|--------|
|
||||||
|
| `light_deg.pt` | Low-res video from the web, mild compression | [HuggingFace](https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/light_deg.pt) |
|
||||||
|
| `heavy_deg.pt` | Heavily compressed/degraded video | [HuggingFace](https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt) |
|
||||||
|
|
||||||
|
The OpenCLIP text encoder and SVD temporal VAE are downloaded automatically by their respective libraries on first load.
|
||||||
|
|
||||||
|
## Credits
|
||||||
|
|
||||||
|
- [STAR](https://github.com/NJU-PCALab/STAR) by Rui Xie, Yinhong Liu et al. (Nanjing University) — ICCV 2025
|
||||||
|
- Based on [I2VGen-XL](https://github.com/ali-vilab/VGen) and [VEnhancer](https://github.com/Vchitect/VEnhancer)
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This wrapper is MIT licensed. The STAR model weights follow their respective licenses (MIT for I2VGen-XL-based models).
|
||||||
1
STAR
Submodule
1
STAR
Submodule
Submodule STAR added at 69b8bc53e4
3
__init__.py
Normal file
3
__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||||
10
install.py
Normal file
10
install.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", req_file])
|
||||||
|
|
||||||
|
star_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "STAR")
|
||||||
|
if not os.path.isdir(star_dir):
|
||||||
|
subprocess.check_call(["git", "clone", "https://github.com/NJU-PCALab/STAR.git", star_dir])
|
||||||
255
nodes.py
Normal file
255
nodes.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
# Register the "star" model folder so users can drop .pt weights there.
|
||||||
|
star_model_dir = os.path.join(folder_paths.models_dir, "star")
|
||||||
|
os.makedirs(star_model_dir, exist_ok=True)
|
||||||
|
folder_paths.folder_names_and_paths["star"] = (
|
||||||
|
[star_model_dir],
|
||||||
|
folder_paths.supported_pt_extensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Put the cloned STAR repo on sys.path so its internal imports work.
|
||||||
|
STAR_REPO = os.path.join(os.path.dirname(os.path.realpath(__file__)), "STAR")
|
||||||
|
if STAR_REPO not in sys.path:
|
||||||
|
sys.path.insert(0, STAR_REPO)
|
||||||
|
|
||||||
|
# Known models on HuggingFace that can be auto-downloaded.
|
||||||
|
HF_REPO = "SherryX/STAR"
|
||||||
|
HF_MODELS = {
|
||||||
|
"light_deg.pt": "I2VGen-XL-based/light_deg.pt",
|
||||||
|
"heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_list():
|
||||||
|
"""Return the union of files already on disk + known downloadable models."""
|
||||||
|
on_disk = set(folder_paths.get_filename_list("star"))
|
||||||
|
available = set(HF_MODELS.keys())
|
||||||
|
return sorted(on_disk | available)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_model(model_name: str) -> str:
|
||||||
|
"""Return the local path to model_name, downloading from HF if needed."""
|
||||||
|
local = folder_paths.get_full_path("star", model_name)
|
||||||
|
if local is not None:
|
||||||
|
return local
|
||||||
|
|
||||||
|
if model_name not in HF_MODELS:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model '{model_name}' not found in {star_model_dir} and is not a known downloadable model."
|
||||||
|
)
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
print(f"[STAR] Downloading {model_name} from HuggingFace ({HF_REPO})...")
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id=HF_REPO,
|
||||||
|
filename=HF_MODELS[model_name],
|
||||||
|
local_dir=star_model_dir,
|
||||||
|
)
|
||||||
|
# hf_hub_download may place the file in a subdirectory; symlink into the
|
||||||
|
# star folder root so folder_paths can find it next time.
|
||||||
|
dest = os.path.join(star_model_dir, model_name)
|
||||||
|
if not os.path.exists(dest):
|
||||||
|
os.symlink(path, dest)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
class STARModelLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_name": (_get_model_list(), {
|
||||||
|
"tooltip": "STAR checkpoint to load. light_deg for mildly degraded video, heavy_deg for heavily degraded video. Auto-downloaded from HuggingFace on first use.",
|
||||||
|
}),
|
||||||
|
"precision": (["fp16", "bf16", "fp32"], {
|
||||||
|
"default": "fp16",
|
||||||
|
"tooltip": "Weight precision. fp16 is recommended (fastest, lowest VRAM). bf16 for newer GPUs. fp32 for maximum quality at 2x VRAM cost.",
|
||||||
|
}),
|
||||||
|
"offload": (["disabled", "model", "aggressive"], {
|
||||||
|
"default": "disabled",
|
||||||
|
"tooltip": "disabled: all on GPU (~39GB). model: swap UNet/VAE/CLIP to CPU when idle (~16GB). aggressive: model offload + single-frame VAE decode (~12GB).",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STAR_MODEL",)
|
||||||
|
RETURN_NAMES = ("star_model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = "STAR"
|
||||||
|
DESCRIPTION = "Loads the STAR video super-resolution model (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). All components are auto-downloaded on first use."
|
||||||
|
|
||||||
|
def load_model(self, model_name, precision, offload="disabled"):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
||||||
|
dtype = dtype_map[precision]
|
||||||
|
|
||||||
|
# Where to park models when not in use.
|
||||||
|
keep_on = device if offload == "disabled" else "cpu"
|
||||||
|
|
||||||
|
model_path = _ensure_model(model_name)
|
||||||
|
|
||||||
|
# ---- Text encoder (OpenCLIP ViT-H-14) ----
|
||||||
|
from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder
|
||||||
|
|
||||||
|
text_encoder = FrozenOpenCLIPEmbedder(
|
||||||
|
device=device, pretrained="laion2b_s32b_b79k"
|
||||||
|
)
|
||||||
|
text_encoder.model.to(device)
|
||||||
|
|
||||||
|
# Pre-compute the negative prompt embedding used during sampling.
|
||||||
|
from video_to_video.utils.config import cfg
|
||||||
|
|
||||||
|
negative_y = text_encoder(cfg.negative_prompt).detach()
|
||||||
|
|
||||||
|
# Park text encoder after pre-computing embeddings.
|
||||||
|
text_encoder.model.to(keep_on)
|
||||||
|
|
||||||
|
# ---- UNet + ControlNet ----
|
||||||
|
from video_to_video.modules.unet_v2v import ControlledV2VUNet
|
||||||
|
|
||||||
|
generator = ControlledV2VUNet()
|
||||||
|
load_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||||
|
if "state_dict" in load_dict:
|
||||||
|
load_dict = load_dict["state_dict"]
|
||||||
|
generator.load_state_dict(load_dict, strict=False)
|
||||||
|
del load_dict
|
||||||
|
generator = generator.to(device=keep_on, dtype=dtype)
|
||||||
|
generator.eval()
|
||||||
|
|
||||||
|
# ---- Noise schedule + diffusion helper ----
|
||||||
|
from video_to_video.diffusion.schedules_sdedit import noise_schedule
|
||||||
|
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
|
||||||
|
|
||||||
|
sigmas = noise_schedule(
|
||||||
|
schedule="logsnr_cosine_interp",
|
||||||
|
n=1000,
|
||||||
|
zero_terminal_snr=True,
|
||||||
|
scale_min=2.0,
|
||||||
|
scale_max=4.0,
|
||||||
|
)
|
||||||
|
diffusion = GaussianDiffusion(sigmas=sigmas)
|
||||||
|
|
||||||
|
# ---- Temporal VAE (from HuggingFace diffusers) ----
|
||||||
|
from diffusers import AutoencoderKLTemporalDecoder
|
||||||
|
|
||||||
|
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
||||||
|
"stabilityai/stable-video-diffusion-img2vid",
|
||||||
|
subfolder="vae",
|
||||||
|
variant="fp16",
|
||||||
|
)
|
||||||
|
vae.eval()
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.to(keep_on)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
star_model = {
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"generator": generator,
|
||||||
|
"diffusion": diffusion,
|
||||||
|
"vae": vae,
|
||||||
|
"negative_y": negative_y,
|
||||||
|
"device": device,
|
||||||
|
"dtype": dtype,
|
||||||
|
"offload": offload,
|
||||||
|
}
|
||||||
|
return (star_model,)
|
||||||
|
|
||||||
|
|
||||||
|
class STARVideoSuperResolution:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"star_model": ("STAR_MODEL", {
|
||||||
|
"tooltip": "Connect from STAR Model Loader.",
|
||||||
|
}),
|
||||||
|
"images": ("IMAGE", {
|
||||||
|
"tooltip": "Input video frames (IMAGE batch). Can come from LoadImage, VHS LoadVideo, etc.",
|
||||||
|
}),
|
||||||
|
"upscale": ("INT", {
|
||||||
|
"default": 4, "min": 1, "max": 8,
|
||||||
|
"tooltip": "Upscale factor applied to the input resolution. 4x is the default. Higher values need more VRAM.",
|
||||||
|
}),
|
||||||
|
"steps": ("INT", {
|
||||||
|
"default": 15, "min": 1, "max": 100,
|
||||||
|
"tooltip": "Number of denoising steps. Ignored in 'fast' solver mode (hardcoded 15). More steps = better quality but slower.",
|
||||||
|
}),
|
||||||
|
"guide_scale": ("FLOAT", {
|
||||||
|
"default": 7.5, "min": 1.0, "max": 20.0, "step": 0.5,
|
||||||
|
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strongly. 7.5 is a good default.",
|
||||||
|
}),
|
||||||
|
"prompt": ("STRING", {
|
||||||
|
"default": "", "multiline": True,
|
||||||
|
"tooltip": "Text prompt describing the desired output. Leave empty to use STAR's built-in quality prompt.",
|
||||||
|
}),
|
||||||
|
"solver_mode": (["fast", "normal"], {
|
||||||
|
"default": "fast",
|
||||||
|
"tooltip": "fast: optimized 15-step schedule (4 coarse + 11 fine). normal: uniform schedule using the steps parameter.",
|
||||||
|
}),
|
||||||
|
"max_chunk_len": ("INT", {
|
||||||
|
"default": 32, "min": 4, "max": 128,
|
||||||
|
"tooltip": "Max frames processed at once. Lower values reduce VRAM usage for long videos. Chunks overlap by 50%.",
|
||||||
|
}),
|
||||||
|
"seed": ("INT", {
|
||||||
|
"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"tooltip": "Random seed for reproducible results.",
|
||||||
|
}),
|
||||||
|
"color_fix": (["adain", "wavelet", "none"], {
|
||||||
|
"default": "adain",
|
||||||
|
"tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
FUNCTION = "upscale_video"
|
||||||
|
CATEGORY = "STAR"
|
||||||
|
DESCRIPTION = "Upscale video frames using STAR diffusion-based super-resolution."
|
||||||
|
|
||||||
|
def upscale_video(
|
||||||
|
self,
|
||||||
|
star_model,
|
||||||
|
images,
|
||||||
|
upscale,
|
||||||
|
steps,
|
||||||
|
guide_scale,
|
||||||
|
prompt,
|
||||||
|
solver_mode,
|
||||||
|
max_chunk_len,
|
||||||
|
seed,
|
||||||
|
color_fix,
|
||||||
|
):
|
||||||
|
from .star_pipeline import run_star_inference
|
||||||
|
|
||||||
|
result = run_star_inference(
|
||||||
|
star_model=star_model,
|
||||||
|
images=images,
|
||||||
|
upscale=upscale,
|
||||||
|
steps=steps,
|
||||||
|
guide_scale=guide_scale,
|
||||||
|
prompt=prompt,
|
||||||
|
solver_mode=solver_mode,
|
||||||
|
max_chunk_len=max_chunk_len,
|
||||||
|
seed=seed,
|
||||||
|
color_fix=color_fix,
|
||||||
|
)
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"STARModelLoader": STARModelLoader,
|
||||||
|
"STARVideoSuperResolution": STARVideoSuperResolution,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"STARModelLoader": "STAR Model Loader",
|
||||||
|
"STARVideoSuperResolution": "STAR Video Super-Resolution",
|
||||||
|
}
|
||||||
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
easydict
|
||||||
|
einops
|
||||||
|
open-clip-torch
|
||||||
|
torchsde
|
||||||
|
diffusers>=0.25.0
|
||||||
|
huggingface_hub
|
||||||
290
star_pipeline.py
Normal file
290
star_pipeline.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.amp
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
from video_to_video.video_to_video_model import pad_to_fit, make_chunks
|
||||||
|
from video_to_video.utils.config import cfg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tensor format conversions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def comfyui_to_star_frames(images: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert ComfyUI IMAGE batch to STAR input format.
|
||||||
|
|
||||||
|
ComfyUI: [N, H, W, 3] float32 in [0, 1]
|
||||||
|
STAR: [N, 3, H, W] float32 in [-1, 1]
|
||||||
|
"""
|
||||||
|
t = images.permute(0, 3, 1, 2) # [N,3,H,W]
|
||||||
|
t = t * 2.0 - 1.0
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def star_output_to_comfyui(video: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert STAR output to ComfyUI IMAGE batch.
|
||||||
|
|
||||||
|
STAR output: [1, 3, F, H, W] float32 in [-1, 1]
|
||||||
|
ComfyUI: [F, H, W, 3] float32 in [0, 1]
|
||||||
|
"""
|
||||||
|
v = video.squeeze(0) # [3, F, H, W]
|
||||||
|
v = v.permute(1, 2, 3, 0) # [F, H, W, 3]
|
||||||
|
v = (v + 1.0) / 2.0
|
||||||
|
v = v.clamp(0.0, 1.0)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VAE helpers (mirror VideoToVideo_sr methods)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def vae_encode(vae, t, chunk_size=1):
|
||||||
|
"""Encode [B, F, C, H, W] video tensor to latent space."""
|
||||||
|
num_f = t.shape[1]
|
||||||
|
t = rearrange(t, "b f c h w -> (b f) c h w")
|
||||||
|
z_list = []
|
||||||
|
for ind in range(0, t.shape[0], chunk_size):
|
||||||
|
z_list.append(vae.encode(t[ind : ind + chunk_size]).latent_dist.sample())
|
||||||
|
z = torch.cat(z_list, dim=0)
|
||||||
|
z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f)
|
||||||
|
return z * vae.config.scaling_factor
|
||||||
|
|
||||||
|
|
||||||
|
def vae_decode_chunk(vae, z, chunk_size=3):
|
||||||
|
"""Decode latent [B, C, F, H, W] back to pixel frames."""
|
||||||
|
z = rearrange(z, "b c f h w -> (b f) c h w")
|
||||||
|
video = []
|
||||||
|
for ind in range(0, z.shape[0], chunk_size):
|
||||||
|
chunk = z[ind : ind + chunk_size]
|
||||||
|
num_f = chunk.shape[0]
|
||||||
|
decoded = vae.decode(chunk / vae.config.scaling_factor, num_frames=num_f).sample
|
||||||
|
video.append(decoded)
|
||||||
|
video = torch.cat(video)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Color correction wrappers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def apply_color_fix(output_frames, input_frames_star, method):
|
||||||
|
"""Apply colour correction to the upscaled output.
|
||||||
|
|
||||||
|
output_frames: [F, H, W, 3] float [0, 1] (ComfyUI format)
|
||||||
|
input_frames_star: [F, 3, H, W] float [-1, 1] (STAR format)
|
||||||
|
method: "adain" | "wavelet" | "none"
|
||||||
|
"""
|
||||||
|
if method == "none":
|
||||||
|
return output_frames
|
||||||
|
|
||||||
|
from video_super_resolution.color_fix import adain_color_fix, wavelet_color_fix
|
||||||
|
|
||||||
|
# Resize input to match output spatial size for stats transfer
|
||||||
|
_, h_out, w_out, _ = output_frames.shape
|
||||||
|
source = F.interpolate(
|
||||||
|
input_frames_star, size=(h_out, w_out), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# The color_fix functions expect:
|
||||||
|
# target: [T, H, W, C] in [0, 255]
|
||||||
|
# source: [T, C, H, W] in [-1, 1]
|
||||||
|
target = output_frames * 255.0
|
||||||
|
|
||||||
|
if method == "adain":
|
||||||
|
result = adain_color_fix(target, source)
|
||||||
|
else:
|
||||||
|
result = wavelet_color_fix(target, source)
|
||||||
|
|
||||||
|
return (result / 255.0).clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Progress-bar integration via trange monkey-patch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_progress_trange(pbar, total_steps):
|
||||||
|
"""Return a drop-in replacement for tqdm.auto.trange that drives *pbar*."""
|
||||||
|
from tqdm.auto import trange as _real_trange
|
||||||
|
|
||||||
|
def _progress_trange(*args, **kwargs):
|
||||||
|
kwargs["disable"] = True # silence console output
|
||||||
|
for val in _real_trange(*args, **kwargs):
|
||||||
|
yield val
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
return _progress_trange
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main inference entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _move(module, device):
|
||||||
|
"""Move a nn.Module to device and free source memory."""
|
||||||
|
module.to(device)
|
||||||
|
if device == "cpu" or (isinstance(device, torch.device) and device.type == "cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def run_star_inference(
|
||||||
|
star_model: dict,
|
||||||
|
images: torch.Tensor,
|
||||||
|
upscale: int = 4,
|
||||||
|
steps: int = 15,
|
||||||
|
guide_scale: float = 7.5,
|
||||||
|
prompt: str = "",
|
||||||
|
solver_mode: str = "fast",
|
||||||
|
max_chunk_len: int = 32,
|
||||||
|
seed: int = 0,
|
||||||
|
color_fix: str = "adain",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run STAR video super-resolution and return ComfyUI IMAGE batch."""
|
||||||
|
|
||||||
|
device = star_model["device"]
|
||||||
|
dtype = star_model["dtype"]
|
||||||
|
text_encoder = star_model["text_encoder"]
|
||||||
|
generator = star_model["generator"]
|
||||||
|
diffusion = star_model["diffusion"]
|
||||||
|
vae = star_model["vae"]
|
||||||
|
negative_y = star_model["negative_y"]
|
||||||
|
offload = star_model.get("offload", "disabled")
|
||||||
|
|
||||||
|
# In aggressive mode use smaller VAE chunks to cut peak VRAM.
|
||||||
|
vae_enc_chunk = 1
|
||||||
|
vae_dec_chunk = 3
|
||||||
|
if offload == "aggressive":
|
||||||
|
vae_dec_chunk = 1
|
||||||
|
|
||||||
|
total_noise_levels = 1000
|
||||||
|
|
||||||
|
# -- Convert ComfyUI frames to STAR format --
|
||||||
|
video_data = comfyui_to_star_frames(images) # [F, 3, H, W]
|
||||||
|
|
||||||
|
# Keep a copy at input resolution (on CPU) for colour correction later
|
||||||
|
input_frames_star = video_data.clone().cpu()
|
||||||
|
|
||||||
|
frames_num, _, orig_h, orig_w = video_data.shape
|
||||||
|
target_h = orig_h * upscale
|
||||||
|
target_w = orig_w * upscale
|
||||||
|
|
||||||
|
# -- Bilinear upscale to target resolution --
|
||||||
|
video_data = F.interpolate(video_data, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||||
|
_, _, h, w = video_data.shape
|
||||||
|
|
||||||
|
# -- Pad to model-friendly resolution --
|
||||||
|
padding = pad_to_fit(h, w)
|
||||||
|
video_data = F.pad(video_data, padding, "constant", 1)
|
||||||
|
|
||||||
|
video_data = video_data.unsqueeze(0).to(device) # [1, F, 3, H_pad, W_pad]
|
||||||
|
|
||||||
|
# ---- Stage 1: Text encoding ----
|
||||||
|
if offload != "disabled":
|
||||||
|
text_encoder.model.to(device)
|
||||||
|
text_encoder.device = device
|
||||||
|
text = prompt if prompt.strip() else cfg.positive_prompt
|
||||||
|
y = text_encoder(text).detach()
|
||||||
|
if offload != "disabled":
|
||||||
|
text_encoder.model.to("cpu")
|
||||||
|
text_encoder.device = "cpu"
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# -- Diffusion sampling (autocast needed for fp16 VAE / UNet) --
|
||||||
|
with torch.amp.autocast("cuda"):
|
||||||
|
# ---- Stage 2: VAE encode ----
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(vae, device)
|
||||||
|
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(vae, "cpu")
|
||||||
|
# Free the full-res pixel tensor — only latents needed from here.
|
||||||
|
del video_data
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
t = torch.LongTensor([total_noise_levels - 1]).to(device)
|
||||||
|
noised_lr = diffusion.diffuse(video_data_feature, t)
|
||||||
|
|
||||||
|
model_kwargs = [{"y": y}, {"y": negative_y}, {"hint": video_data_feature}]
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
chunk_inds = (
|
||||||
|
make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len)
|
||||||
|
if frames_num > max_chunk_len
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# Need at least 2 chunks; a single chunk causes IndexError in
|
||||||
|
# model_chunk_fn when it accesses chunk_inds[1].
|
||||||
|
if chunk_inds is not None and len(chunk_inds) < 2:
|
||||||
|
chunk_inds = None
|
||||||
|
|
||||||
|
# Monkey-patch trange for progress reporting
|
||||||
|
import video_to_video.diffusion.solvers_sdedit as _solvers_mod
|
||||||
|
_orig_trange = _solvers_mod.trange
|
||||||
|
|
||||||
|
# Calculate actual number of sigma steps for progress bar
|
||||||
|
# (matches logic inside GaussianDiffusion.sample_sr)
|
||||||
|
if solver_mode == "fast":
|
||||||
|
num_sigma_steps = 14 # 4 coarse + 11 fine = 15 sigmas, trange iterates len-1 = 14
|
||||||
|
else:
|
||||||
|
num_sigma_steps = steps
|
||||||
|
pbar = comfy.utils.ProgressBar(num_sigma_steps)
|
||||||
|
_solvers_mod.trange = _make_progress_trange(pbar, num_sigma_steps)
|
||||||
|
|
||||||
|
# ---- Stage 3: Diffusion (UNet) ----
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(generator, device)
|
||||||
|
try:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
gen_vid = diffusion.sample_sr(
|
||||||
|
noise=noised_lr,
|
||||||
|
model=generator,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
guide_scale=guide_scale,
|
||||||
|
guide_rescale=0.2,
|
||||||
|
solver="dpmpp_2m_sde",
|
||||||
|
solver_mode=solver_mode,
|
||||||
|
return_intermediate=None,
|
||||||
|
steps=steps,
|
||||||
|
t_max=total_noise_levels - 1,
|
||||||
|
t_min=0,
|
||||||
|
discretization="trailing",
|
||||||
|
chunk_inds=chunk_inds,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
_solvers_mod.trange = _orig_trange
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(generator, "cpu")
|
||||||
|
|
||||||
|
# Free latents that are no longer needed.
|
||||||
|
del noised_lr, video_data_feature, model_kwargs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# ---- Stage 4: VAE decode ----
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(vae, device)
|
||||||
|
vid_tensor_gen = vae_decode_chunk(vae, gen_vid, chunk_size=vae_dec_chunk)
|
||||||
|
if offload != "disabled":
|
||||||
|
_move(vae, "cpu")
|
||||||
|
|
||||||
|
# -- Remove padding --
|
||||||
|
w1, w2, h1, h2 = padding
|
||||||
|
vid_tensor_gen = vid_tensor_gen[:, :, h1 : h + h1, w1 : w + w1]
|
||||||
|
|
||||||
|
# -- Reshape to [B, C, F, H, W] then convert to ComfyUI format --
|
||||||
|
gen_video = rearrange(vid_tensor_gen, "(b f) c h w -> b c f h w", b=1)
|
||||||
|
gen_video = gen_video.float().cpu()
|
||||||
|
|
||||||
|
result = star_output_to_comfyui(gen_video) # [F, H, W, 3]
|
||||||
|
|
||||||
|
# -- Color correction --
|
||||||
|
result = apply_color_fix(result, input_frames_star, color_fix)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
return result
|
||||||
Reference in New Issue
Block a user