Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)
Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines, v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based processing with temporal overlap and crossfade blending for bounded RAM on long videos. Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
407
nodes.py
407
nodes.py
@@ -8,7 +8,7 @@ import torch
|
||||
import folder_paths
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel
|
||||
from .bim_vfi_arch import clear_backwarp_cache
|
||||
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
||||
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
||||
@@ -1507,3 +1507,408 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
|
||||
return (result, model)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlashVSR nodes (4x video super-resolution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FLASHVSR_HF_REPO = "1038lab/FlashVSR"
|
||||
FLASHVSR_REQUIRED_FILES = [
|
||||
"FlashVSR1_1.safetensors",
|
||||
"Wan2.1_VAE.safetensors",
|
||||
"LQ_proj_in.safetensors",
|
||||
"TCDecoder.safetensors",
|
||||
"Prompt.safetensors",
|
||||
]
|
||||
|
||||
FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "flashvsr")
|
||||
if not os.path.exists(FLASHVSR_MODEL_DIR):
|
||||
os.makedirs(FLASHVSR_MODEL_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def download_flashvsr_models(model_dir):
|
||||
"""Download FlashVSR checkpoints from HuggingFace if missing."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
missing = [f for f in FLASHVSR_REQUIRED_FILES
|
||||
if not os.path.exists(os.path.join(model_dir, f))]
|
||||
if not missing:
|
||||
return
|
||||
|
||||
logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id=FLASHVSR_HF_REPO,
|
||||
local_dir=model_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
resume_download=True,
|
||||
)
|
||||
|
||||
still_missing = [f for f in FLASHVSR_REQUIRED_FILES
|
||||
if not os.path.exists(os.path.join(model_dir, f))]
|
||||
if still_missing:
|
||||
raise FileNotFoundError(
|
||||
f"[FlashVSR] Failed to download: {', '.join(still_missing)}. "
|
||||
f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}"
|
||||
)
|
||||
logger.info("[FlashVSR] All checkpoints downloaded successfully.")
|
||||
|
||||
|
||||
class _FlashVSRProgressBar:
|
||||
"""Wrap an iterable with a ComfyUI ProgressBar."""
|
||||
|
||||
def __init__(self, total, pbar, step_ref):
|
||||
self.total = total
|
||||
self.pbar = pbar
|
||||
self.step_ref = step_ref
|
||||
|
||||
def __call__(self, iterable):
|
||||
return self._Wrapper(iterable, self.pbar, self.step_ref)
|
||||
|
||||
class _Wrapper:
|
||||
def __init__(self, iterable, pbar, step_ref):
|
||||
self.iterable = iterable
|
||||
self.pbar = pbar
|
||||
self.step_ref = step_ref
|
||||
self._iter = iter(iterable)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
val = next(self._iter)
|
||||
self.step_ref[0] += 1
|
||||
self.pbar.update_absolute(self.step_ref[0])
|
||||
return val
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iterable)
|
||||
|
||||
|
||||
class LoadFlashVSRModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mode": (["tiny", "tiny-long", "full"], {
|
||||
"default": "tiny",
|
||||
"tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. "
|
||||
"Tiny-long: streaming TCDecoder, lowest VRAM for long videos. "
|
||||
"Full: standard VAE decode, highest quality but more VRAM.",
|
||||
}),
|
||||
"precision": (["bf16", "fp16"], {
|
||||
"default": "bf16",
|
||||
"tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLASHVSR_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def load_model(self, mode, precision):
|
||||
download_flashvsr_models(FLASHVSR_MODEL_DIR)
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
|
||||
|
||||
wrapper = FlashVSRModel(
|
||||
model_dir=FLASHVSR_MODEL_DIR,
|
||||
mode=mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})")
|
||||
return (wrapper,)
|
||||
|
||||
|
||||
class FlashVSRUpscale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Input video frames. Minimum 21 frames required.",
|
||||
}),
|
||||
"model": ("FLASHVSR_MODEL", {
|
||||
"tooltip": "FlashVSR model from the Load FlashVSR Model node.",
|
||||
}),
|
||||
"scale": ("INT", {
|
||||
"default": 4, "min": 2, "max": 4, "step": 2,
|
||||
"tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.",
|
||||
}),
|
||||
"frame_chunk_size": ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). "
|
||||
"Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).",
|
||||
}),
|
||||
"tiled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.",
|
||||
}),
|
||||
"tile_size_h": ("INT", {
|
||||
"default": 60, "min": 16, "max": 256, "step": 4,
|
||||
"tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.",
|
||||
}),
|
||||
"tile_size_w": ("INT", {
|
||||
"default": 104, "min": 16, "max": 256, "step": 4,
|
||||
"tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.",
|
||||
}),
|
||||
"topk_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
|
||||
}),
|
||||
"kv_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
"tooltip": "KV cache ratio. Higher = better quality, more VRAM.",
|
||||
}),
|
||||
"local_range": ([9, 11], {
|
||||
"default": 9,
|
||||
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.",
|
||||
}),
|
||||
"color_fix": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply color correction to prevent color shifts from the diffusion process.",
|
||||
}),
|
||||
"unload_dit": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.",
|
||||
}),
|
||||
"seed": ("INT", {
|
||||
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "Random seed for the diffusion process.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def upscale(self, images, model, scale, frame_chunk_size,
|
||||
tiled, tile_size_h, tile_size_w,
|
||||
topk_ratio, kv_ratio, local_range,
|
||||
color_fix, unload_dit, seed):
|
||||
num_frames = images.shape[0]
|
||||
if num_frames < FlashVSRModel.MIN_FRAMES:
|
||||
raise ValueError(
|
||||
f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}"
|
||||
)
|
||||
|
||||
tile_size = (tile_size_h, tile_size_w)
|
||||
|
||||
# Build frame chunks
|
||||
if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames:
|
||||
chunks = [(0, num_frames)]
|
||||
else:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < num_frames:
|
||||
end = min(start + frame_chunk_size, num_frames)
|
||||
chunks.append((start, end))
|
||||
if end == num_frames:
|
||||
break
|
||||
start = end
|
||||
# If the last chunk is too small, merge it into the previous one
|
||||
if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES:
|
||||
prev_start = chunks[-2][0]
|
||||
last_end = chunks[-1][1]
|
||||
chunks = chunks[:-2]
|
||||
chunks.append((prev_start, last_end))
|
||||
|
||||
# Estimate total pipeline steps for progress bar
|
||||
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
|
||||
total_steps = 0
|
||||
for cs, ce in chunks:
|
||||
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
|
||||
remainder = (padded_n + 2 - 5) % 8
|
||||
if remainder != 0:
|
||||
padded_n += 8 - remainder
|
||||
total_steps += max(1, (padded_n - 1) // 8 - 2)
|
||||
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
progress = _FlashVSRProgressBar(total_steps, pbar, step_ref)
|
||||
|
||||
model.load_to_device()
|
||||
|
||||
result_chunks = []
|
||||
for chunk_start, chunk_end in chunks:
|
||||
chunk_frames = images[chunk_start:chunk_end]
|
||||
|
||||
chunk_result = model.upscale(
|
||||
chunk_frames,
|
||||
scale=scale, tiled=tiled, tile_size=tile_size,
|
||||
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
|
||||
local_range=local_range, color_fix=color_fix,
|
||||
unload_dit=unload_dit, seed=seed,
|
||||
progress_bar_cmd=progress,
|
||||
)
|
||||
result_chunks.append(chunk_result)
|
||||
model.clear_caches()
|
||||
|
||||
model.offload()
|
||||
from .flashvsr_arch.models.utils import clean_vram
|
||||
clean_vram()
|
||||
|
||||
return (torch.cat(result_chunks, dim=0),)
|
||||
|
||||
|
||||
class FlashVSRSegmentUpscale:
|
||||
"""Process a numbered segment with temporal overlap and crossfade blending.
|
||||
|
||||
Chain multiple instances with Save nodes between them to bound peak RAM.
|
||||
The model pass-through forces sequential execution so each segment
|
||||
saves and frees RAM before the next starts.
|
||||
|
||||
Crossfade blending within the overlap region:
|
||||
- First (overlap - blend) frames: warmup only, discarded from output
|
||||
- Last blend frames: linear alpha crossfade with previous segment's tail
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Full input video frames. Minimum 21 frames required.",
|
||||
}),
|
||||
"model": ("FLASHVSR_MODEL", {
|
||||
"tooltip": "FlashVSR model from Load FlashVSR Model. "
|
||||
"Chain the model output to the next segment node for sequential execution.",
|
||||
}),
|
||||
"segment_index": ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Which segment to process (0-based).",
|
||||
}),
|
||||
"segment_size": ("INT", {
|
||||
"default": 100, "min": 21, "max": 10000, "step": 1,
|
||||
"tooltip": "Number of input frames per segment.",
|
||||
}),
|
||||
"overlap_frames": ("INT", {
|
||||
"default": 8, "min": 0, "max": 100, "step": 1,
|
||||
"tooltip": "Number of overlapping frames between adjacent segments. "
|
||||
"These frames provide temporal context and crossfade blending.",
|
||||
}),
|
||||
"blend_frames": ("INT", {
|
||||
"default": 4, "min": 0, "max": 50, "step": 1,
|
||||
"tooltip": "Number of frames within the overlap region to crossfade. "
|
||||
"Must be <= overlap_frames. The rest of the overlap is warmup (discarded).",
|
||||
}),
|
||||
"scale": ("INT", {
|
||||
"default": 4, "min": 2, "max": 4, "step": 2,
|
||||
"tooltip": "Upscaling factor.",
|
||||
}),
|
||||
"tiled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable VAE tiled decode.",
|
||||
}),
|
||||
"tile_size_h": ("INT", {
|
||||
"default": 60, "min": 16, "max": 256, "step": 4,
|
||||
}),
|
||||
"tile_size_w": ("INT", {
|
||||
"default": 104, "min": 16, "max": 256, "step": 4,
|
||||
}),
|
||||
"topk_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
}),
|
||||
"kv_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
}),
|
||||
"local_range": ([9, 11], {
|
||||
"default": 9,
|
||||
}),
|
||||
"color_fix": ("BOOLEAN", {
|
||||
"default": True,
|
||||
}),
|
||||
"unload_dit": ("BOOLEAN", {
|
||||
"default": False,
|
||||
}),
|
||||
"seed": ("INT", {
|
||||
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL")
|
||||
RETURN_NAMES = ("images", "model")
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def upscale(self, images, model, segment_index, segment_size,
|
||||
overlap_frames, blend_frames, scale,
|
||||
tiled, tile_size_h, tile_size_w,
|
||||
topk_ratio, kv_ratio, local_range,
|
||||
color_fix, unload_dit, seed):
|
||||
total_input = images.shape[0]
|
||||
blend_frames = min(blend_frames, overlap_frames)
|
||||
|
||||
# Clear stale overlap data from previous workflow runs
|
||||
if segment_index == 0:
|
||||
model._overlap_tail = None
|
||||
|
||||
# Compute segment boundaries
|
||||
stride = segment_size - overlap_frames
|
||||
start = segment_index * stride
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input:
|
||||
# Past the end
|
||||
return (images[:1], model)
|
||||
|
||||
# Ensure minimum frame count
|
||||
actual_size = end - start
|
||||
if actual_size < FlashVSRModel.MIN_FRAMES:
|
||||
start = max(0, end - FlashVSRModel.MIN_FRAMES)
|
||||
actual_size = end - start
|
||||
|
||||
segment_frames = images[start:end]
|
||||
|
||||
tile_size = (tile_size_h, tile_size_w)
|
||||
|
||||
model.load_to_device()
|
||||
|
||||
result = model.upscale(
|
||||
segment_frames,
|
||||
scale=scale, tiled=tiled, tile_size=tile_size,
|
||||
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
|
||||
local_range=local_range, color_fix=color_fix,
|
||||
unload_dit=unload_dit, seed=seed,
|
||||
)
|
||||
|
||||
model.clear_caches()
|
||||
model.offload()
|
||||
from .flashvsr_arch.models.utils import clean_vram
|
||||
clean_vram()
|
||||
|
||||
# Handle crossfade blending with previous segment's tail
|
||||
if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'):
|
||||
prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU
|
||||
|
||||
# The overlap region in result: first overlap_frames of the upscaled output
|
||||
# Within overlap: first (overlap - blend) frames are warmup (discard)
|
||||
# last blend_frames frames: crossfade with prev_tail
|
||||
warmup = overlap_frames - blend_frames
|
||||
|
||||
if blend_frames > 0 and prev_tail is not None:
|
||||
# Linear alpha ramp for crossfade
|
||||
alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1)
|
||||
blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames]
|
||||
result = torch.cat([blended, result[overlap_frames:]], dim=0)
|
||||
else:
|
||||
result = result[overlap_frames:]
|
||||
elif segment_index > 0 and overlap_frames > 0:
|
||||
# No previous tail stored, just skip overlap
|
||||
result = result[overlap_frames:]
|
||||
|
||||
# Store tail frames for next segment's crossfade
|
||||
if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames:
|
||||
model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16)
|
||||
else:
|
||||
model._overlap_tail = None
|
||||
|
||||
return (result, model)
|
||||
|
||||
Reference in New Issue
Block a user