Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)

Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines,
v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based
processing with temporal overlap and crossfade blending for bounded RAM on long videos.

Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 15:12:33 +01:00
parent e253cb244e
commit 0fecfcee37
23 changed files with 5733 additions and 9 deletions

View File

@@ -1,6 +1,6 @@
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM.
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM.
## Which model should I use?
@@ -18,6 +18,21 @@ ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://githu
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair.
### Video Super-Resolution
FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above.
| | FlashVSR |
|---|----------|
| **Task** | 4x video super-resolution |
| **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) |
| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) |
| **VRAM** | ~812 GB (tiled, tiny mode) / ~1624 GB (full mode) |
| **Params** | ~1.3B (DiT) + ~200M (VAE) |
| **Min input** | 21 frames |
| **Paper** | arXiv 2510.12747 |
| **License** | Apache 2.0 |
## Nodes
### BIM-VFI
@@ -136,7 +151,61 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, p
Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
**Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
### FlashVSR
FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling.
#### Load FlashVSR Model
Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`.
| Input | Description |
|-------|-------------|
| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) |
| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) |
Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)):
| Checkpoint | Size | Description |
|-----------|------|-------------|
| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) |
| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE |
| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection |
| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) |
| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings |
#### FlashVSR Upscale
Upscales an image batch with 4x spatial super-resolution.
| Input | Description |
|-------|-------------|
| **images** | Input video frames (minimum 21 frames) |
| **model** | Model from the loader node |
| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) |
| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames |
| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) |
| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) |
| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) |
| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) |
| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability |
| **color_fix** | Apply wavelet color correction to prevent color shifts |
| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) |
| **seed** | Random seed for the diffusion process |
#### FlashVSR Segment Upscale
Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
| Input | Description |
|-------|-------------|
| **segment_index** | Which segment to process (0-based) |
| **segment_size** | Number of input frames per segment (minimum 21) |
| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending |
| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) |
Plus all the same upscale parameters as FlashVSR Upscale.
## Installation
@@ -147,7 +216,7 @@ cd ComfyUI/custom_nodes
git clone https://github.com/your-user/ComfyUI-Tween.git
```
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
> **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with:
> ```bash
@@ -168,7 +237,8 @@ python install.py
- `timm` (for EMA-VFI and SGM-VFI)
- `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download)
- `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI)
- `huggingface_hub` (for GIMM-VFI model auto-download)
- `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download)
- `safetensors` (for FlashVSR checkpoint loading)
## VRAM Guide
@@ -181,7 +251,7 @@ python install.py
## Acknowledgments
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, and the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU). GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths).
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths).
**BiM-VFI:**
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
@@ -243,6 +313,21 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
}
```
**FlashVSR:**
> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng.
> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer."
> *arXiv preprint arXiv:2510.12747*, 2025.
> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR)
```bibtex
@article{zhuang2025flashvsr,
title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer},
author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing},
journal={arXiv preprint arXiv:2510.12747},
year={2025}
}
```
## License
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
@@ -252,3 +337,5 @@ The EMA-VFI model weights and architecture code are released under the [Apache 2
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR).

View File

@@ -34,8 +34,8 @@ def _auto_install_deps():
except Exception as e:
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
# GIMM-VFI dependencies
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"):
# GIMM-VFI + FlashVSR dependencies
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"):
try:
__import__(pkg)
except ImportError:
@@ -50,6 +50,7 @@ from .nodes import (
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale,
)
WEB_DIRECTORY = "./web"
@@ -68,6 +69,9 @@ NODE_CLASS_MAPPINGS = {
"LoadGIMMVFIModel": LoadGIMMVFIModel,
"GIMMVFIInterpolate": GIMMVFIInterpolate,
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
"LoadFlashVSRModel": LoadFlashVSRModel,
"FlashVSRUpscale": FlashVSRUpscale,
"FlashVSRSegmentUpscale": FlashVSRSegmentUpscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -84,4 +88,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
"LoadFlashVSRModel": "Load FlashVSR Model",
"FlashVSRUpscale": "FlashVSR Upscale",
"FlashVSRSegmentUpscale": "FlashVSR Segment Upscale",
}

View File

@@ -0,0 +1,4 @@
from .models.model_manager import ModelManager
from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
from .models.utils import clean_vram, Buffer_LQ4x_Proj
from .models.TCDecoder import build_tcdecoder

View File

View File

@@ -0,0 +1,21 @@
from ..models.wan_video_dit import WanModel
from ..models.wan_video_vae import WanVideoVAE
model_loader_configs = [
# (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
]
huggingface_model_loader_configs = [
]
patch_model_loader_configs = [
]

View File

@@ -0,0 +1,320 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
- Encoder removed
- Transplant/widening helpers removed
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
from einops import rearrange
import torch.nn.init as init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
# ----------------------------
# Utility / building blocks
# ----------------------------
class IdentityConv2d(nn.Conv2d):
"""Same-shape Conv2d initialized to identity (Dirac)."""
def __init__(self, C, kernel_size=3, bias=False):
pad = kernel_size // 2
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
with torch.no_grad():
init.dirac_(self.weight)
if self.bias is not None:
self.bias.zero_()
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out)
)
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
B, C, F, H, W = x.shape
if F % self.ff != 0:
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
x = torch.cat([first_frame, x], dim=2)
return rearrange(
x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww
).transpose(1, 2)
# ----------------------------
# Generic NTCHW graph executor (kept; used by decoder)
# ----------------------------
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
progress_bar = tqdm(range(T), disable=not show_progress_bar)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt)
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt)
if len(mem[i]) > b.stride:
raise ValueError("TPool internal state invalid.")
elif len(mem[i]) == b.stride:
N_, C_, H_, W_ = xt.shape
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
mem[i] = []
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C_, H_, W_ = xt.shape
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x, mem
# ----------------------------
# Decoder-only TAEHV
# ----------------------------
class TAEHV(nn.Module):
image_channels = 3
def __init__(
self,
checkpoint_path="taehv.pth",
decoder_time_upscale=(True, True),
decoder_space_upscale=(True, True, True),
channels = [256, 128, 64, 64],
latent_channels = 16
):
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
Deepening config: how_many_each=1, k=3 (fixed as requested).
"""
super().__init__()
self.latent_channels = latent_channels
n_f = channels
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
# Build the decoder "skeleton"
base_decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
if checkpoint_path is not None:
missing_keys = self.load_state_dict(
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
strict=False
)
print('missing_keys', missing_keys)
# Initialize decoder mem state
self.mem = [None] * len(self.decoder)
@staticmethod
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
new_layers = []
for b in decoder:
new_layers.append(b)
if isinstance(b, nn.ReLU):
# Deduce channel count from preceding layer
C = None
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
C = new_layers[-2].out_channels
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
C = new_layers[-2].conv[-1].out_channels
if C is not None:
for _ in range(how_many_each):
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
new_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*new_layers)
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
"""Decode a sequence of frames from latents.
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
"""
trim_flag = self.mem[-8] is None # keeps original relative check
if cond is not None:
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
if trim_flag:
return x[:, self.frames_to_trim:]
return x
def forward(self, *args, **kwargs):
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
def clean_mem(self):
self.mem = [None] * len(self.decoder)
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEW2_1DiffusersWrapper(nn.Module):
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
super().__init__()
self.dtype = torch.bfloat16
self.device = "cuda"
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
self.temperal_downsample = [True, True, False] # [sic]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
def decode(self, latents, return_dict=None):
n, c, t, h, w = latents.shape
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
n, c, t, h, w = latents.shape
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
def clean_mem(self):
self.taehv.clean_mem()
# ----------------------------
# Simplified builder (no small, no transplant, no post-hoc deepening)
# ----------------------------
def build_tcdecoder(new_channels = [512, 256, 128, 128],
device="cuda",
dtype=torch.bfloat16,
new_latent_channels=None):
"""
构建“更宽”的 decoder深度增强IdentityConv2d+ReLU已在 TAEHV 内部完成。
- 不创建 small / 不做移植
- base_ckpt_path 参数保留但不使用(接口兼容)
返回big (单个模型)
"""
if new_latent_channels is not None:
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
else:
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
big.clean_mem()
return big

View File

@@ -0,0 +1 @@
from .model_manager import *

View File

@@ -0,0 +1,402 @@
import os, torch, json, importlib
from typing import List
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
with init_weights_on_device():
model = model_class(**extra_kwargs)
if hasattr(model, "eval"):
model = model.eval()
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
is_loaded = False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
is_loaded = True
break
if not is_loaded:
print(f" Cannot load LoRA: {file_path}")
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
#print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
#print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -0,0 +1,360 @@
import torch, os, gc
from safetensors import safe_open
from contextlib import contextmanager
from einops import rearrange, repeat
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import hashlib
CACHE_T = 2
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def clean_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.mps.is_available():
torch.mps.empty_cache()
def get_device_list():
devs = []
if torch.cuda.is_available():
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
if torch.mps.is_available():
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
return devs
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode='replicate') # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim = 2)
# print(out_x.shape)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs

View File

@@ -0,0 +1,834 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
import time
from typing import Tuple, Optional, List
from einops import rearrange
from .utils import hash_state_dict_keys
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
from sageattn.core import sparse_sageattn
SPARSE_SAGE_AVAILABLE = True
except ModuleNotFoundError:
SPARSE_SAGE_AVAILABLE = False
sparse_sageattn = None
from PIL import Image
import numpy as np
# ----------------------------
# Local / window masks
# ----------------------------
@torch.no_grad()
def build_local_block_mask_shifted_vec(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
end_r = start_r + win_h - 1
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
@torch.no_grad()
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = r_all - r_half
end_r = start_r + win_h - 1
start_c = c_all - c_half
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
class WindowPartition3D:
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
@staticmethod
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
B, F, H, W, C = x.shape
wf, wh, ww = win
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
return x.view(-1, wf * wh * ww, C)
@staticmethod
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
F, H, W = orig
wf, wh, ww = win
nf, nh, nw = F // wf, H // wh, W // ww
B = windows.size(0) // (nf * nh * nw)
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
return x.view(B, F, H, W, -1)
@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_k = torch.mean(k_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_heads = avgpool_k.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
n = flat.shape[1]
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
# 修正:上行变量名统一
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
mask = mask.repeat_interleave(2, dim=-1)
return mask
@torch.no_grad()
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
avgpool_k_split = torch.mean(k_w_split, dim=2)
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
k_heads = avgpool_k_refined.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
num_q_blocks_local = local_attn_mask.shape[0]
num_k_blocks_local = local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
assert scores.shape == local_attn_mask.shape
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) # it=seqlen可能需要调整取决于seqlen的含义
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
# ----------------------------
# Attention kernels
# ----------------------------
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
seqlen = q.shape[1]
seqlen_kv = k.shape[1]
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
base_blockmask = attention_mask
x = sparse_sageattn(
q, k, v,
mask_id=base_blockmask.to(torch.int8),
is_causal=False,
tensor_layout="HND"
)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
half_dim = max(dim // 2, 1)
scale = torch.arange(half_dim, dtype=torch.float32, device=position.device)
inv_freq = torch.pow(10000.0, -scale / half_dim)
sinusoid = torch.outer(position.to(torch.float32), inv_freq)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
half_dim = max(dim // 2, 1)
base = torch.arange(0, dim, 2, dtype=torch.float32)[:half_dim]
freqs = torch.pow(theta, -base / max(dim, 1))
steps = torch.arange(end, dtype=torch.float32)
angles = torch.outer(steps, freqs)
return torch.polar(torch.ones_like(angles), angles)
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
orig_dtype = x.dtype
work_dtype = torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype
reshaped = x.to(work_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
x_complex = torch.view_as_complex(reshaped)
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
return x_out.to(orig_dtype)
# ----------------------------
# Norms & Blocks
# ----------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads, enable_sageattention=True):
super().__init__()
self.num_heads = num_heads
self.enable_sageattention = enable_sageattention
def forward(self, q, k, v, attention_mask=None):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
self.local_attn_mask = None
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
B, L, D = x.shape
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
assert f==2, "f must be 2"
if is_stream and (pre_cache_k is None or pre_cache_v is None):
assert f==6, " start f must be 6"
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
win = (2, 8, 8)
q = q.view(B, f, h, w, D)
k = k.view(B, f, h, w, D)
v = v.view(B, f, h, w, D)
q_w = WindowPartition3D.partition(q, win)
k_w = WindowPartition3D.partition(k, win)
v_w = WindowPartition3D.partition(v, win)
seqlen = f//win[0]
one_len = k_w.shape[0] // B // seqlen
if pre_cache_k is not None and pre_cache_v is not None:
k_w = torch.cat([pre_cache_k, k_w], dim=0)
v_w = torch.cat([pre_cache_v, v_w], dim=0)
block_n = q_w.shape[0] // B
block_s = q_w.shape[1]
block_n_kv = k_w.shape[0] // B
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
window_size = win[0]*h*w//128
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
self.local_attn_mask_h = h//8
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
if cache_num > kv_len:
cache_k = k_w[one_len:, :, :]
cache_v = v_w[one_len:, :, :]
else:
cache_k = k_w
cache_v = v_w
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
x = WindowPartition3D.reverse(x, win, (f, h, w))
x = x.view(B, f*h*w, D)
if is_stream:
return self.o(x), cache_k, cache_v
return self.o(x)
class CrossAttention(nn.Module):
"""
仅考虑文本 context提供持久 KV 缓存。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
# 持久缓存
self.cache_k = None
self.cache_v = None
@torch.no_grad()
def init_cache(self, ctx: torch.Tensor):
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
self.cache_k = self.norm_k(self.k(ctx))
self.cache_v = self.v(ctx)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
"""
y 即文本上下文(未做其他分支)。
"""
q = self.norm_q(self.q(x))
assert self.cache_k is not None and self.cache_v is not None
k = self.cache_k
v = self.cache_v
x = self.attn(q, k, v)
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
x = self.gate(x, gate_msa, self_attn_output)
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
if is_stream:
return x, self_attn_cache_k, self_attn_cache_v
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
# ----------------------------
# WanModel (no image branch) — init 时即产生 KV 缓存
# ----------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool = False,
enable_sageattention: bool = True,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
# patch embed
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# text / time embed
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self._cross_kv_initialized = False
# 可选:手动清空 / 重新初始化
# 可选:手动清空 / 重新初始化
def clear_cross_kv(self):
for blk in self.blocks:
blk.cross_attn.clear_cache()
self._cross_kv_initialized = False
@torch.no_grad()
def reinit_cross_kv(self, new_context: torch.Tensor):
ctx_txt = self.text_embedding(new_context)
for blk in self.blocks:
blk.cross_attn.init_cache(ctx_txt)
self._cross_kv_initialized = True
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
LQ_latents: Optional[List[torch.Tensor]] = None,
train_img: bool = False,
topk_ratio: Optional[float] = None,
kv_ratio: Optional[float] = None,
local_num: Optional[int] = None,
is_full_block: bool = False,
causal_idx: Optional[int] = None,
**kwargs,
):
# time / text embeds
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# 这里仍会嵌入 textCrossAttention 若已有缓存会忽略它)
# context = self.text_embedding(context)
# 输入打补丁
x, (f, h, w) = self.patchify(x)
B = x.shape[0]
# window / masks 超参
win = (2, 8, 8)
seqlen = f//win[0]
if local_num is None:
local_random = random.random()
if local_random < 0.3:
local_num = seqlen - 3
elif local_random < 0.4:
local_num = seqlen - 4
elif local_random < 0.5:
local_num = seqlen - 2
else:
local_num = seqlen
window_size = win[0]*h*w//128
square_num = window_size*window_size
topk_ratio = 2.0
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
if kv_ratio is None:
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
decay_ratio = random.uniform(0.7, 1.0)
# RoPE 3D
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block_id, block in enumerate(self.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x += LQ_latents[block_id]
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
# ----------------------------
# State dict converter保持原映射已忽略 has_image_input 使用)
# ----------------------------
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
# 保留原有哈希匹配返回的 config实现本身不使用 has_image_input 分支
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,847 @@
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
CACHE_T = 2
def check_is_instance(model, module_class):
if isinstance(model, module_class):
return True
if hasattr(model, "module") and isinstance(model.module, module_class):
return True
return False
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
assert s % block_size == 0
num_blocks = s // block_size
# build mask
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
for i in range(num_blocks):
mask[:, :,
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
return mask
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(dim,
dim * 2, (3, 1, 1),
padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim,
dim, (3, 1, 1),
stride=(2, 1, 1),
padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
#attn_mask=block_causal_mask(q, block_size=h * w)
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if check_is_instance(m, CausalConv3d):
count += 1
return count
class VideoVAE_(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def stream_decode(self, z, scale):
# self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# print('self._feat_map', len(self._feat_map))
# cache encode
if self.encoder is not None:
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
# print('self._enc_feat_map', len(self._enc_feat_map))
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16, dim=96):
super().__init__()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
def clear_cache(self):
self.model.clear_cache()
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state for hidden_state in hidden_states]
assert len(hidden_states) == 1
hidden_state = hidden_states[0]
video = self.model.stream_decode(hidden_state, self.scale)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()
class WanVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,3 @@
from .flashvsr_full import FlashVSRFullPipeline
from .flashvsr_tiny import FlashVSRTinyPipeline
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline

View File

@@ -0,0 +1,127 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -0,0 +1,636 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRFullPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
latents_total = []
self.vae.clear_cache()
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
latents_total.append(cur_latents)
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
try:
del pre_cache_k, pre_cache_v
except NameError:
pass
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
self.dit.to('cpu')
clean_vram()
latents = torch.cat(latents_total, dim=2)
del latents_total
clean_vram()
if skip_vae:
return latents
# Decode
print("[FlashVSR] Starting VAE decoding...")
frames = self.decode_video(latents, **tiler_kwargs)
# 颜色校正wavelet
try:
if color_fix:
frames = self.ColorCorrector(
frames.to(device=LQ_video.device),
LQ_video[:, :, :frames.shape[2], :, :],
clip_range=(-1, 1),
chunk_size=16,
method='adain'
)
except:
pass
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,633 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-14B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
latents_total = []
self.TCDecoder.clean_mem()
LQ_pre_idx = 0
LQ_cur_idx = 0
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
latents_total.append(cur_latents)
LQ_pre_idx = LQ_cur_idx
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
try:
del pre_cache_k, pre_cache_v
except NameError:
pass
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
self.dit.to('cpu')
clean_vram()
latents = torch.cat(latents_total, dim=2)
del latents_total
clean_vram()
if skip_vae:
return latents
# Decode
print("[FlashVSR] Starting VAE decoding...")
frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1)
# 颜色校正wavelet
try:
if color_fix:
frames = self.ColorCorrector(
frames.to(device=LQ_video.device),
LQ_video[:, :, :frames.shape[2], :, :],
clip_range=(-1, 1),
chunk_size=16,
method='adain'
)
except:
pass
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,619 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyLongPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
self.TCDecoder.clean_mem()
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
# Decode
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
# 颜色校正wavelet
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1 @@
from .flow_match import FlowMatchScheduler

View File

@@ -0,0 +1,79 @@
import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -0,0 +1 @@
from .layers import *

View File

@@ -0,0 +1,95 @@
import torch, copy
from ..models.utils import init_weights_on_device
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
model.vram_management_enabled = True

View File

@@ -1,8 +1,11 @@
import logging
import math
import os
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .bim_vfi_arch import BiMVFI
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
@@ -621,3 +624,241 @@ class GIMMVFIModel:
results.append(torch.clamp(unpadded, 0, 1))
return results
# ---------------------------------------------------------------------------
# FlashVSR model wrapper (4x video super-resolution)
# ---------------------------------------------------------------------------
class FlashVSRModel:
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
Supports three pipeline modes:
- full: Standard VAE decode, highest quality
- tiny: TCDecoder decode, faster
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
"""
# Minimum input frame count required by the pipeline
MIN_FRAMES = 21
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
from safetensors.torch import load_file
from .flashvsr_arch import (
ModelManager, FlashVSRFullPipeline,
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
)
from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
self.mode = mode
self.device = device
self.dtype = dtype
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
mm = ModelManager(torch_dtype=dtype, device="cpu")
if mode == "full":
mm.load_models([dit_path, vae_path])
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
self.pipe.vae.model.encoder = None
self.pipe.vae.model.conv1 = None
else:
mm.load_models([dit_path])
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
self.pipe = Pipeline.from_model_manager(mm, device=device)
self.pipe.TCDecoder = build_tcdecoder(
[512, 256, 128, 128], device, dtype, 16 + 768,
)
self.pipe.TCDecoder.load_state_dict(
load_file(tcd_path, device=device), strict=False,
)
self.pipe.TCDecoder.clean_mem()
# LQ frame projection
self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype)
if os.path.exists(lq_path):
lq_sd = load_file(lq_path, device="cpu")
cleaned = {}
for k, v in lq_sd.items():
cleaned[k.removeprefix("LQ_proj_in.")] = v
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
self.pipe.denoising_model().LQ_proj_in.to(device)
self.pipe.to(device, dtype)
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
self.pipe.init_cross_kv(prompt_path=prompt_path)
self.pipe.load_models_to_device([]) # offload to CPU
def to(self, device):
self.device = device
self.pipe.device = device
return self
def load_to_device(self):
"""Load models to the compute device for inference."""
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
self.pipe.load_models_to_device(names)
def offload(self):
"""Offload models to CPU."""
self.pipe.load_models_to_device([])
def clear_caches(self):
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
self.pipe.denoising_model().LQ_proj_in.clear_cache()
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
self.pipe.vae.clear_cache()
# ------------------------------------------------------------------
# Frame preprocessing / postprocessing helpers
# ------------------------------------------------------------------
@staticmethod
def _compute_dims(w, h, scale, align=128):
sw, sh = w * scale, h * scale
tw = math.ceil(sw / align) * align
th = math.ceil(sh / align) * align
return sw, sh, tw, th
@staticmethod
def _pad_video_5d(video):
"""Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline.
Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring
the pipeline's streaming loop gets correct iteration counts.
"""
tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1)
video = torch.cat([video, tail], dim=2)
added = 0
remainder = (video.shape[2] + 2 - 5) % 8
if remainder != 0:
added = 8 - remainder
pad = video[:, :, -1:].repeat(1, 1, added, 1, 1)
video = torch.cat([video, pad], dim=2)
return video, added
@staticmethod
def _restore_video_sequence(result, added_frames, expected):
"""Strip padding and warmup frames from the output."""
if added_frames > 0 and result.shape[0] > added_frames:
result = result[:-added_frames]
# Strip the first 2 pipeline warmup frames
if result.shape[0] > 2:
result = result[2:]
# Adjust to exact expected count
if result.shape[0] > expected:
result = result[:expected]
elif result.shape[0] < expected:
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
result = torch.cat([result, pad], dim=0)
return result
def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1].
Bicubic-upscales each frame to the target resolution, normalizes to
[-1, 1], then applies temporal padding for the pipeline.
Returns:
video: [1, C, F_padded, H, W] tensor
th, tw: padded spatial dimensions
nf: padded frame count (= video.shape[2])
sh, sw: actual (unpadded) spatial dimensions
added: number of alignment-padding frames added
"""
N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale)
processed = []
for i in range(N):
frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
pad_h, pad_w = th - sh, tw - sw
if pad_h > 0 or pad_w > 0:
upscaled = F.pad(upscaled, (0, pad_w, 0, pad_h), mode='replicate')
normalized = upscaled * 2.0 - 1.0
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
# Apply temporal padding (tail + alignment)
video, added = self._pad_video_5d(video)
nf = video.shape[2]
return video, th, tw, nf, sh, sw, added
@staticmethod
def _to_frames(video):
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
from einops import rearrange
v = video.squeeze(0) if video.dim() == 5 else video
v = rearrange(v, "C F H W -> F H W C")
return (v.float() + 1.0) / 2.0
# ------------------------------------------------------------------
# Main upscale method
# ------------------------------------------------------------------
@torch.no_grad()
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
topk_ratio=2.0, kv_ratio=2.0, local_range=9,
color_fix=True, unload_dit=False, seed=1,
progress_bar_cmd=None):
"""Upscale video frames with FlashVSR.
Args:
frames: [F, H, W, C] float32 [0, 1] with F >= 21
scale: Upscaling factor (2 or 4)
tiled: Enable VAE tiled decode (saves VRAM)
tile_size: (H, W) tile size for VAE tiling
topk_ratio: Sparse attention ratio (higher = faster, less detail)
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
local_range: Local attention window (9=sharp, 11=stable)
color_fix: Apply wavelet color correction
unload_dit: Offload DiT before VAE decode (saves VRAM)
seed: Random seed
progress_bar_cmd: Callable wrapping an iterable for progress display
Returns:
[F, H*scale, W*scale, C] float32 [0, 1]
"""
if progress_bar_cmd is None:
from tqdm import tqdm
progress_bar_cmd = tqdm
original_count = frames.shape[0]
# Prepare video tensor (bicubic upscale + pad)
video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale)
# Move LQ video to compute device (except for "long" mode which streams)
if "long" not in self.pipe.__class__.__name__.lower():
video = video.to(self.pipe.device)
# Run pipeline
out = self.pipe(
prompt="", negative_prompt="",
cfg_scale=1.0, num_inference_steps=1,
seed=seed, tiled=tiled, tile_size=tile_size,
progress_bar_cmd=progress_bar_cmd,
LQ_video=video,
num_frames=nf, height=th, width=tw,
is_full_block=False, if_buffer=True,
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
kv_ratio=kv_ratio, local_range=local_range,
color_fix=color_fix, unload_dit=unload_dit,
)
# Convert to ComfyUI format and crop spatial padding
result = self._to_frames(out).cpu()[:, :sh, :sw, :]
# Restore original frame count (strip temporal padding + warmup)
result = self._restore_video_sequence(result, added_frames, original_count)
return result

407
nodes.py
View File

@@ -8,7 +8,7 @@ import torch
import folder_paths
from comfy.utils import ProgressBar
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel
from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
@@ -1507,3 +1507,408 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
result = result[1:] # skip duplicate boundary frame
return (result, model)
# ---------------------------------------------------------------------------
# FlashVSR nodes (4x video super-resolution)
# ---------------------------------------------------------------------------
FLASHVSR_HF_REPO = "1038lab/FlashVSR"
FLASHVSR_REQUIRED_FILES = [
"FlashVSR1_1.safetensors",
"Wan2.1_VAE.safetensors",
"LQ_proj_in.safetensors",
"TCDecoder.safetensors",
"Prompt.safetensors",
]
FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "flashvsr")
if not os.path.exists(FLASHVSR_MODEL_DIR):
os.makedirs(FLASHVSR_MODEL_DIR, exist_ok=True)
def download_flashvsr_models(model_dir):
"""Download FlashVSR checkpoints from HuggingFace if missing."""
from huggingface_hub import snapshot_download
missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if not missing:
return
logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...")
snapshot_download(
repo_id=FLASHVSR_HF_REPO,
local_dir=model_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
still_missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if still_missing:
raise FileNotFoundError(
f"[FlashVSR] Failed to download: {', '.join(still_missing)}. "
f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}"
)
logger.info("[FlashVSR] All checkpoints downloaded successfully.")
class _FlashVSRProgressBar:
"""Wrap an iterable with a ComfyUI ProgressBar."""
def __init__(self, total, pbar, step_ref):
self.total = total
self.pbar = pbar
self.step_ref = step_ref
def __call__(self, iterable):
return self._Wrapper(iterable, self.pbar, self.step_ref)
class _Wrapper:
def __init__(self, iterable, pbar, step_ref):
self.iterable = iterable
self.pbar = pbar
self.step_ref = step_ref
self._iter = iter(iterable)
def __iter__(self):
return self
def __next__(self):
val = next(self._iter)
self.step_ref[0] += 1
self.pbar.update_absolute(self.step_ref[0])
return val
def __len__(self):
return len(self.iterable)
class LoadFlashVSRModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["tiny", "tiny-long", "full"], {
"default": "tiny",
"tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. "
"Tiny-long: streaming TCDecoder, lowest VRAM for long videos. "
"Full: standard VAE decode, highest quality but more VRAM.",
}),
"precision": (["bf16", "fp16"], {
"default": "bf16",
"tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.",
}),
}
}
RETURN_TYPES = ("FLASHVSR_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/FlashVSR"
def load_model(self, mode, precision):
download_flashvsr_models(FLASHVSR_MODEL_DIR)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
wrapper = FlashVSRModel(
model_dir=FLASHVSR_MODEL_DIR,
mode=mode,
device=device,
dtype=dtype,
)
logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})")
return (wrapper,)
class FlashVSRUpscale:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from the Load FlashVSR Model node.",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.",
}),
"frame_chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). "
"Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.",
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.",
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "KV cache ratio. Higher = better quality, more VRAM.",
}),
"local_range": ([9, 11], {
"default": 9,
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.",
}),
"color_fix": ("BOOLEAN", {
"default": True,
"tooltip": "Apply color correction to prevent color shifts from the diffusion process.",
}),
"unload_dit": ("BOOLEAN", {
"default": False,
"tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.",
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Random seed for the diffusion process.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, scale, frame_chunk_size,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
num_frames = images.shape[0]
if num_frames < FlashVSRModel.MIN_FRAMES:
raise ValueError(
f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}"
)
tile_size = (tile_size_h, tile_size_w)
# Build frame chunks
if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames:
chunks = [(0, num_frames)]
else:
chunks = []
start = 0
while start < num_frames:
end = min(start + frame_chunk_size, num_frames)
chunks.append((start, end))
if end == num_frames:
break
start = end
# If the last chunk is too small, merge it into the previous one
if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES:
prev_start = chunks[-2][0]
last_end = chunks[-1][1]
chunks = chunks[:-2]
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8
total_steps = 0
for cs, ce in chunks:
padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d
remainder = (padded_n + 2 - 5) % 8
if remainder != 0:
padded_n += 8 - remainder
total_steps += max(1, (padded_n - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
step_ref = [0]
progress = _FlashVSRProgressBar(total_steps, pbar, step_ref)
model.load_to_device()
result_chunks = []
for chunk_start, chunk_end in chunks:
chunk_frames = images[chunk_start:chunk_end]
chunk_result = model.upscale(
chunk_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
progress_bar_cmd=progress,
)
result_chunks.append(chunk_result)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
return (torch.cat(result_chunks, dim=0),)
class FlashVSRSegmentUpscale:
"""Process a numbered segment with temporal overlap and crossfade blending.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through forces sequential execution so each segment
saves and frees RAM before the next starts.
Crossfade blending within the overlap region:
- First (overlap - blend) frames: warmup only, discarded from output
- Last blend frames: linear alpha crossfade with previous segment's tail
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Full input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from Load FlashVSR Model. "
"Chain the model output to the next segment node for sequential execution.",
}),
"segment_index": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based).",
}),
"segment_size": ("INT", {
"default": 100, "min": 21, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment.",
}),
"overlap_frames": ("INT", {
"default": 8, "min": 0, "max": 100, "step": 1,
"tooltip": "Number of overlapping frames between adjacent segments. "
"These frames provide temporal context and crossfade blending.",
}),
"blend_frames": ("INT", {
"default": 4, "min": 0, "max": 50, "step": 1,
"tooltip": "Number of frames within the overlap region to crossfade. "
"Must be <= overlap_frames. The rest of the overlap is warmup (discarded).",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor.",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"kv_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"local_range": ([9, 11], {
"default": 9,
}),
"color_fix": ("BOOLEAN", {
"default": True,
}),
"unload_dit": ("BOOLEAN", {
"default": False,
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
}),
}
}
RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, segment_index, segment_size,
overlap_frames, blend_frames, scale,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
total_input = images.shape[0]
blend_frames = min(blend_frames, overlap_frames)
# Clear stale overlap data from previous workflow runs
if segment_index == 0:
model._overlap_tail = None
# Compute segment boundaries
stride = segment_size - overlap_frames
start = segment_index * stride
end = min(start + segment_size, total_input)
if start >= total_input:
# Past the end
return (images[:1], model)
# Ensure minimum frame count
actual_size = end - start
if actual_size < FlashVSRModel.MIN_FRAMES:
start = max(0, end - FlashVSRModel.MIN_FRAMES)
actual_size = end - start
segment_frames = images[start:end]
tile_size = (tile_size_h, tile_size_w)
model.load_to_device()
result = model.upscale(
segment_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
# Handle crossfade blending with previous segment's tail
if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'):
prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU
# The overlap region in result: first overlap_frames of the upscaled output
# Within overlap: first (overlap - blend) frames are warmup (discard)
# last blend_frames frames: crossfade with prev_tail
warmup = overlap_frames - blend_frames
if blend_frames > 0 and prev_tail is not None:
# Linear alpha ramp for crossfade
alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1)
blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames]
result = torch.cat([blended, result[overlap_frames:]], dim=0)
else:
result = result[overlap_frames:]
elif segment_index > 0 and overlap_frames > 0:
# No previous tail stored, just skip overlap
result = result[overlap_frames:]
# Store tail frames for next segment's crossfade
if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames:
model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16)
else:
model._overlap_tail = None
return (result, model)

View File

@@ -4,3 +4,4 @@ yacs
easydict
einops
huggingface_hub
safetensors