15 Commits

Author SHA1 Message Date
dd61ae8d1f Bundle sparse_sage Triton kernel for block-sparse attention
Without sparse attention, the model uses full (dense) attention which
attends to distant irrelevant information, causing ghosting artifacts.
The FlashVSR paper explicitly requires block-sparse attention.

Vendored from SageAttention team (Apache 2.0), pure Triton (no CUDA C++).
Import chain: local sparse_sage → external sageattn.core → SDPA fallback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 19:22:40 +01:00
e7e7c1cb5a Fix sparse attention mask tiling for temporal windows
The local_attn_mask was not being tiled across temporal dimensions,
causing assertion errors in streaming mode and wrong masks otherwise.
Match naxci1 reference: 4D tile/rearrange for Q/K temporal windows,
chunk-based score computation, and topk<=0 guard.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:50:40 +01:00
3b87652184 Fix FlashVSR attention mask and output quality
- Use generate_draft_block_mask_refined for sparse attention mask (matches
  naxci1's generate_draft_block_mask_sage with proper half-block key scoring)
- Remove spurious repeat_interleave(2, dim=-1) from generate_draft_block_mask
  that doubled the key dimension incorrectly
- Add torch.clamp(0, 1) to _to_frames output (matches naxci1's tensor2video)
- Add .to(self.device) on LQ video slices in streaming loop for all pipelines

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:41:43 +01:00
76dff7e573 Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision
Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:06:46 +01:00
fa250897a2 Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection
Root cause: three critical differences from naxci1 reference implementation:

1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ
   conditioning inside the loop. The TCDecoder uses causal convolutions with
   temporal memory that must be built incrementally per-chunk. Batch decode
   breaks this design and loses LQ frame conditioning, causing ghosting.

2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal
   variant reads the OLD cache before writing the new one (truly causal),
   while Buffer writes cache BEFORE the conv call. Using the wrong variant
   misaligns temporal LQ conditioning features.

3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4)
   matching the naxci1 reference approach.

Changes:
- flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning
  and per-chunk color correction (was: batch VAE decode after loop)
- flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode)
- inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including
  full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in
  clear_caches()
- utils.py: add Causal_LQ4x_Proj class
- nodes.py: update progress bar estimation for new padding formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:42:20 +01:00
94d9818675 Fix FlashVSR quality: match naxci1 reference preprocessing
- Remove front dummy frames (not used by reference implementation)
- Use centered reflect padding instead of right/bottom replicate
- Crop output from center matching padding offsets
- Simplify temporal padding to 8k+1 alignment
- Update progress bar estimation to match new formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:10:12 +01:00
ea84ffef7c Fix FlashVSR ghosting: restore 2 front dummy frames matching reference
The pipeline's LQ conditioning indexing expects 2 front dummy frames
(copies of first frame) as warmup. Our previous refactoring removed
these, shifting all LQ conditioning by 2 frames and causing severe
ghosting artifacts.

Now matches the 1038lab reference preprocessing exactly:
1. _prepare_video: 2 tail copies + alignment + 2 front dummies + back padding
2. _restore_video_sequence: strip first 2 warmup frames + trim to original count
3. Crop pipeline output to padded_n before restoration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:49:46 +01:00
4cc6e9c705 Remove debug logging from FlashVSR SegmentUpscale
Issue was a workflow wiring mistake, not a code bug.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:32:23 +01:00
39d0f7af42 Add debug logging for FlashVSR SegmentUpscale output shapes
Helps diagnose issue where segment 1+ runs but produces no image output.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:31:09 +01:00
11e2acb9e0 Fix FlashVSR frame padding to match pipeline requirements
The pipeline requires num_frames % 4 == 1. Our old _pad_video_5d used a
wrong formula that produced non-conforming counts (e.g. 33 input → 35
padded → pipeline rounds to 37, wasting VRAM).

New padding uses num_frames % 8 == 1 (also satisfies % 4 == 1), which
ensures the streaming loop output exactly matches num_frames with zero
waste. Optimal input counts: 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105.

Also removes incorrect 2-frame warmup stripping from _restore_video_sequence
— the pipeline output doesn't have warmup artifacts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:20:02 +01:00
5071c4de4f Fix sageattn fallback: tensors already rearranged when exception fires
When sageattn fails, q/k/v are already in [b,n,s,d] format from the
rearrange before the call. Use SDPA directly on them instead of calling
_sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:08:01 +01:00
dd69a2fd2b Fix sageattn crash on Blackwell GPUs (sm_120)
SageAttention CUDA kernels don't support Blackwell yet. Catch runtime
failures from sageattn/sparse_sageattn, disable them, and fall back to
PyTorch SDPA. Only pays the try/except cost once per session.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:03:15 +01:00
f40504cbcf Fix crash when flash_attn is installed but broken
Verify attention backend functions are actually callable before marking
them available. Falls back to PyTorch SDPA instead of calling None.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:51:30 +01:00
8317a0603e Reuse FlashVSR models from 1038lab node if already downloaded
Check models/FlashVSR/ (1038lab convention) before models/flashvsr/ to
avoid downloading ~7GB of checkpoints twice. Only create the directory
when actually downloading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:42:10 +01:00
0fecfcee37 Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)
Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines,
v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based
processing with temporal overlap and crossfade blending for bounded RAM on long videos.

Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:12:33 +01:00
27 changed files with 6225 additions and 9 deletions

View File

@@ -1,6 +1,6 @@
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI # ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM. ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM.
## Which model should I use? ## Which model should I use?
@@ -18,6 +18,21 @@ ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://githu
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair. **TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair.
### Video Super-Resolution
FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above.
| | FlashVSR |
|---|----------|
| **Task** | 4x video super-resolution |
| **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) |
| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) |
| **VRAM** | ~812 GB (tiled, tiny mode) / ~1624 GB (full mode) |
| **Params** | ~1.3B (DiT) + ~200M (VAE) |
| **Min input** | 21 frames |
| **Paper** | arXiv 2510.12747 |
| **License** | Apache 2.0 |
## Nodes ## Nodes
### BIM-VFI ### BIM-VFI
@@ -136,7 +151,61 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, p
Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate. Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 **Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
### FlashVSR
FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling.
#### Load FlashVSR Model
Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`.
| Input | Description |
|-------|-------------|
| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) |
| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) |
Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)):
| Checkpoint | Size | Description |
|-----------|------|-------------|
| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) |
| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE |
| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection |
| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) |
| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings |
#### FlashVSR Upscale
Upscales an image batch with 4x spatial super-resolution.
| Input | Description |
|-------|-------------|
| **images** | Input video frames (minimum 21 frames) |
| **model** | Model from the loader node |
| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) |
| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames |
| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) |
| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) |
| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) |
| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) |
| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability |
| **color_fix** | Apply wavelet color correction to prevent color shifts |
| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) |
| **seed** | Random seed for the diffusion process |
#### FlashVSR Segment Upscale
Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
| Input | Description |
|-------|-------------|
| **segment_index** | Which segment to process (0-based) |
| **segment_size** | Number of input frames per segment (minimum 21) |
| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending |
| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) |
Plus all the same upscale parameters as FlashVSR Upscale.
## Installation ## Installation
@@ -147,7 +216,7 @@ cd ComfyUI/custom_nodes
git clone https://github.com/your-user/ComfyUI-Tween.git git clone https://github.com/your-user/ComfyUI-Tween.git
``` ```
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version. Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
> **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with: > **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with:
> ```bash > ```bash
@@ -168,7 +237,8 @@ python install.py
- `timm` (for EMA-VFI and SGM-VFI) - `timm` (for EMA-VFI and SGM-VFI)
- `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download) - `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download)
- `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI) - `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI)
- `huggingface_hub` (for GIMM-VFI model auto-download) - `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download)
- `safetensors` (for FlashVSR checkpoint loading)
## VRAM Guide ## VRAM Guide
@@ -181,7 +251,7 @@ python install.py
## Acknowledgments ## Acknowledgments
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, and the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU). GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths). This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths).
**BiM-VFI:** **BiM-VFI:**
> Wonyong Seo, Jihyong Oh, and Munchurl Kim. > Wonyong Seo, Jihyong Oh, and Munchurl Kim.
@@ -243,6 +313,21 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
} }
``` ```
**FlashVSR:**
> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng.
> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer."
> *arXiv preprint arXiv:2510.12747*, 2025.
> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR)
```bibtex
@article{zhuang2025flashvsr,
title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer},
author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing},
journal={arXiv preprint arXiv:2510.12747},
year={2025}
}
```
## License ## License
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details. The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
@@ -252,3 +337,5 @@ The EMA-VFI model weights and architecture code are released under the [Apache 2
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details. The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI). The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR).

View File

@@ -34,8 +34,8 @@ def _auto_install_deps():
except Exception as e: except Exception as e:
logger.warning(f"[Tween] Could not auto-install cupy: {e}") logger.warning(f"[Tween] Could not auto-install cupy: {e}")
# GIMM-VFI dependencies # GIMM-VFI + FlashVSR dependencies
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"): for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"):
try: try:
__import__(pkg) __import__(pkg)
except ImportError: except ImportError:
@@ -50,6 +50,7 @@ from .nodes import (
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale,
) )
WEB_DIRECTORY = "./web" WEB_DIRECTORY = "./web"
@@ -68,6 +69,9 @@ NODE_CLASS_MAPPINGS = {
"LoadGIMMVFIModel": LoadGIMMVFIModel, "LoadGIMMVFIModel": LoadGIMMVFIModel,
"GIMMVFIInterpolate": GIMMVFIInterpolate, "GIMMVFIInterpolate": GIMMVFIInterpolate,
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate, "GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
"LoadFlashVSRModel": LoadFlashVSRModel,
"FlashVSRUpscale": FlashVSRUpscale,
"FlashVSRSegmentUpscale": FlashVSRSegmentUpscale,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@@ -84,4 +88,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadGIMMVFIModel": "Load GIMM-VFI Model", "LoadGIMMVFIModel": "Load GIMM-VFI Model",
"GIMMVFIInterpolate": "GIMM-VFI Interpolate", "GIMMVFIInterpolate": "GIMM-VFI Interpolate",
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate", "GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
"LoadFlashVSRModel": "Load FlashVSR Model",
"FlashVSRUpscale": "FlashVSR Upscale",
"FlashVSRSegmentUpscale": "FlashVSR Segment Upscale",
} }

View File

@@ -0,0 +1,4 @@
from .models.model_manager import ModelManager
from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
from .models.utils import clean_vram, Buffer_LQ4x_Proj
from .models.TCDecoder import build_tcdecoder

View File

View File

@@ -0,0 +1,21 @@
from ..models.wan_video_dit import WanModel
from ..models.wan_video_vae import WanVideoVAE
model_loader_configs = [
# (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
]
huggingface_model_loader_configs = [
]
patch_model_loader_configs = [
]

View File

@@ -0,0 +1,320 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
- Encoder removed
- Transplant/widening helpers removed
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
from einops import rearrange
import torch.nn.init as init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
# ----------------------------
# Utility / building blocks
# ----------------------------
class IdentityConv2d(nn.Conv2d):
"""Same-shape Conv2d initialized to identity (Dirac)."""
def __init__(self, C, kernel_size=3, bias=False):
pad = kernel_size // 2
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
with torch.no_grad():
init.dirac_(self.weight)
if self.bias is not None:
self.bias.zero_()
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out)
)
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
B, C, F, H, W = x.shape
if F % self.ff != 0:
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
x = torch.cat([first_frame, x], dim=2)
return rearrange(
x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww
).transpose(1, 2)
# ----------------------------
# Generic NTCHW graph executor (kept; used by decoder)
# ----------------------------
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
progress_bar = tqdm(range(T), disable=not show_progress_bar)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt)
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt)
if len(mem[i]) > b.stride:
raise ValueError("TPool internal state invalid.")
elif len(mem[i]) == b.stride:
N_, C_, H_, W_ = xt.shape
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
mem[i] = []
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C_, H_, W_ = xt.shape
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x, mem
# ----------------------------
# Decoder-only TAEHV
# ----------------------------
class TAEHV(nn.Module):
image_channels = 3
def __init__(
self,
checkpoint_path="taehv.pth",
decoder_time_upscale=(True, True),
decoder_space_upscale=(True, True, True),
channels = [256, 128, 64, 64],
latent_channels = 16
):
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
Deepening config: how_many_each=1, k=3 (fixed as requested).
"""
super().__init__()
self.latent_channels = latent_channels
n_f = channels
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
# Build the decoder "skeleton"
base_decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
if checkpoint_path is not None:
missing_keys = self.load_state_dict(
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
strict=False
)
print('missing_keys', missing_keys)
# Initialize decoder mem state
self.mem = [None] * len(self.decoder)
@staticmethod
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
new_layers = []
for b in decoder:
new_layers.append(b)
if isinstance(b, nn.ReLU):
# Deduce channel count from preceding layer
C = None
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
C = new_layers[-2].out_channels
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
C = new_layers[-2].conv[-1].out_channels
if C is not None:
for _ in range(how_many_each):
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
new_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*new_layers)
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
"""Decode a sequence of frames from latents.
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
"""
trim_flag = self.mem[-8] is None # keeps original relative check
if cond is not None:
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
if trim_flag:
return x[:, self.frames_to_trim:]
return x
def forward(self, *args, **kwargs):
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
def clean_mem(self):
self.mem = [None] * len(self.decoder)
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEW2_1DiffusersWrapper(nn.Module):
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
super().__init__()
self.dtype = torch.bfloat16
self.device = "cuda"
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
self.temperal_downsample = [True, True, False] # [sic]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
def decode(self, latents, return_dict=None):
n, c, t, h, w = latents.shape
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
n, c, t, h, w = latents.shape
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
def clean_mem(self):
self.taehv.clean_mem()
# ----------------------------
# Simplified builder (no small, no transplant, no post-hoc deepening)
# ----------------------------
def build_tcdecoder(new_channels = [512, 256, 128, 128],
device="cuda",
dtype=torch.bfloat16,
new_latent_channels=None):
"""
构建“更宽”的 decoder深度增强IdentityConv2d+ReLU已在 TAEHV 内部完成。
- 不创建 small / 不做移植
- base_ckpt_path 参数保留但不使用(接口兼容)
返回big (单个模型)
"""
if new_latent_channels is not None:
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
else:
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
big.clean_mem()
return big

View File

@@ -0,0 +1 @@
from .model_manager import *

View File

@@ -0,0 +1,402 @@
import os, torch, json, importlib
from typing import List
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
with init_weights_on_device():
model = model_class(**extra_kwargs)
if hasattr(model, "eval"):
model = model.eval()
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
is_loaded = False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
is_loaded = True
break
if not is_loaded:
print(f" Cannot load LoRA: {file_path}")
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
#print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
#print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -0,0 +1,3 @@
from .core import sparse_sageattn
__all__ = ["sparse_sageattn"]

View File

@@ -0,0 +1,40 @@
"""
Sparse SageAttention — block-sparse INT8 attention via Triton.
https://github.com/jt-zhang/Sparse_SageAttention_API
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
from .quant_per_block import per_block_int8
from .sparse_int8_attn import forward as sparse_sageattn_fwd
import torch
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
if mask_id is None:
mask_id = torch.ones(
(q.shape[0], q.shape[1],
(q.shape[2] + 128 - 1) // 128,
(q.shape[3] + 64 - 1) // 64),
dtype=torch.int8, device=q.device,
)
output_dtype = q.dtype
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
v = v.to(torch.float16)
seq_dim = 1 if tensor_layout == "NHD" else 2
km = k.mean(dim=seq_dim, keepdim=True)
q_int8, q_scale, k_int8, k_scale = per_block_int8(
q, k, km=km, tensor_layout=tensor_layout,
)
o = sparse_sageattn_fwd(
q_int8, k_int8, mask_id, v, q_scale, k_scale,
is_causal=is_causal, tensor_layout=tensor_layout,
output_dtype=output_dtype,
)
return o

View File

@@ -0,0 +1,110 @@
"""
Per-block INT8 quantization kernel for Sparse SageAttention.
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
import torch
import triton
import triton.language as tl
@triton.jit
def quant_per_block_int8_kernel(
Input, Output, Scale, L,
stride_iz, stride_ih, stride_in,
stride_oz, stride_oh, stride_on,
stride_sz, stride_sh,
sm_scale,
C: tl.constexpr, BLK: tl.constexpr,
):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = (
Input
+ off_b * stride_iz
+ off_h * stride_ih
+ offs_n[:, None] * stride_in
+ offs_k[None, :]
)
output_ptrs = (
Output
+ off_b * stride_oz
+ off_h * stride_oh
+ offs_n[:, None] * stride_on
+ offs_k[None, :]
)
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.0
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
if km is not None:
k = k - km
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
else:
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
q_scale = torch.empty(
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32,
)
k_scale = torch.empty(
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32,
)
if sm_scale is None:
sm_scale = head_dim ** -0.5
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q, q_int8, q_scale, qo_len,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_qo, stride_h_qo, stride_seq_qo,
q_scale.stride(0), q_scale.stride(1),
sm_scale=(sm_scale * 1.44269504),
C=head_dim, BLK=BLKQ,
)
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k, k_int8, k_scale, kv_len,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_ko, stride_h_ko, stride_seq_ko,
k_scale.stride(0), k_scale.stride(1),
sm_scale=1.0,
C=head_dim, BLK=BLKK,
)
return q_int8, q_scale, k_int8, k_scale

View File

@@ -0,0 +1,196 @@
"""
Sparse INT8 attention kernel for Sparse SageAttention.
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _attn_fwd_inner(
acc, l_i, old_m, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn, start_m,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += lo // BLOCK_N
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
elif STAGE == 3:
lo, hi = 0, kv_len
for start_n in range(lo, hi, BLOCK_N):
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
if kbid:
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask=k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk -= new_m[:, None]
else:
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk = qk - new_m[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(old_m - new_m)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
old_m = new_m
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += 1
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, old_m
@triton.jit
def _attn_fwd(
Q, K, K_blkid, V, Q_scale, K_scale, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_vz, stride_vh, stride_vn,
stride_oz, stride_oh, stride_on,
stride_kbidq, stride_kbidk,
qo_len, kv_len,
H: tl.constexpr, num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
k_scale_offset = (
off_z * (H // num_kv_groups) + off_h // num_kv_groups
) * tl.cdiv(kv_len, BLOCK_N)
k_bid_offset = (
off_z * (H // num_kv_groups) + off_h // num_kv_groups
) * stride_kbidq
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (off_z * stride_qz + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset + start_m
K_ptrs = (
K
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
V_ptrs = (
V
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
O_block_ptr = (
Out
+ (off_z * stride_oz + off_h * stride_oh)
+ offs_m[:, None] * stride_on
+ offs_k[None, :]
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
4 - STAGE, offs_m, offs_n,
)
if STAGE != 1:
acc, l_i, _ = _attn_fwd_inner(
acc, l_i, m_i, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
2, offs_m, offs_n,
)
acc = acc / l_i[:, None]
tl.store(
O_block_ptr,
acc.to(Out.type.element_ty),
mask=(offs_m[:, None] < qo_len),
)
def forward(
q, k, k_block_id, v, q_scale, k_scale,
is_causal=False, tensor_layout="HND", output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3 if is_causal else 1
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
else:
raise ValueError(f"tensor_layout {tensor_layout} not supported")
if is_causal:
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
_attn_fwd[grid](
q, k, k_block_id, v, q_scale, k_scale, o,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_v, stride_h_v, stride_seq_v,
stride_bz_o, stride_h_o, stride_seq_o,
k_block_id.stride(1), k_block_id.stride(2),
qo_len, kv_len,
h_qo, num_kv_groups,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return o

View File

@@ -0,0 +1,460 @@
import torch, os, gc
from safetensors import safe_open
from contextlib import contextmanager
from einops import rearrange, repeat
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import hashlib
CACHE_T = 2
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def clean_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.mps.is_available():
torch.mps.empty_cache()
def get_device_list():
devs = []
if torch.cuda.is_available():
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
if torch.mps.is_available():
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
return devs
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode='replicate') # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim = 2)
# print(out_x.shape)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
class Causal_LQ4x_Proj(nn.Module):
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
Key difference: reads old cache BEFORE writing new cache (truly causal),
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
"""
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
if i == 0:
self.cache['conv2'] = cache2_x
continue
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes AFTER
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs

View File

@@ -0,0 +1,865 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
import time
from typing import Tuple, Optional, List
from einops import rearrange
from .utils import hash_state_dict_keys
try:
import flash_attn_interface
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
FLASH_ATTN_3_AVAILABLE = True
except Exception:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
assert callable(getattr(flash_attn, "flash_attn_func", None))
FLASH_ATTN_2_AVAILABLE = True
except Exception:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
assert callable(sageattn)
SAGE_ATTN_AVAILABLE = True
except Exception:
SAGE_ATTN_AVAILABLE = False
try:
from .sparse_sage.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True
except Exception:
try:
from sageattn.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True
except Exception:
SPARSE_SAGE_AVAILABLE = False
sparse_sageattn = None
from PIL import Image
import numpy as np
print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, "
f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, "
f"sage_attn={SAGE_ATTN_AVAILABLE}")
# ----------------------------
# Local / window masks
# ----------------------------
@torch.no_grad()
def build_local_block_mask_shifted_vec(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
end_r = start_r + win_h - 1
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
@torch.no_grad()
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = r_all - r_half
end_r = start_r + win_h - 1
start_c = c_all - c_half
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
class WindowPartition3D:
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
@staticmethod
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
B, F, H, W, C = x.shape
wf, wh, ww = win
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
return x.view(-1, wf * wh * ww, C)
@staticmethod
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
F, H, W = orig
wf, wh, ww = win
nf, nh, nw = F // wf, H // wh, W // ww
B = windows.size(0) // (nf * nh * nw)
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
return x.view(B, F, H, W, -1)
@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_k = torch.mean(k_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_heads = avgpool_k.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
n = flat.shape[1]
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
# 修正:上行变量名统一
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
@torch.no_grad()
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
D = avgpool_q.shape[-1]
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
avgpool_k_split = torch.mean(k_w_split, dim=2)
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2)
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
scores = torch.cat([scores_1, scores_2], dim=-1)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
assert scores.shape == local_attn_mask.shape, \
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
apply_topk = min(flat.shape[1]-1, topk)
if apply_topk <= 0:
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
else:
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
# ----------------------------
# Attention kernels
# ----------------------------
def _sdpa_fallback(q, k, v, num_heads):
"""PyTorch scaled dot-product attention (always available)."""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
base_blockmask = attention_mask
x = sparse_sageattn(
q, k, v,
mask_id=base_blockmask.to(torch.int8),
is_causal=False,
tensor_layout="HND"
)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SPARSE_SAGE_AVAILABLE = False
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
x = _sdpa_fallback(q, k, v, num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SAGE_ATTN_AVAILABLE = False
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
x = _sdpa_fallback(q, k, v, num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
half_dim = max(dim // 2, 1)
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
inv_freq = torch.pow(10000.0, -scale / half_dim)
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
half_dim = max(dim // 2, 1)
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
freqs = torch.pow(theta, -base / max(dim, 1))
steps = torch.arange(end, dtype=torch.float64)
angles = torch.outer(steps, freqs)
return torch.polar(torch.ones_like(angles), angles)
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
orig_dtype = x.dtype
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
x_complex = torch.view_as_complex(reshaped)
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
return x_out.to(orig_dtype)
# ----------------------------
# Norms & Blocks
# ----------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads, enable_sageattention=True):
super().__init__()
self.num_heads = num_heads
self.enable_sageattention = enable_sageattention
def forward(self, q, k, v, attention_mask=None):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
self.local_attn_mask = None
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
B, L, D = x.shape
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
assert f==2, "f must be 2"
if is_stream and (pre_cache_k is None or pre_cache_v is None):
assert f==6, " start f must be 6"
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
win = (2, 8, 8)
q = q.view(B, f, h, w, D)
k = k.view(B, f, h, w, D)
v = v.view(B, f, h, w, D)
q_w = WindowPartition3D.partition(q, win)
k_w = WindowPartition3D.partition(k, win)
v_w = WindowPartition3D.partition(v, win)
seqlen = f//win[0]
one_len = k_w.shape[0] // B // seqlen
if pre_cache_k is not None and pre_cache_v is not None:
k_w = torch.cat([pre_cache_k, k_w], dim=0)
v_w = torch.cat([pre_cache_v, v_w], dim=0)
block_n = q_w.shape[0] // B
block_s = q_w.shape[1]
block_n_kv = k_w.shape[0] // B
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
window_size = win[0]*h*w//128
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
self.local_attn_mask_h = h//8
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
if cache_num > kv_len:
cache_k = k_w[one_len:, :, :]
cache_v = v_w[one_len:, :, :]
else:
cache_k = k_w
cache_v = v_w
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
x = WindowPartition3D.reverse(x, win, (f, h, w))
x = x.view(B, f*h*w, D)
if is_stream:
return self.o(x), cache_k, cache_v
return self.o(x)
class CrossAttention(nn.Module):
"""
仅考虑文本 context提供持久 KV 缓存。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
# 持久缓存
self.cache_k = None
self.cache_v = None
@torch.no_grad()
def init_cache(self, ctx: torch.Tensor):
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
self.cache_k = self.norm_k(self.k(ctx))
self.cache_v = self.v(ctx)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
"""
y 即文本上下文(未做其他分支)。
"""
q = self.norm_q(self.q(x))
assert self.cache_k is not None and self.cache_v is not None
k = self.cache_k
v = self.cache_v
x = self.attn(q, k, v)
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
x = self.gate(x, gate_msa, self_attn_output)
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
if is_stream:
return x, self_attn_cache_k, self_attn_cache_v
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
# ----------------------------
# WanModel (no image branch) — init 时即产生 KV 缓存
# ----------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool = False,
enable_sageattention: bool = True,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
# patch embed
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# text / time embed
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self._cross_kv_initialized = False
# 可选:手动清空 / 重新初始化
# 可选:手动清空 / 重新初始化
def clear_cross_kv(self):
for blk in self.blocks:
blk.cross_attn.clear_cache()
self._cross_kv_initialized = False
@torch.no_grad()
def reinit_cross_kv(self, new_context: torch.Tensor):
ctx_txt = self.text_embedding(new_context)
for blk in self.blocks:
blk.cross_attn.init_cache(ctx_txt)
self._cross_kv_initialized = True
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
LQ_latents: Optional[List[torch.Tensor]] = None,
train_img: bool = False,
topk_ratio: Optional[float] = None,
kv_ratio: Optional[float] = None,
local_num: Optional[int] = None,
is_full_block: bool = False,
causal_idx: Optional[int] = None,
**kwargs,
):
# time / text embeds
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# 这里仍会嵌入 textCrossAttention 若已有缓存会忽略它)
# context = self.text_embedding(context)
# 输入打补丁
x, (f, h, w) = self.patchify(x)
B = x.shape[0]
# window / masks 超参
win = (2, 8, 8)
seqlen = f//win[0]
if local_num is None:
local_random = random.random()
if local_random < 0.3:
local_num = seqlen - 3
elif local_random < 0.4:
local_num = seqlen - 4
elif local_random < 0.5:
local_num = seqlen - 2
else:
local_num = seqlen
window_size = win[0]*h*w//128
square_num = window_size*window_size
topk_ratio = 2.0
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
if kv_ratio is None:
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
decay_ratio = random.uniform(0.7, 1.0)
# RoPE 3D
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block_id, block in enumerate(self.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x += LQ_latents[block_id]
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
# ----------------------------
# State dict converter保持原映射已忽略 has_image_input 使用)
# ----------------------------
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
# 保留原有哈希匹配返回的 config实现本身不使用 has_image_input 分支
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,847 @@
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
CACHE_T = 2
def check_is_instance(model, module_class):
if isinstance(model, module_class):
return True
if hasattr(model, "module") and isinstance(model.module, module_class):
return True
return False
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
assert s % block_size == 0
num_blocks = s // block_size
# build mask
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
for i in range(num_blocks):
mask[:, :,
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
return mask
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(dim,
dim * 2, (3, 1, 1),
padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim,
dim, (3, 1, 1),
stride=(2, 1, 1),
padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
#attn_mask=block_causal_mask(q, block_size=h * w)
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if check_is_instance(m, CausalConv3d):
count += 1
return count
class VideoVAE_(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def stream_decode(self, z, scale):
# self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# print('self._feat_map', len(self._feat_map))
# cache encode
if self.encoder is not None:
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
# print('self._enc_feat_map', len(self._enc_feat_map))
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16, dim=96):
super().__init__()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
def clear_cache(self):
self.model.clear_cache()
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state for hidden_state in hidden_states]
assert len(hidden_states) == 1
hidden_state = hidden_states[0]
video = self.model.stream_decode(hidden_state, self.scale)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()
class WanVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,3 @@
from .flashvsr_full import FlashVSRFullPipeline
from .flashvsr_tiny import FlashVSRTinyPipeline
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline

View File

@@ -0,0 +1,127 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -0,0 +1,638 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRFullPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
self.TCDecoder.clean_mem()
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
cur_latents = cur_latents - noise_pred_posi
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
else:
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
# Per-chunk color correction
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,625 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-14B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
self.TCDecoder.clean_mem()
LQ_pre_idx = 0
LQ_cur_idx = 0
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
cur_latents = cur_latents - noise_pred_posi
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
# Per-chunk color correction
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,619 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyLongPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
self.TCDecoder.clean_mem()
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
# Decode
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
# 颜色校正wavelet
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1 @@
from .flow_match import FlowMatchScheduler

View File

@@ -0,0 +1,79 @@
import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -0,0 +1 @@
from .layers import *

View File

@@ -0,0 +1,95 @@
import torch, copy
from ..models.utils import init_weights_on_device
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
model.vram_management_enabled = True

View File

@@ -1,8 +1,11 @@
import logging import logging
import math
import os
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .bim_vfi_arch import BiMVFI from .bim_vfi_arch import BiMVFI
from .ema_vfi_arch import feature_extractor as ema_feature_extractor from .ema_vfi_arch import feature_extractor as ema_feature_extractor
@@ -621,3 +624,248 @@ class GIMMVFIModel:
results.append(torch.clamp(unpadded, 0, 1)) results.append(torch.clamp(unpadded, 0, 1))
return results return results
# ---------------------------------------------------------------------------
# FlashVSR model wrapper (4x video super-resolution)
# ---------------------------------------------------------------------------
class FlashVSRModel:
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
Supports three pipeline modes:
- full: Standard VAE decode, highest quality
- tiny: TCDecoder decode, faster
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
"""
# Minimum input frame count required by the pipeline
MIN_FRAMES = 21
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
from safetensors.torch import load_file
from .flashvsr_arch import (
ModelManager, FlashVSRFullPipeline,
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
)
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
self.mode = mode
self.device = device
self.dtype = dtype
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
mm = ModelManager(torch_dtype=dtype, device="cpu")
if mode == "full":
mm.load_models([dit_path, vae_path])
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
self.pipe.vae.model.encoder = None
self.pipe.vae.model.conv1 = None
else:
mm.load_models([dit_path])
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
self.pipe = Pipeline.from_model_manager(mm, device=device)
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
self.pipe.TCDecoder = build_tcdecoder(
[512, 256, 128, 128], device, dtype, 16 + 768,
)
self.pipe.TCDecoder.load_state_dict(
load_file(tcd_path, device=device), strict=False,
)
self.pipe.TCDecoder.clean_mem()
# LQ frame projection — Causal variant for FlashVSR v1.1
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
if os.path.exists(lq_path):
lq_sd = load_file(lq_path, device="cpu")
cleaned = {}
for k, v in lq_sd.items():
cleaned[k.removeprefix("LQ_proj_in.")] = v
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
self.pipe.denoising_model().LQ_proj_in.to(device)
self.pipe.to(device, dtype)
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
self.pipe.init_cross_kv(prompt_path=prompt_path)
self.pipe.load_models_to_device([]) # offload to CPU
def to(self, device):
self.device = device
self.pipe.device = device
return self
def load_to_device(self):
"""Load models to the compute device for inference."""
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
self.pipe.load_models_to_device(names)
def offload(self):
"""Offload models to CPU."""
self.pipe.load_models_to_device([])
def clear_caches(self):
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
self.pipe.denoising_model().LQ_proj_in.clear_cache()
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
self.pipe.vae.clear_cache()
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
self.pipe.TCDecoder.clean_mem()
# ------------------------------------------------------------------
# Frame preprocessing / postprocessing helpers
# ------------------------------------------------------------------
@staticmethod
def _compute_dims(w, h, scale, align=128):
sw, sh = w * scale, h * scale
tw = math.ceil(sw / align) * align
th = math.ceil(sh / align) * align
return sw, sh, tw, th
@staticmethod
def _restore_video_sequence(result, expected):
"""Trim pipeline output to the expected frame count."""
if result.shape[0] > expected:
result = result[:expected]
elif result.shape[0] < expected:
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
result = torch.cat([result, pad], dim=0)
return result
@staticmethod
def _next_8n5(n, minimum=21):
"""Next integer >= n of the form 8k+5 (minimum 21)."""
if n < minimum:
return minimum
return ((n - 5 + 7) // 8) * 8 + 5
def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
1. Bicubic-upscale each frame to target resolution
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
3. Normalize to [-1, 1]
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
Returns:
video: [1, C, F_padded, H, W] tensor
th, tw: padded spatial dimensions
nf: padded frame count
sh, sw: actual (unpadded) spatial dimensions
pad_top, pad_left: spatial padding offsets for output cropping
"""
N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale)
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
N_padded = self._next_8n5(N)
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
target = N_padded + 4
# Centered spatial padding offsets
pad_top = (th - sh) // 2
pad_bottom = th - sh - pad_top
pad_left = (tw - sw) // 2
pad_right = tw - sw - pad_left
processed = []
for i in range(target):
frame_idx = min(i, N - 1) # clamp to last real frame
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
# Centered reflect padding (matches naxci1 reference)
try:
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
except RuntimeError:
# Reflect requires pad < input size; fall back to replicate
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
normalized = upscaled * 2.0 - 1.0
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
nf = video.shape[2]
return video, th, tw, nf, sh, sw, pad_top, pad_left
@staticmethod
def _to_frames(video):
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
from einops import rearrange
v = video.squeeze(0) if video.dim() == 5 else video
v = rearrange(v, "C F H W -> F H W C")
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
# ------------------------------------------------------------------
# Main upscale method
# ------------------------------------------------------------------
@torch.no_grad()
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
color_fix=True, unload_dit=False, seed=1,
progress_bar_cmd=None):
"""Upscale video frames with FlashVSR.
Args:
frames: [F, H, W, C] float32 [0, 1] with F >= 21
scale: Upscaling factor (2 or 4)
tiled: Enable VAE tiled decode (saves VRAM)
tile_size: (H, W) tile size for VAE tiling
topk_ratio: Sparse attention ratio (higher = faster, less detail)
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
local_range: Local attention window (9=sharp, 11=stable)
color_fix: Apply wavelet color correction
unload_dit: Offload DiT before VAE decode (saves VRAM)
seed: Random seed
progress_bar_cmd: Callable wrapping an iterable for progress display
Returns:
[F, H*scale, W*scale, C] float32 [0, 1]
"""
if progress_bar_cmd is None:
from tqdm import tqdm
progress_bar_cmd = tqdm
original_count = frames.shape[0]
# Prepare video tensor (bicubic upscale + centered pad)
video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
# Move LQ video to compute device (except for "long" mode which streams)
if "long" not in self.pipe.__class__.__name__.lower():
video = video.to(self.pipe.device)
# Run pipeline
out = self.pipe(
prompt="", negative_prompt="",
cfg_scale=1.0, num_inference_steps=1,
seed=seed, tiled=tiled, tile_size=tile_size,
progress_bar_cmd=progress_bar_cmd,
LQ_video=video,
num_frames=nf, height=th, width=tw,
is_full_block=False, if_buffer=True,
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
kv_ratio=kv_ratio, local_range=local_range,
color_fix=color_fix, unload_dit=unload_dit,
)
# Convert to ComfyUI format with centered spatial crop
result = self._to_frames(out).cpu()
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
# Trim to original frame count
result = self._restore_video_sequence(result, original_count)
return result

418
nodes.py
View File

@@ -8,7 +8,7 @@ import torch
import folder_paths import folder_paths
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel
from .bim_vfi_arch import clear_backwarp_cache from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
@@ -1507,3 +1507,419 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
result = result[1:] # skip duplicate boundary frame result = result[1:] # skip duplicate boundary frame
return (result, model) return (result, model)
# ---------------------------------------------------------------------------
# FlashVSR nodes (4x video super-resolution)
# ---------------------------------------------------------------------------
FLASHVSR_HF_REPO = "1038lab/FlashVSR"
FLASHVSR_REQUIRED_FILES = [
"FlashVSR1_1.safetensors",
"Wan2.1_VAE.safetensors",
"LQ_proj_in.safetensors",
"TCDecoder.safetensors",
"Prompt.safetensors",
]
# Check common locations so we reuse models from 1038lab/ComfyUI-FlashVSR
FLASHVSR_MODEL_DIR = None
for _dirname in ("FlashVSR", "flashvsr"):
_candidate = os.path.join(folder_paths.models_dir, _dirname)
if os.path.isdir(_candidate) and all(
os.path.exists(os.path.join(_candidate, f)) for f in FLASHVSR_REQUIRED_FILES
):
FLASHVSR_MODEL_DIR = _candidate
break
if FLASHVSR_MODEL_DIR is None:
# Default to "FlashVSR" (matches 1038lab convention)
FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "FlashVSR")
def download_flashvsr_models(model_dir):
"""Download FlashVSR checkpoints from HuggingFace if missing."""
from huggingface_hub import snapshot_download
missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if not missing:
return
os.makedirs(model_dir, exist_ok=True)
logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...")
snapshot_download(
repo_id=FLASHVSR_HF_REPO,
local_dir=model_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
still_missing = [f for f in FLASHVSR_REQUIRED_FILES
if not os.path.exists(os.path.join(model_dir, f))]
if still_missing:
raise FileNotFoundError(
f"[FlashVSR] Failed to download: {', '.join(still_missing)}. "
f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}"
)
logger.info("[FlashVSR] All checkpoints downloaded successfully.")
class _FlashVSRProgressBar:
"""Wrap an iterable with a ComfyUI ProgressBar."""
def __init__(self, total, pbar, step_ref):
self.total = total
self.pbar = pbar
self.step_ref = step_ref
def __call__(self, iterable):
return self._Wrapper(iterable, self.pbar, self.step_ref)
class _Wrapper:
def __init__(self, iterable, pbar, step_ref):
self.iterable = iterable
self.pbar = pbar
self.step_ref = step_ref
self._iter = iter(iterable)
def __iter__(self):
return self
def __next__(self):
val = next(self._iter)
self.step_ref[0] += 1
self.pbar.update_absolute(self.step_ref[0])
return val
def __len__(self):
return len(self.iterable)
class LoadFlashVSRModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["tiny", "tiny-long", "full"], {
"default": "tiny",
"tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. "
"Tiny-long: streaming TCDecoder, lowest VRAM for long videos. "
"Full: standard VAE decode, highest quality but more VRAM.",
}),
"precision": (["bf16", "fp16"], {
"default": "bf16",
"tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.",
}),
}
}
RETURN_TYPES = ("FLASHVSR_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/FlashVSR"
def load_model(self, mode, precision):
download_flashvsr_models(FLASHVSR_MODEL_DIR)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
wrapper = FlashVSRModel(
model_dir=FLASHVSR_MODEL_DIR,
mode=mode,
device=device,
dtype=dtype,
)
logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})")
return (wrapper,)
class FlashVSRUpscale:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from the Load FlashVSR Model node.",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.",
}),
"frame_chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). "
"Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.",
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
"tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.",
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
}),
"kv_ratio": ("FLOAT", {
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
"tooltip": "KV cache ratio. Higher = better quality, more VRAM. 3.0 recommended.",
}),
"local_range": ([9, 11], {
"default": 11,
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability (recommended).",
}),
"color_fix": ("BOOLEAN", {
"default": True,
"tooltip": "Apply color correction to prevent color shifts from the diffusion process.",
}),
"unload_dit": ("BOOLEAN", {
"default": False,
"tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.",
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Random seed for the diffusion process.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, scale, frame_chunk_size,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
num_frames = images.shape[0]
if num_frames < FlashVSRModel.MIN_FRAMES:
raise ValueError(
f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}"
)
tile_size = (tile_size_h, tile_size_w)
# Build frame chunks
if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames:
chunks = [(0, num_frames)]
else:
chunks = []
start = 0
while start < num_frames:
end = min(start + frame_chunk_size, num_frames)
chunks.append((start, end))
if end == num_frames:
break
start = end
# If the last chunk is too small, merge it into the previous one
if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES:
prev_start = chunks[-2][0]
last_end = chunks[-1][1]
chunks = chunks[:-2]
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _prepare_video two-stage padding: next_8n5(N) + 4
def _next_8n5(n, minimum=21):
return minimum if n < minimum else ((n - 5 + 7) // 8) * 8 + 5
total_steps = 0
for cs, ce in chunks:
n = ce - cs
target = _next_8n5(n) + 4 # always 8k+1
total_steps += max(1, (target - 1) // 8 - 2)
pbar = ProgressBar(total_steps)
step_ref = [0]
progress = _FlashVSRProgressBar(total_steps, pbar, step_ref)
model.load_to_device()
result_chunks = []
for chunk_start, chunk_end in chunks:
chunk_frames = images[chunk_start:chunk_end]
chunk_result = model.upscale(
chunk_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
progress_bar_cmd=progress,
)
result_chunks.append(chunk_result)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
return (torch.cat(result_chunks, dim=0),)
class FlashVSRSegmentUpscale:
"""Process a numbered segment with temporal overlap and crossfade blending.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through forces sequential execution so each segment
saves and frees RAM before the next starts.
Crossfade blending within the overlap region:
- First (overlap - blend) frames: warmup only, discarded from output
- Last blend frames: linear alpha crossfade with previous segment's tail
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Full input video frames. Minimum 21 frames required.",
}),
"model": ("FLASHVSR_MODEL", {
"tooltip": "FlashVSR model from Load FlashVSR Model. "
"Chain the model output to the next segment node for sequential execution.",
}),
"segment_index": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based).",
}),
"segment_size": ("INT", {
"default": 100, "min": 21, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment.",
}),
"overlap_frames": ("INT", {
"default": 8, "min": 0, "max": 100, "step": 1,
"tooltip": "Number of overlapping frames between adjacent segments. "
"These frames provide temporal context and crossfade blending.",
}),
"blend_frames": ("INT", {
"default": 4, "min": 0, "max": 50, "step": 1,
"tooltip": "Number of frames within the overlap region to crossfade. "
"Must be <= overlap_frames. The rest of the overlap is warmup (discarded).",
}),
"scale": ("INT", {
"default": 4, "min": 2, "max": 4, "step": 2,
"tooltip": "Upscaling factor.",
}),
"tiled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable VAE tiled decode.",
}),
"tile_size_h": ("INT", {
"default": 60, "min": 16, "max": 256, "step": 4,
}),
"tile_size_w": ("INT", {
"default": 104, "min": 16, "max": 256, "step": 4,
}),
"topk_ratio": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"kv_ratio": ("FLOAT", {
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
}),
"local_range": ([9, 11], {
"default": 11,
}),
"color_fix": ("BOOLEAN", {
"default": True,
}),
"unload_dit": ("BOOLEAN", {
"default": False,
}),
"seed": ("INT", {
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
}),
}
}
RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "upscale"
CATEGORY = "video/FlashVSR"
def upscale(self, images, model, segment_index, segment_size,
overlap_frames, blend_frames, scale,
tiled, tile_size_h, tile_size_w,
topk_ratio, kv_ratio, local_range,
color_fix, unload_dit, seed):
total_input = images.shape[0]
blend_frames = min(blend_frames, overlap_frames)
# Clear stale overlap data from previous workflow runs
if segment_index == 0:
model._overlap_tail = None
# Compute segment boundaries
stride = segment_size - overlap_frames
start = segment_index * stride
end = min(start + segment_size, total_input)
if start >= total_input:
# Past the end
return (images[:1], model)
# Ensure minimum frame count
actual_size = end - start
if actual_size < FlashVSRModel.MIN_FRAMES:
start = max(0, end - FlashVSRModel.MIN_FRAMES)
actual_size = end - start
segment_frames = images[start:end]
tile_size = (tile_size_h, tile_size_w)
model.load_to_device()
result = model.upscale(
segment_frames,
scale=scale, tiled=tiled, tile_size=tile_size,
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
local_range=local_range, color_fix=color_fix,
unload_dit=unload_dit, seed=seed,
)
model.clear_caches()
model.offload()
from .flashvsr_arch.models.utils import clean_vram
clean_vram()
# Handle crossfade blending with previous segment's tail
if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'):
prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU
# The overlap region in result: first overlap_frames of the upscaled output
# Within overlap: first (overlap - blend) frames are warmup (discard)
# last blend_frames frames: crossfade with prev_tail
warmup = overlap_frames - blend_frames
if blend_frames > 0 and prev_tail is not None:
# Linear alpha ramp for crossfade
alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1)
blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames]
result = torch.cat([blended, result[overlap_frames:]], dim=0)
else:
result = result[overlap_frames:]
elif segment_index > 0 and overlap_frames > 0:
# No previous tail stored, just skip overlap
result = result[overlap_frames:]
# Store tail frames for next segment's crossfade
if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames:
model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16)
else:
model._overlap_tail = None
return (result, model)

View File

@@ -4,3 +4,4 @@ yacs
easydict easydict
einops einops
huggingface_hub huggingface_hub
safetensors