commit 5f9287cfacae5e06a38a0e4d9f5ae5133d598081 Author: Ethanfel Date: Sat Feb 14 23:20:27 2026 +0100 Initial release: ComfyUI nodes for STAR video super-resolution Two-node package wrapping the STAR (ICCV 2025) diffusion-based video upscaling pipeline: - STAR Model Loader: loads UNet+ControlNet, OpenCLIP text encoder, and temporal VAE with auto-download from HuggingFace - STAR Video Super-Resolution: runs the full diffusion pipeline with configurable upscale factor, guidance, solver mode, chunking, and color correction Includes three VRAM offload modes (disabled/model/aggressive) to support GPUs from 12GB to 40GB+. Co-Authored-By: Claude Opus 4.6 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2af3729 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +*.pyc +*.pyo +.claude/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..64870b4 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "STAR"] + path = STAR + url = https://github.com/NJU-PCALab/STAR.git diff --git a/README.md b/README.md new file mode 100644 index 0000000..f30506c --- /dev/null +++ b/README.md @@ -0,0 +1,88 @@ +# ComfyUI-STAR + +ComfyUI custom nodes for [STAR (Spatial-Temporal Augmentation with Text-to-Video Models for Real-World Video Super-Resolution)](https://github.com/NJU-PCALab/STAR) — a diffusion-based video upscaling model (ICCV 2025). + +## Features + +- **Diffusion-based 4x video super-resolution** with temporal coherence +- **Two model variants**: `light_deg.pt` (light degradation) and `heavy_deg.pt` (heavy degradation) +- **Auto-download**: all models (UNet checkpoint, OpenCLIP text encoder, temporal VAE) download automatically on first use +- **VRAM offloading**: three modes to fit GPUs from 12GB to 40GB+ +- **Long video support**: sliding-window chunking with 50% overlap +- **Color correction**: AdaIN and wavelet-based post-processing + +## Installation + +### ComfyUI Manager + +Search for `ComfyUI-STAR` in ComfyUI Manager and install. + +### Manual + +```bash +cd ComfyUI/custom_nodes +git clone --recursive git@192.168.1.1:Ethanfel/Comfyui-STAR.git +cd Comfyui-STAR +pip install -r requirements.txt +``` + +> The `--recursive` flag clones the STAR submodule. If you forgot it, run `git submodule update --init` afterwards. + +## Nodes + +### STAR Model Loader + +Loads the STAR model components (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). + +| Input | Description | +|-------|-------------| +| **model_name** | `light_deg.pt` for mildly degraded video, `heavy_deg.pt` for heavily degraded video. Auto-downloaded from HuggingFace on first use. | +| **precision** | `fp16` (recommended), `bf16`, or `fp32`. | +| **offload** | `disabled` (~39GB VRAM), `model` (~16GB — swaps components to CPU when idle), `aggressive` (~12GB — model offload + single-frame VAE decode). | + +### STAR Video Super-Resolution + +Runs the STAR diffusion pipeline on an image batch. + +| Input | Description | +|-------|-------------| +| **star_model** | Connect from STAR Model Loader. | +| **images** | Input video frames (IMAGE batch). | +| **upscale** | Upscale factor (1–8, default 4). | +| **steps** | Denoising steps (1–100, default 15). Ignored in `fast` mode. | +| **guide_scale** | Classifier-free guidance scale (1–20, default 7.5). | +| **prompt** | Text prompt. Leave empty for STAR's built-in quality prompt. | +| **solver_mode** | `fast` (optimized 15-step schedule) or `normal` (uniform schedule). | +| **max_chunk_len** | Max frames per chunk (4–128, default 32). Lower = less VRAM for long videos. | +| **seed** | Random seed for reproducibility. | +| **color_fix** | `adain` (match color stats), `wavelet` (preserve low-frequency color), or `none`. | + +## VRAM Requirements + +| Offload Mode | Approximate VRAM | Notes | +|---|---|---| +| disabled | ~39 GB | Fastest — everything on GPU | +| model | ~16 GB | Components swap to CPU between stages | +| aggressive | ~12 GB | Model offload + frame-by-frame VAE decode | + +Reducing `max_chunk_len` further lowers VRAM usage for long videos at the cost of slightly more processing time. + +## Model Weights + +Models are stored in `ComfyUI/models/star/` and auto-downloaded on first use: + +| Model | Use Case | Source | +|-------|----------|--------| +| `light_deg.pt` | Low-res video from the web, mild compression | [HuggingFace](https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/light_deg.pt) | +| `heavy_deg.pt` | Heavily compressed/degraded video | [HuggingFace](https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt) | + +The OpenCLIP text encoder and SVD temporal VAE are downloaded automatically by their respective libraries on first load. + +## Credits + +- [STAR](https://github.com/NJU-PCALab/STAR) by Rui Xie, Yinhong Liu et al. (Nanjing University) — ICCV 2025 +- Based on [I2VGen-XL](https://github.com/ali-vilab/VGen) and [VEnhancer](https://github.com/Vchitect/VEnhancer) + +## License + +This wrapper is MIT licensed. The STAR model weights follow their respective licenses (MIT for I2VGen-XL-based models). diff --git a/STAR b/STAR new file mode 160000 index 0000000..69b8bc5 --- /dev/null +++ b/STAR @@ -0,0 +1 @@ +Subproject commit 69b8bc53e4bb04267265fa9a288ee30882c5b2ed diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..39a8c6b --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/install.py b/install.py new file mode 100644 index 0000000..163fd12 --- /dev/null +++ b/install.py @@ -0,0 +1,10 @@ +import os +import subprocess +import sys + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") +subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", req_file]) + +star_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "STAR") +if not os.path.isdir(star_dir): + subprocess.check_call(["git", "clone", "https://github.com/NJU-PCALab/STAR.git", star_dir]) diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..e34de81 --- /dev/null +++ b/nodes.py @@ -0,0 +1,255 @@ +import os +import sys +import torch +import folder_paths +import comfy.model_management as mm + +# Register the "star" model folder so users can drop .pt weights there. +star_model_dir = os.path.join(folder_paths.models_dir, "star") +os.makedirs(star_model_dir, exist_ok=True) +folder_paths.folder_names_and_paths["star"] = ( + [star_model_dir], + folder_paths.supported_pt_extensions, +) + +# Put the cloned STAR repo on sys.path so its internal imports work. +STAR_REPO = os.path.join(os.path.dirname(os.path.realpath(__file__)), "STAR") +if STAR_REPO not in sys.path: + sys.path.insert(0, STAR_REPO) + +# Known models on HuggingFace that can be auto-downloaded. +HF_REPO = "SherryX/STAR" +HF_MODELS = { + "light_deg.pt": "I2VGen-XL-based/light_deg.pt", + "heavy_deg.pt": "I2VGen-XL-based/heavy_deg.pt", +} + + +def _get_model_list(): + """Return the union of files already on disk + known downloadable models.""" + on_disk = set(folder_paths.get_filename_list("star")) + available = set(HF_MODELS.keys()) + return sorted(on_disk | available) + + +def _ensure_model(model_name: str) -> str: + """Return the local path to model_name, downloading from HF if needed.""" + local = folder_paths.get_full_path("star", model_name) + if local is not None: + return local + + if model_name not in HF_MODELS: + raise FileNotFoundError( + f"Model '{model_name}' not found in {star_model_dir} and is not a known downloadable model." + ) + + from huggingface_hub import hf_hub_download + + print(f"[STAR] Downloading {model_name} from HuggingFace ({HF_REPO})...") + path = hf_hub_download( + repo_id=HF_REPO, + filename=HF_MODELS[model_name], + local_dir=star_model_dir, + ) + # hf_hub_download may place the file in a subdirectory; symlink into the + # star folder root so folder_paths can find it next time. + dest = os.path.join(star_model_dir, model_name) + if not os.path.exists(dest): + os.symlink(path, dest) + return dest + + +class STARModelLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (_get_model_list(), { + "tooltip": "STAR checkpoint to load. light_deg for mildly degraded video, heavy_deg for heavily degraded video. Auto-downloaded from HuggingFace on first use.", + }), + "precision": (["fp16", "bf16", "fp32"], { + "default": "fp16", + "tooltip": "Weight precision. fp16 is recommended (fastest, lowest VRAM). bf16 for newer GPUs. fp32 for maximum quality at 2x VRAM cost.", + }), + "offload": (["disabled", "model", "aggressive"], { + "default": "disabled", + "tooltip": "disabled: all on GPU (~39GB). model: swap UNet/VAE/CLIP to CPU when idle (~16GB). aggressive: model offload + single-frame VAE decode (~12GB).", + }), + } + } + + RETURN_TYPES = ("STAR_MODEL",) + RETURN_NAMES = ("star_model",) + FUNCTION = "load_model" + CATEGORY = "STAR" + DESCRIPTION = "Loads the STAR video super-resolution model (UNet+ControlNet, OpenCLIP text encoder, temporal VAE). All components are auto-downloaded on first use." + + def load_model(self, model_name, precision, offload="disabled"): + device = mm.get_torch_device() + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + dtype = dtype_map[precision] + + # Where to park models when not in use. + keep_on = device if offload == "disabled" else "cpu" + + model_path = _ensure_model(model_name) + + # ---- Text encoder (OpenCLIP ViT-H-14) ---- + from video_to_video.modules.embedder import FrozenOpenCLIPEmbedder + + text_encoder = FrozenOpenCLIPEmbedder( + device=device, pretrained="laion2b_s32b_b79k" + ) + text_encoder.model.to(device) + + # Pre-compute the negative prompt embedding used during sampling. + from video_to_video.utils.config import cfg + + negative_y = text_encoder(cfg.negative_prompt).detach() + + # Park text encoder after pre-computing embeddings. + text_encoder.model.to(keep_on) + + # ---- UNet + ControlNet ---- + from video_to_video.modules.unet_v2v import ControlledV2VUNet + + generator = ControlledV2VUNet() + load_dict = torch.load(model_path, map_location="cpu", weights_only=False) + if "state_dict" in load_dict: + load_dict = load_dict["state_dict"] + generator.load_state_dict(load_dict, strict=False) + del load_dict + generator = generator.to(device=keep_on, dtype=dtype) + generator.eval() + + # ---- Noise schedule + diffusion helper ---- + from video_to_video.diffusion.schedules_sdedit import noise_schedule + from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion + + sigmas = noise_schedule( + schedule="logsnr_cosine_interp", + n=1000, + zero_terminal_snr=True, + scale_min=2.0, + scale_max=4.0, + ) + diffusion = GaussianDiffusion(sigmas=sigmas) + + # ---- Temporal VAE (from HuggingFace diffusers) ---- + from diffusers import AutoencoderKLTemporalDecoder + + vae = AutoencoderKLTemporalDecoder.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid", + subfolder="vae", + variant="fp16", + ) + vae.eval() + vae.requires_grad_(False) + vae.to(keep_on) + + torch.cuda.empty_cache() + + star_model = { + "text_encoder": text_encoder, + "generator": generator, + "diffusion": diffusion, + "vae": vae, + "negative_y": negative_y, + "device": device, + "dtype": dtype, + "offload": offload, + } + return (star_model,) + + +class STARVideoSuperResolution: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "star_model": ("STAR_MODEL", { + "tooltip": "Connect from STAR Model Loader.", + }), + "images": ("IMAGE", { + "tooltip": "Input video frames (IMAGE batch). Can come from LoadImage, VHS LoadVideo, etc.", + }), + "upscale": ("INT", { + "default": 4, "min": 1, "max": 8, + "tooltip": "Upscale factor applied to the input resolution. 4x is the default. Higher values need more VRAM.", + }), + "steps": ("INT", { + "default": 15, "min": 1, "max": 100, + "tooltip": "Number of denoising steps. Ignored in 'fast' solver mode (hardcoded 15). More steps = better quality but slower.", + }), + "guide_scale": ("FLOAT", { + "default": 7.5, "min": 1.0, "max": 20.0, "step": 0.5, + "tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strongly. 7.5 is a good default.", + }), + "prompt": ("STRING", { + "default": "", "multiline": True, + "tooltip": "Text prompt describing the desired output. Leave empty to use STAR's built-in quality prompt.", + }), + "solver_mode": (["fast", "normal"], { + "default": "fast", + "tooltip": "fast: optimized 15-step schedule (4 coarse + 11 fine). normal: uniform schedule using the steps parameter.", + }), + "max_chunk_len": ("INT", { + "default": 32, "min": 4, "max": 128, + "tooltip": "Max frames processed at once. Lower values reduce VRAM usage for long videos. Chunks overlap by 50%.", + }), + "seed": ("INT", { + "default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "Random seed for reproducible results.", + }), + "color_fix": (["adain", "wavelet", "none"], { + "default": "adain", + "tooltip": "Post-processing color correction. adain: match color stats from input. wavelet: preserve input low-frequency color. none: no correction.", + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "upscale_video" + CATEGORY = "STAR" + DESCRIPTION = "Upscale video frames using STAR diffusion-based super-resolution." + + def upscale_video( + self, + star_model, + images, + upscale, + steps, + guide_scale, + prompt, + solver_mode, + max_chunk_len, + seed, + color_fix, + ): + from .star_pipeline import run_star_inference + + result = run_star_inference( + star_model=star_model, + images=images, + upscale=upscale, + steps=steps, + guide_scale=guide_scale, + prompt=prompt, + solver_mode=solver_mode, + max_chunk_len=max_chunk_len, + seed=seed, + color_fix=color_fix, + ) + return (result,) + + +NODE_CLASS_MAPPINGS = { + "STARModelLoader": STARModelLoader, + "STARVideoSuperResolution": STARVideoSuperResolution, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "STARModelLoader": "STAR Model Loader", + "STARVideoSuperResolution": "STAR Video Super-Resolution", +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..eda67af --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +easydict +einops +open-clip-torch +torchsde +diffusers>=0.25.0 +huggingface_hub diff --git a/star_pipeline.py b/star_pipeline.py new file mode 100644 index 0000000..d79aec9 --- /dev/null +++ b/star_pipeline.py @@ -0,0 +1,290 @@ +import torch +import torch.nn.functional as F +import torch.amp +from einops import rearrange + +import comfy.utils +import comfy.model_management as mm + +from video_to_video.video_to_video_model import pad_to_fit, make_chunks +from video_to_video.utils.config import cfg + + +# --------------------------------------------------------------------------- +# Tensor format conversions +# --------------------------------------------------------------------------- + +def comfyui_to_star_frames(images: torch.Tensor) -> torch.Tensor: + """Convert ComfyUI IMAGE batch to STAR input format. + + ComfyUI: [N, H, W, 3] float32 in [0, 1] + STAR: [N, 3, H, W] float32 in [-1, 1] + """ + t = images.permute(0, 3, 1, 2) # [N,3,H,W] + t = t * 2.0 - 1.0 + return t + + +def star_output_to_comfyui(video: torch.Tensor) -> torch.Tensor: + """Convert STAR output to ComfyUI IMAGE batch. + + STAR output: [1, 3, F, H, W] float32 in [-1, 1] + ComfyUI: [F, H, W, 3] float32 in [0, 1] + """ + v = video.squeeze(0) # [3, F, H, W] + v = v.permute(1, 2, 3, 0) # [F, H, W, 3] + v = (v + 1.0) / 2.0 + v = v.clamp(0.0, 1.0) + return v + + +# --------------------------------------------------------------------------- +# VAE helpers (mirror VideoToVideo_sr methods) +# --------------------------------------------------------------------------- + +def vae_encode(vae, t, chunk_size=1): + """Encode [B, F, C, H, W] video tensor to latent space.""" + num_f = t.shape[1] + t = rearrange(t, "b f c h w -> (b f) c h w") + z_list = [] + for ind in range(0, t.shape[0], chunk_size): + z_list.append(vae.encode(t[ind : ind + chunk_size]).latent_dist.sample()) + z = torch.cat(z_list, dim=0) + z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f) + return z * vae.config.scaling_factor + + +def vae_decode_chunk(vae, z, chunk_size=3): + """Decode latent [B, C, F, H, W] back to pixel frames.""" + z = rearrange(z, "b c f h w -> (b f) c h w") + video = [] + for ind in range(0, z.shape[0], chunk_size): + chunk = z[ind : ind + chunk_size] + num_f = chunk.shape[0] + decoded = vae.decode(chunk / vae.config.scaling_factor, num_frames=num_f).sample + video.append(decoded) + video = torch.cat(video) + return video + + +# --------------------------------------------------------------------------- +# Color correction wrappers +# --------------------------------------------------------------------------- + +def apply_color_fix(output_frames, input_frames_star, method): + """Apply colour correction to the upscaled output. + + output_frames: [F, H, W, 3] float [0, 1] (ComfyUI format) + input_frames_star: [F, 3, H, W] float [-1, 1] (STAR format) + method: "adain" | "wavelet" | "none" + """ + if method == "none": + return output_frames + + from video_super_resolution.color_fix import adain_color_fix, wavelet_color_fix + + # Resize input to match output spatial size for stats transfer + _, h_out, w_out, _ = output_frames.shape + source = F.interpolate( + input_frames_star, size=(h_out, w_out), mode="bilinear", align_corners=False + ) + + # The color_fix functions expect: + # target: [T, H, W, C] in [0, 255] + # source: [T, C, H, W] in [-1, 1] + target = output_frames * 255.0 + + if method == "adain": + result = adain_color_fix(target, source) + else: + result = wavelet_color_fix(target, source) + + return (result / 255.0).clamp(0.0, 1.0) + + +# --------------------------------------------------------------------------- +# Progress-bar integration via trange monkey-patch +# --------------------------------------------------------------------------- + +def _make_progress_trange(pbar, total_steps): + """Return a drop-in replacement for tqdm.auto.trange that drives *pbar*.""" + from tqdm.auto import trange as _real_trange + + def _progress_trange(*args, **kwargs): + kwargs["disable"] = True # silence console output + for val in _real_trange(*args, **kwargs): + yield val + pbar.update(1) + + return _progress_trange + + +# --------------------------------------------------------------------------- +# Main inference entry point +# --------------------------------------------------------------------------- + +def _move(module, device): + """Move a nn.Module to device and free source memory.""" + module.to(device) + if device == "cpu" or (isinstance(device, torch.device) and device.type == "cpu"): + torch.cuda.empty_cache() + + +def run_star_inference( + star_model: dict, + images: torch.Tensor, + upscale: int = 4, + steps: int = 15, + guide_scale: float = 7.5, + prompt: str = "", + solver_mode: str = "fast", + max_chunk_len: int = 32, + seed: int = 0, + color_fix: str = "adain", +) -> torch.Tensor: + """Run STAR video super-resolution and return ComfyUI IMAGE batch.""" + + device = star_model["device"] + dtype = star_model["dtype"] + text_encoder = star_model["text_encoder"] + generator = star_model["generator"] + diffusion = star_model["diffusion"] + vae = star_model["vae"] + negative_y = star_model["negative_y"] + offload = star_model.get("offload", "disabled") + + # In aggressive mode use smaller VAE chunks to cut peak VRAM. + vae_enc_chunk = 1 + vae_dec_chunk = 3 + if offload == "aggressive": + vae_dec_chunk = 1 + + total_noise_levels = 1000 + + # -- Convert ComfyUI frames to STAR format -- + video_data = comfyui_to_star_frames(images) # [F, 3, H, W] + + # Keep a copy at input resolution (on CPU) for colour correction later + input_frames_star = video_data.clone().cpu() + + frames_num, _, orig_h, orig_w = video_data.shape + target_h = orig_h * upscale + target_w = orig_w * upscale + + # -- Bilinear upscale to target resolution -- + video_data = F.interpolate(video_data, size=(target_h, target_w), mode="bilinear", align_corners=False) + _, _, h, w = video_data.shape + + # -- Pad to model-friendly resolution -- + padding = pad_to_fit(h, w) + video_data = F.pad(video_data, padding, "constant", 1) + + video_data = video_data.unsqueeze(0).to(device) # [1, F, 3, H_pad, W_pad] + + # ---- Stage 1: Text encoding ---- + if offload != "disabled": + text_encoder.model.to(device) + text_encoder.device = device + text = prompt if prompt.strip() else cfg.positive_prompt + y = text_encoder(text).detach() + if offload != "disabled": + text_encoder.model.to("cpu") + text_encoder.device = "cpu" + torch.cuda.empty_cache() + + # -- Diffusion sampling (autocast needed for fp16 VAE / UNet) -- + with torch.amp.autocast("cuda"): + # ---- Stage 2: VAE encode ---- + if offload != "disabled": + _move(vae, device) + video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk) + if offload != "disabled": + _move(vae, "cpu") + # Free the full-res pixel tensor — only latents needed from here. + del video_data + torch.cuda.empty_cache() + + t = torch.LongTensor([total_noise_levels - 1]).to(device) + noised_lr = diffusion.diffuse(video_data_feature, t) + + model_kwargs = [{"y": y}, {"y": negative_y}, {"hint": video_data_feature}] + + torch.cuda.empty_cache() + + chunk_inds = ( + make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) + if frames_num > max_chunk_len + else None + ) + # Need at least 2 chunks; a single chunk causes IndexError in + # model_chunk_fn when it accesses chunk_inds[1]. + if chunk_inds is not None and len(chunk_inds) < 2: + chunk_inds = None + + # Monkey-patch trange for progress reporting + import video_to_video.diffusion.solvers_sdedit as _solvers_mod + _orig_trange = _solvers_mod.trange + + # Calculate actual number of sigma steps for progress bar + # (matches logic inside GaussianDiffusion.sample_sr) + if solver_mode == "fast": + num_sigma_steps = 14 # 4 coarse + 11 fine = 15 sigmas, trange iterates len-1 = 14 + else: + num_sigma_steps = steps + pbar = comfy.utils.ProgressBar(num_sigma_steps) + _solvers_mod.trange = _make_progress_trange(pbar, num_sigma_steps) + + # ---- Stage 3: Diffusion (UNet) ---- + if offload != "disabled": + _move(generator, device) + try: + torch.manual_seed(seed) + + gen_vid = diffusion.sample_sr( + noise=noised_lr, + model=generator, + model_kwargs=model_kwargs, + guide_scale=guide_scale, + guide_rescale=0.2, + solver="dpmpp_2m_sde", + solver_mode=solver_mode, + return_intermediate=None, + steps=steps, + t_max=total_noise_levels - 1, + t_min=0, + discretization="trailing", + chunk_inds=chunk_inds, + ) + finally: + _solvers_mod.trange = _orig_trange + if offload != "disabled": + _move(generator, "cpu") + + # Free latents that are no longer needed. + del noised_lr, video_data_feature, model_kwargs + torch.cuda.empty_cache() + + # ---- Stage 4: VAE decode ---- + if offload != "disabled": + _move(vae, device) + vid_tensor_gen = vae_decode_chunk(vae, gen_vid, chunk_size=vae_dec_chunk) + if offload != "disabled": + _move(vae, "cpu") + + # -- Remove padding -- + w1, w2, h1, h2 = padding + vid_tensor_gen = vid_tensor_gen[:, :, h1 : h + h1, w1 : w + w1] + + # -- Reshape to [B, C, F, H, W] then convert to ComfyUI format -- + gen_video = rearrange(vid_tensor_gen, "(b f) c h w -> b c f h w", b=1) + gen_video = gen_video.float().cpu() + + result = star_output_to_comfyui(gen_video) # [F, H, W, 3] + + # -- Color correction -- + result = apply_color_fix(result, input_frames_star, color_fix) + + torch.cuda.empty_cache() + mm.soft_empty_cache() + + return result