Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)

Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines,
v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based
processing with temporal overlap and crossfade blending for bounded RAM on long videos.

Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 15:12:33 +01:00
parent e253cb244e
commit 0fecfcee37
23 changed files with 5733 additions and 9 deletions

407
nodes.py
View File

@@ -8,7 +8,7 @@ import torch
import folder_paths
from comfy.utils import ProgressBar
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel
from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
@@ -1507,3 +1507,408 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
result = result[1:] # skip duplicate boundary frame
return (result, model)
# ---------------------------------------------------------------------------
# FlashVSR nodes (4x video super-resolution)
# ---------------------------------------------------------------------------
FLASHVSR_HF_REPO = "1038lab/FlashVSR"
FLASHVSR_REQUIRED_FILES = [
"FlashVSR1_1.safetensors",
"Wan2.1_VAE.safetensors",
"LQ_proj_in.safetensors",
"TCDecoder.safetensors",
"Prompt.safetensors",
]
FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "flashvsr")
if not os.path.exists(FLASHVSR_MODEL_DIR):
os.makedirs(FLASHVSR_MODEL_DIR, exist_ok=True)
def download_flashvsr_models(model_dir):
"""Download FlashVSR checkpoints from HuggingFace if missing."""
from huggingface_hub import snapshot_download
missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if not missing:
return
logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...")
snapshot_download(
repo_id=FLASHVSR_HF_REPO,
local_dir=model_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
still_missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if still_missing:
raise FileNotFoundError(
f"[FlashVSR] Failed to download: {', '.join(still_missing)}. "
f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}"
)
logger.info("[FlashVSR] All checkpoints downloaded successfully.")
class _FlashVSRProgressBar:
"""Wrap an iterable with a ComfyUI ProgressBar."""
def __init__(self, total, pbar, step_ref):
self.total = total
self.pbar = pbar
self.step_ref = step_ref
def __call__(self, iterable):
return self._Wrapper(iterable, self.pbar, self.step_ref)
class _Wrapper:
def __init__(self, iterable, pbar, step_ref):
self.iterable = iterable
self.pbar = pbar
self.step_ref = step_ref
self._iter = iter(iterable)
def __iter__(self):
return self
def __next__(self):
val = next(self._iter)
self.step_ref[0] += 1
self.pbar.update_absolute(self.step_ref[0])
return val
def __len__(self):
return len(self.iterable)
class LoadFlashVSRModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["tiny", "tiny-long", "full"], {
"default": "tiny",
"tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. "
"Tiny-long: streaming TCDecoder, lowest VRAM for long videos. "
"Full: standard VAE decode, highest quality but more VRAM.",
}),
"precision": (["bf16", "fp16"], {
"default": "bf16",
"tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.",
}),
}
}
RETURN_TYPES = ("FLASHVSR_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/FlashVSR"
def load_model(self, mode, precision):
download_flashvsr_models(FLASHVSR_MODEL_DIR)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
wrapper = FlashVSRModel(
model_dir=FLASHVSR_MODEL_DIR,
mode=mode,
device=device,
dtype=dtype,
)
logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})")
return (wrapper,)
class FlashVSRUpscale:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from the Load FlashVSR Model node.",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.",
}),
"frame_chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). "
"Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.",
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.",
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "KV cache ratio. Higher = better quality, more VRAM.",
}),
"local_range": ([9, 11], {
"default": 9,
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.",
}),
"color_fix": ("BOOLEAN", {
"default": True,
"tooltip": "Apply color correction to prevent color shifts from the diffusion process.",
}),
"unload_dit": ("BOOLEAN", {
"default": False,
"tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.",
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Random seed for the diffusion process.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, scale, frame_chunk_size,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
num_frames = images.shape[0]
if num_frames < FlashVSRModel.MIN_FRAMES:
raise ValueError(
f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}"
)
tile_size = (tile_size_h, tile_size_w)
# Build frame chunks
if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames:
chunks = [(0, num_frames)]
else:
chunks = []
start = 0
while start < num_frames:
end = min(start + frame_chunk_size, num_frames)
chunks.append((start, end))
if end == num_frames:
break
start = end
# If the last chunk is too small, merge it into the previous one
if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES:
prev_start = chunks[-2][0]
last_end = chunks[-1][1]
chunks = chunks[:-2]
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
total_steps = 0
for cs, ce in chunks:
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
remainder = (padded_n + 2 - 5) % 8
if remainder != 0:
padded_n += 8 - remainder
total_steps += max(1, (padded_n - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
step_ref = [0]
progress = _FlashVSRProgressBar(total_steps, pbar, step_ref)
model.load_to_device()
result_chunks = []
for chunk_start, chunk_end in chunks:
chunk_frames = images[chunk_start:chunk_end]
chunk_result = model.upscale(
chunk_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
progress_bar_cmd=progress,
)
result_chunks.append(chunk_result)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
return (torch.cat(result_chunks, dim=0),)
class FlashVSRSegmentUpscale:
"""Process a numbered segment with temporal overlap and crossfade blending.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through forces sequential execution so each segment
saves and frees RAM before the next starts.
Crossfade blending within the overlap region:
- First (overlap - blend) frames: warmup only, discarded from output
- Last blend frames: linear alpha crossfade with previous segment's tail
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Full input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from Load FlashVSR Model. "
"Chain the model output to the next segment node for sequential execution.",
}),
"segment_index": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based).",
}),
"segment_size": ("INT", {
"default": 100, "min": 21, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment.",
}),
"overlap_frames": ("INT", {
"default": 8, "min": 0, "max": 100, "step": 1,
"tooltip": "Number of overlapping frames between adjacent segments. "
"These frames provide temporal context and crossfade blending.",
}),
"blend_frames": ("INT", {
"default": 4, "min": 0, "max": 50, "step": 1,
"tooltip": "Number of frames within the overlap region to crossfade. "
"Must be <= overlap_frames. The rest of the overlap is warmup (discarded).",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor.",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"local_range": ([9, 11], {
"default": 9,
}),
"color_fix": ("BOOLEAN", {
"default": True,
}),
"unload_dit": ("BOOLEAN", {
"default": False,
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
}),
}
}
RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, segment_index, segment_size,
overlap_frames, blend_frames, scale,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
total_input = images.shape[0]
blend_frames = min(blend_frames, overlap_frames)
# Clear stale overlap data from previous workflow runs
if segment_index == 0:
model._overlap_tail = None
# Compute segment boundaries
stride = segment_size - overlap_frames
start = segment_index * stride
end = min(start + segment_size, total_input)
if start >= total_input:
# Past the end
return (images[:1], model)
# Ensure minimum frame count
actual_size = end - start
if actual_size < FlashVSRModel.MIN_FRAMES:
start = max(0, end - FlashVSRModel.MIN_FRAMES)
actual_size = end - start
segment_frames = images[start:end]
tile_size = (tile_size_h, tile_size_w)
model.load_to_device()
result = model.upscale(
segment_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
# Handle crossfade blending with previous segment's tail
if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'):
prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU
# The overlap region in result: first overlap_frames of the upscaled output
# Within overlap: first (overlap - blend) frames are warmup (discard)
# last blend_frames frames: crossfade with prev_tail
warmup = overlap_frames - blend_frames
if blend_frames > 0 and prev_tail is not None:
# Linear alpha ramp for crossfade
alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1)
blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames]
result = torch.cat([blended, result[overlap_frames:]], dim=0)
else:
result = result[overlap_frames:]
elif segment_index > 0 and overlap_frames > 0:
# No previous tail stored, just skip overlap
result = result[overlap_frames:]
# Store tail frames for next segment's crossfade
if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames:
model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16)
else:
model._overlap_tail = None
return (result, model)