Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)
Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines, v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based processing with temporal overlap and crossfade blending for bounded RAM on long videos. Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
241
inference.py
241
inference.py
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .bim_vfi_arch import BiMVFI
|
||||
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
||||
@@ -621,3 +624,241 @@ class GIMMVFIModel:
|
||||
results.append(torch.clamp(unpadded, 0, 1))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlashVSR model wrapper (4x video super-resolution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FlashVSRModel:
|
||||
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
|
||||
|
||||
Supports three pipeline modes:
|
||||
- full: Standard VAE decode, highest quality
|
||||
- tiny: TCDecoder decode, faster
|
||||
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
|
||||
"""
|
||||
|
||||
# Minimum input frame count required by the pipeline
|
||||
MIN_FRAMES = 21
|
||||
|
||||
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
|
||||
from safetensors.torch import load_file
|
||||
from .flashvsr_arch import (
|
||||
ModelManager, FlashVSRFullPipeline,
|
||||
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
|
||||
)
|
||||
from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj
|
||||
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
|
||||
|
||||
self.mode = mode
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
|
||||
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
|
||||
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
|
||||
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
|
||||
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
|
||||
|
||||
mm = ModelManager(torch_dtype=dtype, device="cpu")
|
||||
|
||||
if mode == "full":
|
||||
mm.load_models([dit_path, vae_path])
|
||||
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
|
||||
self.pipe.vae.model.encoder = None
|
||||
self.pipe.vae.model.conv1 = None
|
||||
else:
|
||||
mm.load_models([dit_path])
|
||||
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
|
||||
self.pipe = Pipeline.from_model_manager(mm, device=device)
|
||||
self.pipe.TCDecoder = build_tcdecoder(
|
||||
[512, 256, 128, 128], device, dtype, 16 + 768,
|
||||
)
|
||||
self.pipe.TCDecoder.load_state_dict(
|
||||
load_file(tcd_path, device=device), strict=False,
|
||||
)
|
||||
self.pipe.TCDecoder.clean_mem()
|
||||
|
||||
# LQ frame projection
|
||||
self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype)
|
||||
if os.path.exists(lq_path):
|
||||
lq_sd = load_file(lq_path, device="cpu")
|
||||
cleaned = {}
|
||||
for k, v in lq_sd.items():
|
||||
cleaned[k.removeprefix("LQ_proj_in.")] = v
|
||||
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
|
||||
self.pipe.denoising_model().LQ_proj_in.to(device)
|
||||
|
||||
self.pipe.to(device, dtype)
|
||||
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
self.pipe.init_cross_kv(prompt_path=prompt_path)
|
||||
self.pipe.load_models_to_device([]) # offload to CPU
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
self.pipe.device = device
|
||||
return self
|
||||
|
||||
def load_to_device(self):
|
||||
"""Load models to the compute device for inference."""
|
||||
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
|
||||
self.pipe.load_models_to_device(names)
|
||||
|
||||
def offload(self):
|
||||
"""Offload models to CPU."""
|
||||
self.pipe.load_models_to_device([])
|
||||
|
||||
def clear_caches(self):
|
||||
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
|
||||
self.pipe.denoising_model().LQ_proj_in.clear_cache()
|
||||
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
|
||||
self.pipe.vae.clear_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame preprocessing / postprocessing helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _compute_dims(w, h, scale, align=128):
|
||||
sw, sh = w * scale, h * scale
|
||||
tw = math.ceil(sw / align) * align
|
||||
th = math.ceil(sh / align) * align
|
||||
return sw, sh, tw, th
|
||||
|
||||
@staticmethod
|
||||
def _pad_video_5d(video):
|
||||
"""Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline.
|
||||
|
||||
Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring
|
||||
the pipeline's streaming loop gets correct iteration counts.
|
||||
"""
|
||||
tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1)
|
||||
video = torch.cat([video, tail], dim=2)
|
||||
added = 0
|
||||
remainder = (video.shape[2] + 2 - 5) % 8
|
||||
if remainder != 0:
|
||||
added = 8 - remainder
|
||||
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
return video, added
|
||||
|
||||
@staticmethod
|
||||
def _restore_video_sequence(result, added_frames, expected):
|
||||
"""Strip padding and warmup frames from the output."""
|
||||
if added_frames > 0 and result.shape[0] > added_frames:
|
||||
result = result[:-added_frames]
|
||||
# Strip the first 2 pipeline warmup frames
|
||||
if result.shape[0] > 2:
|
||||
result = result[2:]
|
||||
# Adjust to exact expected count
|
||||
if result.shape[0] > expected:
|
||||
result = result[:expected]
|
||||
elif result.shape[0] < expected:
|
||||
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
|
||||
result = torch.cat([result, pad], dim=0)
|
||||
return result
|
||||
|
||||
def _prepare_video(self, frames, scale):
|
||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1].
|
||||
|
||||
Bicubic-upscales each frame to the target resolution, normalizes to
|
||||
[-1, 1], then applies temporal padding for the pipeline.
|
||||
|
||||
Returns:
|
||||
video: [1, C, F_padded, H, W] tensor
|
||||
th, tw: padded spatial dimensions
|
||||
nf: padded frame count (= video.shape[2])
|
||||
sh, sw: actual (unpadded) spatial dimensions
|
||||
added: number of alignment-padding frames added
|
||||
"""
|
||||
N, H, W, C = frames.shape
|
||||
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
||||
|
||||
processed = []
|
||||
for i in range(N):
|
||||
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
||||
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
||||
pad_h, pad_w = th - sh, tw - sw
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
upscaled = F.pad(upscaled, (0, pad_w, 0, pad_h), mode='replicate')
|
||||
normalized = upscaled * 2.0 - 1.0
|
||||
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
||||
|
||||
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
||||
|
||||
# Apply temporal padding (tail + alignment)
|
||||
video, added = self._pad_video_5d(video)
|
||||
nf = video.shape[2]
|
||||
|
||||
return video, th, tw, nf, sh, sw, added
|
||||
|
||||
@staticmethod
|
||||
def _to_frames(video):
|
||||
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
|
||||
from einops import rearrange
|
||||
v = video.squeeze(0) if video.dim() == 5 else video
|
||||
v = rearrange(v, "C F H W -> F H W C")
|
||||
return (v.float() + 1.0) / 2.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main upscale method
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
||||
topk_ratio=2.0, kv_ratio=2.0, local_range=9,
|
||||
color_fix=True, unload_dit=False, seed=1,
|
||||
progress_bar_cmd=None):
|
||||
"""Upscale video frames with FlashVSR.
|
||||
|
||||
Args:
|
||||
frames: [F, H, W, C] float32 [0, 1] with F >= 21
|
||||
scale: Upscaling factor (2 or 4)
|
||||
tiled: Enable VAE tiled decode (saves VRAM)
|
||||
tile_size: (H, W) tile size for VAE tiling
|
||||
topk_ratio: Sparse attention ratio (higher = faster, less detail)
|
||||
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
|
||||
local_range: Local attention window (9=sharp, 11=stable)
|
||||
color_fix: Apply wavelet color correction
|
||||
unload_dit: Offload DiT before VAE decode (saves VRAM)
|
||||
seed: Random seed
|
||||
progress_bar_cmd: Callable wrapping an iterable for progress display
|
||||
|
||||
Returns:
|
||||
[F, H*scale, W*scale, C] float32 [0, 1]
|
||||
"""
|
||||
if progress_bar_cmd is None:
|
||||
from tqdm import tqdm
|
||||
progress_bar_cmd = tqdm
|
||||
|
||||
original_count = frames.shape[0]
|
||||
|
||||
# Prepare video tensor (bicubic upscale + pad)
|
||||
video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale)
|
||||
|
||||
# Move LQ video to compute device (except for "long" mode which streams)
|
||||
if "long" not in self.pipe.__class__.__name__.lower():
|
||||
video = video.to(self.pipe.device)
|
||||
|
||||
# Run pipeline
|
||||
out = self.pipe(
|
||||
prompt="", negative_prompt="",
|
||||
cfg_scale=1.0, num_inference_steps=1,
|
||||
seed=seed, tiled=tiled, tile_size=tile_size,
|
||||
progress_bar_cmd=progress_bar_cmd,
|
||||
LQ_video=video,
|
||||
num_frames=nf, height=th, width=tw,
|
||||
is_full_block=False, if_buffer=True,
|
||||
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
|
||||
kv_ratio=kv_ratio, local_range=local_range,
|
||||
color_fix=color_fix, unload_dit=unload_dit,
|
||||
)
|
||||
|
||||
# Convert to ComfyUI format and crop spatial padding
|
||||
result = self._to_frames(out).cpu()[:, :sh, :sw, :]
|
||||
|
||||
# Restore original frame count (strip temporal padding + warmup)
|
||||
result = self._restore_video_sequence(result, added_frames, original_count)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user