15 Commits

Author SHA1 Message Date
dd61ae8d1f Bundle sparse_sage Triton kernel for block-sparse attention
Without sparse attention, the model uses full (dense) attention which
attends to distant irrelevant information, causing ghosting artifacts.
The FlashVSR paper explicitly requires block-sparse attention.

Vendored from SageAttention team (Apache 2.0), pure Triton (no CUDA C++).
Import chain: local sparse_sage → external sageattn.core → SDPA fallback.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 19:22:40 +01:00
e7e7c1cb5a Fix sparse attention mask tiling for temporal windows
The local_attn_mask was not being tiled across temporal dimensions,
causing assertion errors in streaming mode and wrong masks otherwise.
Match naxci1 reference: 4D tile/rearrange for Q/K temporal windows,
chunk-based score computation, and topk<=0 guard.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:50:40 +01:00
3b87652184 Fix FlashVSR attention mask and output quality
- Use generate_draft_block_mask_refined for sparse attention mask (matches
  naxci1's generate_draft_block_mask_sage with proper half-block key scoring)
- Remove spurious repeat_interleave(2, dim=-1) from generate_draft_block_mask
  that doubled the key dimension incorrectly
- Add torch.clamp(0, 1) to _to_frames output (matches naxci1's tensor2video)
- Add .to(self.device) on LQ video slices in streaming loop for all pipelines

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:41:43 +01:00
76dff7e573 Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision
Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:06:46 +01:00
fa250897a2 Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection
Root cause: three critical differences from naxci1 reference implementation:

1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ
   conditioning inside the loop. The TCDecoder uses causal convolutions with
   temporal memory that must be built incrementally per-chunk. Batch decode
   breaks this design and loses LQ frame conditioning, causing ghosting.

2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal
   variant reads the OLD cache before writing the new one (truly causal),
   while Buffer writes cache BEFORE the conv call. Using the wrong variant
   misaligns temporal LQ conditioning features.

3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4)
   matching the naxci1 reference approach.

Changes:
- flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning
  and per-chunk color correction (was: batch VAE decode after loop)
- flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode)
- inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including
  full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in
  clear_caches()
- utils.py: add Causal_LQ4x_Proj class
- nodes.py: update progress bar estimation for new padding formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:42:20 +01:00
94d9818675 Fix FlashVSR quality: match naxci1 reference preprocessing
- Remove front dummy frames (not used by reference implementation)
- Use centered reflect padding instead of right/bottom replicate
- Crop output from center matching padding offsets
- Simplify temporal padding to 8k+1 alignment
- Update progress bar estimation to match new formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:10:12 +01:00
ea84ffef7c Fix FlashVSR ghosting: restore 2 front dummy frames matching reference
The pipeline's LQ conditioning indexing expects 2 front dummy frames
(copies of first frame) as warmup. Our previous refactoring removed
these, shifting all LQ conditioning by 2 frames and causing severe
ghosting artifacts.

Now matches the 1038lab reference preprocessing exactly:
1. _prepare_video: 2 tail copies + alignment + 2 front dummies + back padding
2. _restore_video_sequence: strip first 2 warmup frames + trim to original count
3. Crop pipeline output to padded_n before restoration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:49:46 +01:00
4cc6e9c705 Remove debug logging from FlashVSR SegmentUpscale
Issue was a workflow wiring mistake, not a code bug.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:32:23 +01:00
39d0f7af42 Add debug logging for FlashVSR SegmentUpscale output shapes
Helps diagnose issue where segment 1+ runs but produces no image output.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:31:09 +01:00
11e2acb9e0 Fix FlashVSR frame padding to match pipeline requirements
The pipeline requires num_frames % 4 == 1. Our old _pad_video_5d used a
wrong formula that produced non-conforming counts (e.g. 33 input → 35
padded → pipeline rounds to 37, wasting VRAM).

New padding uses num_frames % 8 == 1 (also satisfies % 4 == 1), which
ensures the streaming loop output exactly matches num_frames with zero
waste. Optimal input counts: 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105.

Also removes incorrect 2-frame warmup stripping from _restore_video_sequence
— the pipeline output doesn't have warmup artifacts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:20:02 +01:00
5071c4de4f Fix sageattn fallback: tensors already rearranged when exception fires
When sageattn fails, q/k/v are already in [b,n,s,d] format from the
rearrange before the call. Use SDPA directly on them instead of calling
_sdpa_fallback which expects [b,s,(n*d)] and crashes with a shape error.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:08:01 +01:00
dd69a2fd2b Fix sageattn crash on Blackwell GPUs (sm_120)
SageAttention CUDA kernels don't support Blackwell yet. Catch runtime
failures from sageattn/sparse_sageattn, disable them, and fall back to
PyTorch SDPA. Only pays the try/except cost once per session.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:03:15 +01:00
f40504cbcf Fix crash when flash_attn is installed but broken
Verify attention backend functions are actually callable before marking
them available. Falls back to PyTorch SDPA instead of calling None.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:51:30 +01:00
8317a0603e Reuse FlashVSR models from 1038lab node if already downloaded
Check models/FlashVSR/ (1038lab convention) before models/flashvsr/ to
avoid downloading ~7GB of checkpoints twice. Only create the directory
when actually downloading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:42:10 +01:00
0fecfcee37 Add FlashVSR support: diffusion-based 4x video super-resolution (Wan 2.1-1.3B)
Vendor minimal diffsynth subset for FlashVSR inference (full/tiny pipelines,
v1 and v1.1 checkpoints auto-downloaded from HuggingFace). Includes segment-based
processing with temporal overlap and crossfade blending for bounded RAM on long videos.

Nodes: Load FlashVSR Model, FlashVSR Upscale, FlashVSR Segment Upscale.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 15:12:33 +01:00
32 changed files with 6834 additions and 1386 deletions

View File

@@ -1,20 +0,0 @@
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- master
paths:
- "pyproject.toml"
jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}

463
README.md
View File

@@ -1,107 +1,41 @@
# Tween — Video Frame Interpolation for ComfyUI # ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR
[![ComfyUI](https://img.shields.io/badge/ComfyUI-Custom_Node-0a7ef0)](https://registry.comfy.org/) ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM.
[![Python 3.10+](https://img.shields.io/badge/Python-3.10+-3776AB?logo=python&logoColor=white)](https://www.python.org/)
[![License](https://img.shields.io/badge/License-Apache_2.0-green.svg)](https://www.apache.org/licenses/LICENSE-2.0)
[![Models](https://img.shields.io/badge/VFI_Models-4-8B5CF6)](#which-model-should-i-use)
Four video frame interpolation models in one package — **BIM-VFI**, **EMA-VFI**, **SGM-VFI**, and **GIMM-VFI**. Designed for long videos with thousands of frames without running out of VRAM.
<p align="center">
<img src="assets/model-comparison.svg" alt="Model Comparison" width="720"/>
</p>
## Installation
Install from the [ComfyUI Registry](https://registry.comfy.org/) (recommended) or clone manually:
```bash
cd ComfyUI/custom_nodes
git clone https://github.com/Ethanfel/ComfyUI-Tween.git
pip install -r requirements.txt
```
All dependencies (`gdown`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are declared in `pyproject.toml` and `requirements.txt`, installed automatically by ComfyUI Manager or pip.
### cupy (required for BIM-VFI, SGM-VFI, GIMM-VFI)
[cupy](https://cupy.dev/) provides GPU-accelerated optical flow warping. **EMA-VFI works without it.**
1. Find your CUDA version:
```bash
python -c "import torch; print(torch.version.cuda)"
```
2. Install the matching package:
| CUDA | Command |
|------|---------|
| 12.x | `pip install cupy-cuda12x` |
| 11.x | `pip install cupy-cuda11x` |
> Make sure to run pip in the same Python environment as ComfyUI. If cupy is missing, the Load node shows an error with your CUDA version and the exact install command.
<details>
<summary>cupy troubleshooting</summary>
| Problem | Solution |
|---------|----------|
| `ModuleNotFoundError: No module named 'cupy'` | Install cupy using the steps above |
| `cupy` installed but `ImportError` at runtime | CUDA version mismatch — uninstall and reinstall the correct version |
| Install hangs or takes very long | cupy wheels are ~800 MB, be patient |
| Docker / no build tools | Use the prebuilt wheel: `pip install cupy-cuda12x` (not bare `cupy` which compiles from source) |
</details>
## Which model should I use? ## Which model should I use?
| | BIM-VFI | EMA-VFI | SGM-VFI | GIMM-VFI | | | BIM-VFI | EMA-VFI | SGM-VFI | GIMM-VFI |
|---|---------|---------|---------|----------| |---|---------|---------|---------|----------|
| **Best for** | General-purpose | Fast, low VRAM | Large motion | High multipliers (4x/8x) | | **Best for** | General-purpose, non-uniform motion | Fast inference, light VRAM | Large motion, occlusion-heavy scenes | High multipliers (4x/8x) in a single pass |
| **Quality** | Highest | Good | Best on large motion | Good | | **Quality** | Highest overall | Good | Best on large motion | Good |
| **Speed** | Moderate | Fastest | Slowest | Fast for 4x/8x | | **Speed** | Moderate | Fastest | Slowest | Fast for 4x/8x (single pass) |
| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair | ~2.5 GB/pair | | **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair | ~2.5 GB/pair |
| **Params** | ~17 M | ~1465 M | ~15 M + GMFlow | ~80 M (RAFT) / ~123 M (FlowFormer) | | **Params** | ~17M | ~1465M | ~15M + GMFlow | ~80M (RAFT) / ~123M (FlowFormer) |
| **Arbitrary timestep** | Yes | Yes (`_t` checkpoint) | No (fixed 0.5) | Yes (native) | | **Arbitrary timestep** | Yes | Yes (with `_t` checkpoint) | No (fixed 0.5) | Yes (native single-pass) |
| **4x/8x** | Recursive passes | Recursive passes | Recursive passes | Single forward pass | | **4x/8x mode** | Recursive 2x passes | Recursive 2x passes | Recursive 2x passes | Single forward pass (or recursive) |
| **Requires cupy** | Yes | No | Yes | Yes |
| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 | NeurIPS 2024 | | **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 | NeurIPS 2024 |
| **License** | Research only | Apache 2.0 | Apache 2.0 | Apache 2.0 |
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** for speed or if you can't install cupy. Use **SGM-VFI** for large camera motion. Use **GIMM-VFI** for 4x/8x without recursive passes. **TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair.
## VRAM Guide ### Video Super-Resolution
| VRAM | Recommended settings | FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above.
|------|----------------------|
| 8 GB | `batch_size=1, chunk_size=500` | | | FlashVSR |
| 24 GB | `batch_size=24, chunk_size=1000` | |---|----------|
| 48 GB+ | `batch_size=416, all_on_gpu=true` | | **Task** | 4x video super-resolution |
| 96 GB+ | `batch_size=816, all_on_gpu=true, chunk_size=0` | | **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) |
| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) |
| **VRAM** | ~812 GB (tiled, tiny mode) / ~1624 GB (full mode) |
| **Params** | ~1.3B (DiT) + ~200M (VAE) |
| **Min input** | 21 frames |
| **Paper** | arXiv 2510.12747 |
| **License** | Apache 2.0 |
## Nodes ## Nodes
All Interpolate nodes share a common set of controls: ### BIM-VFI
| Input | Description |
|-------|-------------|
| **images** | Input image batch |
| **model** | Model from the loader node |
| **multiplier** | 2x, 4x, or 8x frame rate (recursive 2x passes) |
| **batch_size** | Frame pairs processed simultaneously (higher = faster, more VRAM) |
| **chunk_size** | Process in segments of N input frames (0 = disabled). Bounds VRAM for very long videos |
| **keep_device** | Keep model on GPU between pairs (faster, ~200 MB constant VRAM) |
| **all_on_gpu** | Keep all intermediate frames on GPU (fast, needs large VRAM) |
| **clear_cache_after_n_frames** | Clear CUDA cache every N pairs to prevent VRAM buildup |
| **source_fps** | Input frame rate. Required when target_fps > 0 |
| **target_fps** | Target output FPS. When > 0, overrides multiplier — auto-computes the optimal power-of-2 oversample then selects frames at exact target timestamps. 0 = use multiplier |
| Output | Description |
|--------|-------------|
| **images** | Interpolated frames at the target FPS (or at the multiplied rate when target_fps = 0) |
| **oversampled** | Full power-of-2 oversampled frames before target FPS selection. Same as `images` when target_fps = 0 |
<details>
<summary><strong>BIM-VFI</strong></summary>
#### Load BIM-VFI Model #### Load BIM-VFI Model
@@ -109,121 +43,221 @@ Loads the BiM-VFI checkpoint. Auto-downloads from Google Drive on first use to `
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| **model_path** | Checkpoint from `models/bim-vfi/` | | **model_path** | Checkpoint file from `models/bim-vfi/` |
| **auto_pyr_level** | Auto pyramid level by resolution (&lt;540p=3, 540p=5, 1080p=6, 4K=7) | | **auto_pyr_level** | Auto-select pyramid level by resolution (&lt;540p=3, 540p=5, 1080p=6, 4K=7) |
| **pyr_level** | Manual pyramid level (37), used when auto is off | | **pyr_level** | Manual pyramid level (3-7), only used when auto is off |
#### BIM-VFI Interpolate #### BIM-VFI Interpolate
Common controls listed above. Interpolates frames from an image batch.
| Input | Description |
|-------|-------------|
| **images** | Input image batch |
| **model** | Model from the loader node |
| **multiplier** | 2x, 4x, or 8x frame rate (recursive 2x passes) |
| **batch_size** | Frame pairs processed simultaneously (higher = faster, more VRAM) |
| **chunk_size** | Process in segments of N input frames (0 = disabled). Bounds VRAM for very long videos. Result is identical to processing all at once |
| **keep_device** | Keep model on GPU between pairs (faster, ~200MB constant VRAM) |
| **all_on_gpu** | Keep all intermediate frames on GPU (fast, needs large VRAM) |
| **clear_cache_after_n_frames** | Clear CUDA cache every N pairs to prevent VRAM buildup |
#### BIM-VFI Segment Interpolate #### BIM-VFI Segment Interpolate
Processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution. Same as Interpolate but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
</details>
<details>
<summary><strong>EMA-VFI</strong></summary>
#### Load EMA-VFI Model
Auto-downloads from Google Drive to `ComfyUI/models/ema-vfi/`. Variant and timestep support are auto-detected from the filename.
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint from `models/ema-vfi/` |
| **tta** | Test-time augmentation (~2x slower, slightly better quality) |
| Checkpoint | Variant | Params | Arbitrary timestep |
|-----------|---------|--------|-------------------|
| `ours_t.pkl` | Large | ~65 M | Yes |
| `ours.pkl` | Large | ~65 M | No (fixed 0.5) |
| `ours_small_t.pkl` | Small | ~14 M | Yes |
| `ours_small.pkl` | Small | ~14 M | No (fixed 0.5) |
#### EMA-VFI Interpolate / Segment Interpolate
Same controls as above.
</details>
<details>
<summary><strong>SGM-VFI</strong></summary>
#### Load SGM-VFI Model
Auto-downloads from Google Drive to `ComfyUI/models/sgm-vfi/`. Requires cupy.
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint from `models/sgm-vfi/` |
| **tta** | Test-time augmentation (~2x slower, slightly better quality) |
| **num_key_points** | Global matching sparsity (0.0 = global everywhere, 0.5 = default, higher = faster) |
| Checkpoint | Variant | Params |
|-----------|---------|--------|
| `ours-1-2-points.pkl` | Small | ~15 M + GMFlow |
#### SGM-VFI Interpolate / Segment Interpolate
Same controls as above.
</details>
<details>
<summary><strong>GIMM-VFI</strong></summary>
#### Load GIMM-VFI Model
Auto-downloads from [HuggingFace](https://huggingface.co/Kijai/GIMM-VFI_safetensors) to `ComfyUI/models/gimm-vfi/`. The matching flow estimator (RAFT or FlowFormer) is auto-detected and downloaded alongside.
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint from `models/gimm-vfi/` |
| **ds_factor** | Downscale factor for internal processing (1.0 = full, 0.5 = half). Try 0.5 for 4K inputs |
| Checkpoint | Variant | Params | Flow estimator (auto-downloaded) |
|-----------|---------|--------|----------------------------------|
| `gimmvfi_r_arb_lpips_fp32.safetensors` | RAFT | ~80 M | `raft-things_fp32.safetensors` |
| `gimmvfi_f_arb_lpips_fp32.safetensors` | FlowFormer | ~123 M | `flowformer_sintel_fp32.safetensors` |
#### GIMM-VFI Interpolate
Common controls plus:
| Input | Description |
|-------|-------------|
| **single_pass** | Generate all intermediate frames per pair in one forward pass (default on). No recursive 2x passes needed for 4x/8x. Disable to use the standard recursive approach |
#### GIMM-VFI Segment Interpolate
Same pattern as other Segment nodes.
</details>
### Tween Concat Videos ### Tween Concat Videos
Concatenates segment video files into a single video using ffmpeg. Connect from any Segment Interpolate's model output to ensure it runs after all segments are saved. Works with all four models. Concatenates segment video files into a single video using ffmpeg. Connect from any Segment Interpolate's model output to ensure it runs after all segments are saved. Works with all three models.
### Output frame count ### EMA-VFI
- **Multiplier mode:** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 #### Load EMA-VFI Model
- **Target FPS mode:** `floor((N-1) / source_fps * target_fps) + 1` frames. Automatically oversamples to the nearest power-of-2 above the ratio, then selects frames at exact target timestamps. Downsampling (target < source) also works — frames are selected from the input with no model calls.
Loads an EMA-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/ema-vfi/`. Variant (large/small) and timestep support are auto-detected from the filename.
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint file from `models/ema-vfi/` |
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
Available checkpoints:
| Checkpoint | Variant | Params | Arbitrary timestep |
|-----------|---------|--------|-------------------|
| `ours_t.pkl` | Large | ~65M | Yes |
| `ours.pkl` | Large | ~65M | No (fixed 0.5) |
| `ours_small_t.pkl` | Small | ~14M | Yes |
| `ours_small.pkl` | Small | ~14M | No (fixed 0.5) |
#### EMA-VFI Interpolate
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
#### EMA-VFI Segment Interpolate
Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
### SGM-VFI
#### Load SGM-VFI Model
Loads an SGM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/sgm-vfi/`. Variant (base/small) is auto-detected from the filename (default is small).
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint file from `models/sgm-vfi/` |
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
| **num_key_points** | Sparsity of global matching (0.0 = global everywhere, 0.5 = default balance, higher = faster) |
Available checkpoints:
| Checkpoint | Variant | Params |
|-----------|---------|--------|
| `ours-1-2-points.pkl` | Small | ~15M + GMFlow |
#### SGM-VFI Interpolate
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
#### SGM-VFI Segment Interpolate
Same as SGM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
### GIMM-VFI
#### Load GIMM-VFI Model
Loads a GIMM-VFI checkpoint. Auto-downloads from [HuggingFace](https://huggingface.co/Kijai/GIMM-VFI_safetensors) on first use to `ComfyUI/models/gimm-vfi/`. The matching flow estimator (RAFT or FlowFormer) is auto-detected and downloaded alongside the main model.
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint file from `models/gimm-vfi/` |
| **ds_factor** | Downscale factor for internal processing (1.0 = full res, 0.5 = half). Lower = less VRAM, faster, less quality. Try 0.5 for 4K inputs |
Available checkpoints:
| Checkpoint | Variant | Params | Flow estimator (auto-downloaded) |
|-----------|---------|--------|----------------------------------|
| `gimmvfi_r_arb_lpips_fp32.safetensors` | RAFT | ~80M | `raft-things_fp32.safetensors` |
| `gimmvfi_f_arb_lpips_fp32.safetensors` | FlowFormer | ~123M | `flowformer_sintel_fp32.safetensors` |
#### GIMM-VFI Interpolate
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, plus:
| Input | Description |
|-------|-------------|
| **single_pass** | When enabled (default), generates all intermediate frames per pair in one forward pass using GIMM-VFI's arbitrary-timestep capability. No recursive 2x passes needed for 4x or 8x. Disable to use the standard recursive approach (same as BIM/EMA/SGM) |
#### GIMM-VFI Segment Interpolate
Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
**Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
### FlashVSR
FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling.
#### Load FlashVSR Model
Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`.
| Input | Description |
|-------|-------------|
| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) |
| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) |
Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)):
| Checkpoint | Size | Description |
|-----------|------|-------------|
| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) |
| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE |
| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection |
| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) |
| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings |
#### FlashVSR Upscale
Upscales an image batch with 4x spatial super-resolution.
| Input | Description |
|-------|-------------|
| **images** | Input video frames (minimum 21 frames) |
| **model** | Model from the loader node |
| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) |
| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames |
| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) |
| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) |
| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) |
| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) |
| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability |
| **color_fix** | Apply wavelet color correction to prevent color shifts |
| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) |
| **seed** | Random seed for the diffusion process |
#### FlashVSR Segment Upscale
Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
| Input | Description |
|-------|-------------|
| **segment_index** | Which segment to process (0-based) |
| **segment_size** | Number of input frames per segment (minimum 21) |
| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending |
| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) |
Plus all the same upscale parameters as FlashVSR Upscale.
## Installation
Clone into your ComfyUI `custom_nodes/` directory:
```bash
cd ComfyUI/custom_nodes
git clone https://github.com/your-user/ComfyUI-Tween.git
```
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
> **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with:
> ```bash
> pip install cupy-cuda12x # replace 12 with your CUDA major version
> ```
To install manually:
```bash
cd ComfyUI-Tween
python install.py
```
### Requirements
- PyTorch with CUDA
- `cupy` (matching your CUDA version, for BIM-VFI, SGM-VFI, and GIMM-VFI)
- `timm` (for EMA-VFI and SGM-VFI)
- `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download)
- `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI)
- `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download)
- `safetensors` (for FlashVSR checkpoint loading)
## VRAM Guide
| VRAM | Recommended settings |
|------|---------------------|
| 8 GB | batch_size=1, chunk_size=500 |
| 24 GB | batch_size=2-4, chunk_size=1000 |
| 48 GB+ | batch_size=4-16, all_on_gpu=true |
| 96 GB+ | batch_size=8-16, all_on_gpu=true, chunk_size=0 |
## Acknowledgments ## Acknowledgments
| Model | Authors | Venue | Links | This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths).
|-------|---------|-------|-------|
| **BIM-VFI** | Seo, Oh, Kim (KAIST VIC Lab) | CVPR 2025 | [Paper](https://arxiv.org/abs/2412.11365) · [Code](https://github.com/KAIST-VICLab/BiM-VFI) · [Project](https://kaist-viclab.github.io/BiM-VFI_site/) |
| **EMA-VFI** | Zhang et al. (MCG-NJU) | CVPR 2023 | [Paper](https://arxiv.org/abs/2303.00440) · [Code](https://github.com/MCG-NJU/EMA-VFI) |
| **SGM-VFI** | Zhang et al. (MCG-NJU) | CVPR 2024 | [Paper](https://arxiv.org/abs/2404.06913) · [Code](https://github.com/MCG-NJU/SGM-VFI) |
| **GIMM-VFI** | Guo, Li, Loy (S-Lab NTU) | NeurIPS 2024 | [Paper](https://arxiv.org/abs/2407.08680) · [Code](https://github.com/GSeanCDAT/GIMM-VFI) |
GIMM-VFI adaptation from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications. **BiM-VFI:**
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
<details> > "BiM-VFI: Bidirectional Motion Field-Guided Frame Interpolation for Video with Non-uniform Motions."
<summary>BibTeX citations</summary> > *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2025.
> [[arXiv]](https://arxiv.org/abs/2412.11365) [[Project Page]](https://kaist-viclab.github.io/BiM-VFI_site/) [[GitHub]](https://github.com/KAIST-VICLab/BiM-VFI)
```bibtex ```bibtex
@inproceedings{seo2025bimvfi, @inproceedings{seo2025bimvfi,
@@ -232,21 +266,45 @@ GIMM-VFI adaptation from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/Comfy
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2025} year={2025}
} }
```
**EMA-VFI:**
> Guozhen Zhang, Yuhan Zhu, Haonan Wang, Youxin Chen, Gangshan Wu, and Limin Wang.
> "Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation."
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2023.
> [[arXiv]](https://arxiv.org/abs/2303.00440) [[GitHub]](https://github.com/MCG-NJU/EMA-VFI)
```bibtex
@inproceedings{zhang2023emavfi, @inproceedings{zhang2023emavfi,
title={Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation}, title={Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation},
author={Zhang, Guozhen and Zhu, Yuhan and Wang, Haonan and Chen, Youxin and Wu, Gangshan and Wang, Limin}, author={Zhang, Guozhen and Zhu, Yuhan and Wang, Haonan and Chen, Youxin and Wu, Gangshan and Wang, Limin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2023} year={2023}
} }
```
**SGM-VFI:**
> Guozhen Zhang, Yuhan Zhu, Evan Zheran Liu, Haonan Wang, Mingzhen Sun, Gangshan Wu, and Limin Wang.
> "Sparse Global Matching for Video Frame Interpolation with Large Motion."
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2024.
> [[arXiv]](https://arxiv.org/abs/2404.06913) [[GitHub]](https://github.com/MCG-NJU/SGM-VFI)
```bibtex
@inproceedings{zhang2024sgmvfi, @inproceedings{zhang2024sgmvfi,
title={Sparse Global Matching for Video Frame Interpolation with Large Motion}, title={Sparse Global Matching for Video Frame Interpolation with Large Motion},
author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin}, author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2024} year={2024}
} }
```
**GIMM-VFI:**
> Zujin Guo, Wei Li, and Chen Change Loy.
> "Generalizable Implicit Motion Modeling for Video Frame Interpolation."
> *Advances in Neural Information Processing Systems (NeurIPS)*, 2024.
> [[arXiv]](https://arxiv.org/abs/2407.08680) [[GitHub]](https://github.com/GSeanCDAT/GIMM-VFI)
```bibtex
@inproceedings{guo2024gimmvfi, @inproceedings{guo2024gimmvfi,
title={Generalizable Implicit Motion Modeling for Video Frame Interpolation}, title={Generalizable Implicit Motion Modeling for Video Frame Interpolation},
author={Guo, Zujin and Li, Wei and Loy, Chen Change}, author={Guo, Zujin and Li, Wei and Loy, Chen Change},
@@ -255,12 +313,29 @@ GIMM-VFI adaptation from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/Comfy
} }
``` ```
</details> **FlashVSR:**
> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng.
> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer."
> *arXiv preprint arXiv:2510.12747*, 2025.
> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR)
```bibtex
@article{zhuang2025flashvsr,
title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer},
author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing},
journal={arXiv preprint arXiv:2510.12747},
year={2025}
}
```
## License ## License
**BIM-VFI:** Research and education only. Commercial use requires permission from Prof. Munchurl Kim (mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI). The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
**EMA-VFI, SGM-VFI, GIMM-VFI:** [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). GIMM-VFI ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI). The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details.
**This wrapper code:** [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR).

View File

@@ -1,11 +1,60 @@
import subprocess
import sys
import logging
logger = logging.getLogger("Tween")
def _auto_install_deps():
"""Auto-install missing dependencies on first load."""
# gdown
try:
import gdown # noqa: F401
except ImportError:
logger.info("[Tween] Installing gdown...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
# timm (required for EMA-VFI's MotionFormer backbone)
try:
import timm # noqa: F401
except ImportError:
logger.info("[Tween] Installing timm...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "timm"])
# cupy
try:
import cupy # noqa: F401
except ImportError:
try:
import torch
major = int(torch.version.cuda.split(".")[0])
cupy_pkg = f"cupy-cuda{major}x"
logger.info(f"[Tween] Installing {cupy_pkg} (CUDA {torch.version.cuda})...")
subprocess.check_call([sys.executable, "-m", "pip", "install", cupy_pkg])
except Exception as e:
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
# GIMM-VFI + FlashVSR dependencies
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"):
try:
__import__(pkg)
except ImportError:
logger.info(f"[Tween] Installing {pkg}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
_auto_install_deps()
from .nodes import ( from .nodes import (
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos, LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
VFIOptimizer, LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale,
) )
WEB_DIRECTORY = "./web"
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LoadBIMVFIModel": LoadBIMVFIModel, "LoadBIMVFIModel": LoadBIMVFIModel,
"BIMVFIInterpolate": BIMVFIInterpolate, "BIMVFIInterpolate": BIMVFIInterpolate,
@@ -20,7 +69,9 @@ NODE_CLASS_MAPPINGS = {
"LoadGIMMVFIModel": LoadGIMMVFIModel, "LoadGIMMVFIModel": LoadGIMMVFIModel,
"GIMMVFIInterpolate": GIMMVFIInterpolate, "GIMMVFIInterpolate": GIMMVFIInterpolate,
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate, "GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
"VFIOptimizer": VFIOptimizer, "LoadFlashVSRModel": LoadFlashVSRModel,
"FlashVSRUpscale": FlashVSRUpscale,
"FlashVSRSegmentUpscale": FlashVSRSegmentUpscale,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@@ -37,5 +88,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadGIMMVFIModel": "Load GIMM-VFI Model", "LoadGIMMVFIModel": "Load GIMM-VFI Model",
"GIMMVFIInterpolate": "GIMM-VFI Interpolate", "GIMMVFIInterpolate": "GIMM-VFI Interpolate",
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate", "GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
"VFIOptimizer": "VFI Optimizer", "LoadFlashVSRModel": "Load FlashVSR Model",
"FlashVSRUpscale": "FlashVSR Upscale",
"FlashVSRSegmentUpscale": "FlashVSR Segment Upscale",
} }

View File

@@ -1,84 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 720 320" width="720" height="320">
<defs>
<linearGradient id="gQ" x1="0" y1="0" x2="1" y2="0">
<stop offset="0%" stop-color="#7aa2f7"/><stop offset="100%" stop-color="#7dcfff"/>
</linearGradient>
<linearGradient id="gS" x1="0" y1="0" x2="1" y2="0">
<stop offset="0%" stop-color="#9ece6a"/><stop offset="100%" stop-color="#73daca"/>
</linearGradient>
<linearGradient id="gV" x1="0" y1="0" x2="1" y2="0">
<stop offset="0%" stop-color="#bb9af7"/><stop offset="100%" stop-color="#d2a8ff"/>
</linearGradient>
</defs>
<!-- Background -->
<rect width="720" height="320" rx="16" fill="#0d1117"/>
<!-- ═══ BIM-VFI (top-left) ═══ -->
<rect x="10" y="10" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
<rect x="11" y="22" width="3" height="121" fill="#3fb950"/>
<text x="30" y="38" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">BIM-VFI</text>
<text x="30" y="56" fill="#3fb950" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">&#9733; Recommended &#183; Best quality &#183; CVPR 2025</text>
<line x1="30" y1="64" x2="330" y2="64" stroke="#30363d" stroke-width="0.5"/>
<text x="30" y="82" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
<rect x="88" y="72" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="72" width="244" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
<text x="30" y="100" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
<rect x="88" y="90" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="90" width="146" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
<text x="30" y="118" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
<rect x="88" y="108" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="108" width="195" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
<text x="262" y="143" fill="#f0883e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Research only</text>
<!-- ═══ EMA-VFI (top-right) ═══ -->
<rect x="370" y="10" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
<rect x="371" y="22" width="3" height="121" fill="#58a6ff"/>
<text x="390" y="38" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">EMA-VFI</text>
<text x="390" y="56" fill="#58a6ff" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Fastest &#183; No cupy needed &#183; CVPR 2023</text>
<line x1="390" y1="64" x2="690" y2="64" stroke="#30363d" stroke-width="0.5"/>
<text x="390" y="82" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
<rect x="448" y="72" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="72" width="146" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
<text x="390" y="100" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
<rect x="448" y="90" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="90" width="244" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
<text x="390" y="118" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
<rect x="448" y="108" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="108" width="244" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
<text x="632" y="143" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
<!-- ═══ SGM-VFI (bottom-left) ═══ -->
<rect x="10" y="165" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
<rect x="11" y="177" width="3" height="121" fill="#f0883e"/>
<text x="30" y="193" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">SGM-VFI</text>
<text x="30" y="211" fill="#f0883e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Large motion specialist &#183; CVPR 2024</text>
<line x1="30" y1="219" x2="330" y2="219" stroke="#30363d" stroke-width="0.5"/>
<text x="30" y="237" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
<rect x="88" y="227" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="227" width="195" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
<text x="30" y="255" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
<rect x="88" y="245" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="245" width="98" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
<text x="30" y="273" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
<rect x="88" y="263" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="88" y="263" width="98" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
<text x="272" y="298" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
<!-- ═══ GIMM-VFI (bottom-right) ═══ -->
<rect x="370" y="165" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
<rect x="371" y="177" width="3" height="121" fill="#bc8cff"/>
<text x="390" y="193" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">GIMM-VFI</text>
<text x="390" y="211" fill="#bc8cff" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Single-pass 4&#215;/8&#215; &#183; NeurIPS 2024</text>
<line x1="390" y1="219" x2="690" y2="219" stroke="#30363d" stroke-width="0.5"/>
<text x="390" y="237" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
<rect x="448" y="227" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="227" width="146" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
<text x="390" y="255" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
<rect x="448" y="245" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="245" width="195" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
<text x="390" y="273" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
<rect x="448" y="263" width="244" height="11" rx="3" fill="#21262d"/>
<rect x="448" y="263" width="146" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
<text x="632" y="298" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
</svg>

Before

Width:  |  Height:  |  Size: 7.8 KiB

View File

@@ -206,6 +206,100 @@
"2" "2"
] ]
}, },
{
"id": 12,
"type": "easy forLoopStart",
"pos": [
-8160,
576
],
"size": [
270,
138
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "initial_value1",
"shape": 7,
"type": "*",
"link": 68
},
{
"name": "total",
"type": "INT",
"widget": {
"name": "total"
},
"link": 33
},
{
"name": "initial_value2",
"type": "*",
"link": 44
},
{
"name": "initial_value3",
"type": "*",
"link": null
}
],
"outputs": [
{
"name": "flow",
"shape": 5,
"type": "FLOW_CONTROL",
"links": [
15
]
},
{
"name": "index",
"type": "INT",
"links": [
25,
26
]
},
{
"name": "value1",
"type": "*",
"links": [
18
]
},
{
"name": "value2",
"type": "*",
"links": [
21,
64
]
},
{
"name": "value3",
"type": "*",
"links": null
}
],
"properties": {
"cnr_id": "comfyui-easy-use",
"ver": "7c470c67d6df44498e52c902173c1ac77cd5bdfd",
"Node name for S&R": "easy forLoopStart",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
6
],
"color": "#223",
"bgcolor": "#335"
},
{ {
"id": 13, "id": 13,
"type": "easy forLoopEnd", "type": "easy forLoopEnd",
@@ -277,6 +371,85 @@
"color": "#223", "color": "#223",
"bgcolor": "#335" "bgcolor": "#335"
}, },
{
"id": 11,
"type": "BIMVFISegmentInterpolate",
"pos": [
-7584,
576
],
"size": [
321.58209228515625,
246
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 21
},
{
"name": "model",
"type": "BIM_VFI_MODEL",
"link": 18
},
{
"name": "segment_index",
"type": "INT",
"widget": {
"name": "segment_index"
},
"link": 25
},
{
"name": "segment_size",
"type": "INT",
"widget": {
"name": "segment_size"
},
"link": 35
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
66
]
},
{
"name": "model",
"type": "BIM_VFI_MODEL",
"links": [
67
]
}
],
"properties": {
"aux_id": "Comfyui-BIM-VFI.git",
"ver": "7cf7162143eaa5b0939e0e122f80bc956baf65ea",
"Node name for S&R": "BIMVFISegmentInterpolate",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
2,
40,
true,
true,
1,
0,
0,
500
]
},
{ {
"id": 3, "id": 3,
"type": "LoadBIMVFIModel", "type": "LoadBIMVFIModel",
@@ -388,6 +561,7 @@
"video/", "video/",
"tween_sgm", "tween_sgm",
"tween_video_sgm.mp4", "tween_video_sgm.mp4",
true,
true true
] ]
}, },
@@ -400,7 +574,7 @@
], ],
"size": [ "size": [
544, 544,
334 352
], ],
"flags": {}, "flags": {},
"order": 10, "order": 10,
@@ -473,227 +647,11 @@
} }
} }
}, },
{
"id": 16,
"type": "PrimitiveInt",
"pos": [
-9184,
544
],
"size": [
270,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "INT",
"type": "INT",
"links": [
31,
35
]
}
],
"title": "Frames number each loops",
"properties": {
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "PrimitiveInt",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
100,
"fixed"
]
},
{
"id": 12,
"type": "easy forLoopStart",
"pos": [
-8160,
576
],
"size": [
270,
138
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "initial_value1",
"shape": 7,
"type": "*",
"link": 68
},
{
"name": "total",
"type": "INT",
"widget": {
"name": "total"
},
"link": 33
},
{
"name": "initial_value2",
"type": "*",
"link": 44
},
{
"name": "initial_value3",
"type": "*",
"link": null
}
],
"outputs": [
{
"name": "flow",
"shape": 5,
"type": "FLOW_CONTROL",
"links": [
15
]
},
{
"name": "index",
"type": "INT",
"links": [
25,
26
]
},
{
"name": "value1",
"type": "*",
"links": [
18
]
},
{
"name": "value2",
"type": "*",
"links": [
21,
64
]
},
{
"name": "value3",
"type": "*",
"links": null
}
],
"properties": {
"cnr_id": "comfyui-easy-use",
"ver": "7c470c67d6df44498e52c902173c1ac77cd5bdfd",
"Node name for S&R": "easy forLoopStart",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
6
],
"color": "#223",
"bgcolor": "#335"
},
{
"id": 11,
"type": "BIMVFISegmentInterpolate",
"pos": [
-7584,
576
],
"size": [
321.58209228515625,
294
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 21
},
{
"name": "model",
"type": "BIM_VFI_MODEL",
"link": 18
},
{
"name": "segment_index",
"type": "INT",
"widget": {
"name": "segment_index"
},
"link": 25
},
{
"name": "segment_size",
"type": "INT",
"widget": {
"name": "segment_size"
},
"link": 35
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
66
]
},
{
"name": "model",
"type": "BIM_VFI_MODEL",
"links": [
67
]
}
],
"properties": {
"aux_id": "Comfyui-BIM-VFI.git",
"ver": "7cf7162143eaa5b0939e0e122f80bc956baf65ea",
"Node name for S&R": "BIMVFISegmentInterpolate",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
2,
40,
true,
true,
1,
500,
16,
0,
0,
500
]
},
{ {
"id": 28, "id": 28,
"type": "VHS_LoadVideoPath", "type": "VHS_LoadVideoPath",
"pos": [ "pos": [
-9184, -9152,
704 704
], ],
"size": [ "size": [
@@ -701,7 +659,7 @@
286 286
], ],
"flags": {}, "flags": {},
"order": 3, "order": 2,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@@ -780,6 +738,47 @@
} }
} }
} }
},
{
"id": 16,
"type": "PrimitiveInt",
"pos": [
-9152,
576
],
"size": [
270,
82
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "INT",
"type": "INT",
"links": [
31,
35
]
}
],
"title": "Frames number each loops",
"properties": {
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "PrimitiveInt",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.6.2"
}
},
"widgets_values": [
100,
"fixed"
]
} }
], ],
"links": [ "links": [
@@ -934,10 +933,10 @@
"workflowRendererVersion": "LG", "workflowRendererVersion": "LG",
"ue_links": [], "ue_links": [],
"ds": { "ds": {
"scale": 0.8954302432552531, "scale": 1.0834705943388552,
"offset": [ "offset": [
10389.297857289295, 10009.878269742538,
79.21414284327875 -100.68482917709798
] ]
}, },
"links_added_by_ue": [], "links_added_by_ue": [],

View File

@@ -0,0 +1,4 @@
from .models.model_manager import ModelManager
from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
from .models.utils import clean_vram, Buffer_LQ4x_Proj
from .models.TCDecoder import build_tcdecoder

View File

View File

@@ -0,0 +1,21 @@
from ..models.wan_video_dit import WanModel
from ..models.wan_video_vae import WanVideoVAE
model_loader_configs = [
# (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
]
huggingface_model_loader_configs = [
]
patch_model_loader_configs = [
]

View File

@@ -0,0 +1,320 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
- Encoder removed
- Transplant/widening helpers removed
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
from einops import rearrange
import torch.nn.init as init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
# ----------------------------
# Utility / building blocks
# ----------------------------
class IdentityConv2d(nn.Conv2d):
"""Same-shape Conv2d initialized to identity (Dirac)."""
def __init__(self, C, kernel_size=3, bias=False):
pad = kernel_size // 2
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
with torch.no_grad():
init.dirac_(self.weight)
if self.bias is not None:
self.bias.zero_()
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out)
)
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
B, C, F, H, W = x.shape
if F % self.ff != 0:
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
x = torch.cat([first_frame, x], dim=2)
return rearrange(
x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww
).transpose(1, 2)
# ----------------------------
# Generic NTCHW graph executor (kept; used by decoder)
# ----------------------------
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
progress_bar = tqdm(range(T), disable=not show_progress_bar)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt)
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt)
if len(mem[i]) > b.stride:
raise ValueError("TPool internal state invalid.")
elif len(mem[i]) == b.stride:
N_, C_, H_, W_ = xt.shape
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
mem[i] = []
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C_, H_, W_ = xt.shape
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x, mem
# ----------------------------
# Decoder-only TAEHV
# ----------------------------
class TAEHV(nn.Module):
image_channels = 3
def __init__(
self,
checkpoint_path="taehv.pth",
decoder_time_upscale=(True, True),
decoder_space_upscale=(True, True, True),
channels = [256, 128, 64, 64],
latent_channels = 16
):
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
Deepening config: how_many_each=1, k=3 (fixed as requested).
"""
super().__init__()
self.latent_channels = latent_channels
n_f = channels
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
# Build the decoder "skeleton"
base_decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
if checkpoint_path is not None:
missing_keys = self.load_state_dict(
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
strict=False
)
print('missing_keys', missing_keys)
# Initialize decoder mem state
self.mem = [None] * len(self.decoder)
@staticmethod
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
new_layers = []
for b in decoder:
new_layers.append(b)
if isinstance(b, nn.ReLU):
# Deduce channel count from preceding layer
C = None
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
C = new_layers[-2].out_channels
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
C = new_layers[-2].conv[-1].out_channels
if C is not None:
for _ in range(how_many_each):
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
new_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*new_layers)
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
"""Decode a sequence of frames from latents.
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
"""
trim_flag = self.mem[-8] is None # keeps original relative check
if cond is not None:
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
if trim_flag:
return x[:, self.frames_to_trim:]
return x
def forward(self, *args, **kwargs):
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
def clean_mem(self):
self.mem = [None] * len(self.decoder)
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEW2_1DiffusersWrapper(nn.Module):
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
super().__init__()
self.dtype = torch.bfloat16
self.device = "cuda"
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
self.temperal_downsample = [True, True, False] # [sic]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
def decode(self, latents, return_dict=None):
n, c, t, h, w = latents.shape
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
n, c, t, h, w = latents.shape
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
def clean_mem(self):
self.taehv.clean_mem()
# ----------------------------
# Simplified builder (no small, no transplant, no post-hoc deepening)
# ----------------------------
def build_tcdecoder(new_channels = [512, 256, 128, 128],
device="cuda",
dtype=torch.bfloat16,
new_latent_channels=None):
"""
构建“更宽”的 decoder深度增强IdentityConv2d+ReLU已在 TAEHV 内部完成。
- 不创建 small / 不做移植
- base_ckpt_path 参数保留但不使用(接口兼容)
返回big (单个模型)
"""
if new_latent_channels is not None:
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
else:
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
big.clean_mem()
return big

View File

@@ -0,0 +1 @@
from .model_manager import *

View File

@@ -0,0 +1,402 @@
import os, torch, json, importlib
from typing import List
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
with init_weights_on_device():
model = model_class(**extra_kwargs)
if hasattr(model, "eval"):
model = model.eval()
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
is_loaded = False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
is_loaded = True
break
if not is_loaded:
print(f" Cannot load LoRA: {file_path}")
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
#print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
#print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
#print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -0,0 +1,3 @@
from .core import sparse_sageattn
__all__ = ["sparse_sageattn"]

View File

@@ -0,0 +1,40 @@
"""
Sparse SageAttention — block-sparse INT8 attention via Triton.
https://github.com/jt-zhang/Sparse_SageAttention_API
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
from .quant_per_block import per_block_int8
from .sparse_int8_attn import forward as sparse_sageattn_fwd
import torch
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
if mask_id is None:
mask_id = torch.ones(
(q.shape[0], q.shape[1],
(q.shape[2] + 128 - 1) // 128,
(q.shape[3] + 64 - 1) // 64),
dtype=torch.int8, device=q.device,
)
output_dtype = q.dtype
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
v = v.to(torch.float16)
seq_dim = 1 if tensor_layout == "NHD" else 2
km = k.mean(dim=seq_dim, keepdim=True)
q_int8, q_scale, k_int8, k_scale = per_block_int8(
q, k, km=km, tensor_layout=tensor_layout,
)
o = sparse_sageattn_fwd(
q_int8, k_int8, mask_id, v, q_scale, k_scale,
is_causal=is_causal, tensor_layout=tensor_layout,
output_dtype=output_dtype,
)
return o

View File

@@ -0,0 +1,110 @@
"""
Per-block INT8 quantization kernel for Sparse SageAttention.
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
import torch
import triton
import triton.language as tl
@triton.jit
def quant_per_block_int8_kernel(
Input, Output, Scale, L,
stride_iz, stride_ih, stride_in,
stride_oz, stride_oh, stride_on,
stride_sz, stride_sh,
sm_scale,
C: tl.constexpr, BLK: tl.constexpr,
):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = (
Input
+ off_b * stride_iz
+ off_h * stride_ih
+ offs_n[:, None] * stride_in
+ offs_k[None, :]
)
output_ptrs = (
Output
+ off_b * stride_oz
+ off_h * stride_oh
+ offs_n[:, None] * stride_on
+ offs_k[None, :]
)
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.0
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
if km is not None:
k = k - km
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
else:
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
q_scale = torch.empty(
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32,
)
k_scale = torch.empty(
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32,
)
if sm_scale is None:
sm_scale = head_dim ** -0.5
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q, q_int8, q_scale, qo_len,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_qo, stride_h_qo, stride_seq_qo,
q_scale.stride(0), q_scale.stride(1),
sm_scale=(sm_scale * 1.44269504),
C=head_dim, BLK=BLKQ,
)
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k, k_int8, k_scale, kv_len,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_ko, stride_h_ko, stride_seq_ko,
k_scale.stride(0), k_scale.stride(1),
sm_scale=1.0,
C=head_dim, BLK=BLKK,
)
return q_int8, q_scale, k_int8, k_scale

View File

@@ -0,0 +1,196 @@
"""
Sparse INT8 attention kernel for Sparse SageAttention.
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _attn_fwd_inner(
acc, l_i, old_m, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn, start_m,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += lo // BLOCK_N
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
elif STAGE == 3:
lo, hi = 0, kv_len
for start_n in range(lo, hi, BLOCK_N):
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
if kbid:
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask=k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk -= new_m[:, None]
else:
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk = qk - new_m[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(old_m - new_m)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
old_m = new_m
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += 1
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, old_m
@triton.jit
def _attn_fwd(
Q, K, K_blkid, V, Q_scale, K_scale, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_vz, stride_vh, stride_vn,
stride_oz, stride_oh, stride_on,
stride_kbidq, stride_kbidk,
qo_len, kv_len,
H: tl.constexpr, num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
k_scale_offset = (
off_z * (H // num_kv_groups) + off_h // num_kv_groups
) * tl.cdiv(kv_len, BLOCK_N)
k_bid_offset = (
off_z * (H // num_kv_groups) + off_h // num_kv_groups
) * stride_kbidq
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (off_z * stride_qz + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset + start_m
K_ptrs = (
K
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
V_ptrs = (
V
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
O_block_ptr = (
Out
+ (off_z * stride_oz + off_h * stride_oh)
+ offs_m[:, None] * stride_on
+ offs_k[None, :]
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
4 - STAGE, offs_m, offs_n,
)
if STAGE != 1:
acc, l_i, _ = _attn_fwd_inner(
acc, l_i, m_i, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
2, offs_m, offs_n,
)
acc = acc / l_i[:, None]
tl.store(
O_block_ptr,
acc.to(Out.type.element_ty),
mask=(offs_m[:, None] < qo_len),
)
def forward(
q, k, k_block_id, v, q_scale, k_scale,
is_causal=False, tensor_layout="HND", output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3 if is_causal else 1
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
else:
raise ValueError(f"tensor_layout {tensor_layout} not supported")
if is_causal:
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
_attn_fwd[grid](
q, k, k_block_id, v, q_scale, k_scale, o,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_v, stride_h_v, stride_seq_v,
stride_bz_o, stride_h_o, stride_seq_o,
k_block_id.stride(1), k_block_id.stride(2),
qo_len, kv_len,
h_qo, num_kv_groups,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return o

View File

@@ -0,0 +1,460 @@
import torch, os, gc
from safetensors import safe_open
from contextlib import contextmanager
from einops import rearrange, repeat
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import hashlib
CACHE_T = 2
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def clean_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.mps.is_available():
torch.mps.empty_cache()
def get_device_list():
devs = []
if torch.cuda.is_available():
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
if torch.mps.is_available():
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
return devs
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode='replicate') # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim = 2)
# print(out_x.shape)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
class Causal_LQ4x_Proj(nn.Module):
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
Key difference: reads old cache BEFORE writing new cache (truly causal),
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
"""
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
if i == 0:
self.cache['conv2'] = cache2_x
continue
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes AFTER
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs

View File

@@ -0,0 +1,865 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
import time
from typing import Tuple, Optional, List
from einops import rearrange
from .utils import hash_state_dict_keys
try:
import flash_attn_interface
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
FLASH_ATTN_3_AVAILABLE = True
except Exception:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
assert callable(getattr(flash_attn, "flash_attn_func", None))
FLASH_ATTN_2_AVAILABLE = True
except Exception:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
assert callable(sageattn)
SAGE_ATTN_AVAILABLE = True
except Exception:
SAGE_ATTN_AVAILABLE = False
try:
from .sparse_sage.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True
except Exception:
try:
from sageattn.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True
except Exception:
SPARSE_SAGE_AVAILABLE = False
sparse_sageattn = None
from PIL import Image
import numpy as np
print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, "
f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, "
f"sage_attn={SAGE_ATTN_AVAILABLE}")
# ----------------------------
# Local / window masks
# ----------------------------
@torch.no_grad()
def build_local_block_mask_shifted_vec(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
end_r = start_r + win_h - 1
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
@torch.no_grad()
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = r_all - r_half
end_r = start_r + win_h - 1
start_c = c_all - c_half
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
class WindowPartition3D:
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
@staticmethod
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
B, F, H, W, C = x.shape
wf, wh, ww = win
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
return x.view(-1, wf * wh * ww, C)
@staticmethod
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
F, H, W = orig
wf, wh, ww = win
nf, nh, nw = F // wf, H // wh, W // ww
B = windows.size(0) // (nf * nh * nw)
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
return x.view(B, F, H, W, -1)
@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_k = torch.mean(k_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_heads = avgpool_k.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
n = flat.shape[1]
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
# 修正:上行变量名统一
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
@torch.no_grad()
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
D = avgpool_q.shape[-1]
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
avgpool_k_split = torch.mean(k_w_split, dim=2)
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2)
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
scores = torch.cat([scores_1, scores_2], dim=-1)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
assert scores.shape == local_attn_mask.shape, \
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
apply_topk = min(flat.shape[1]-1, topk)
if apply_topk <= 0:
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
else:
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
# ----------------------------
# Attention kernels
# ----------------------------
def _sdpa_fallback(q, k, v, num_heads):
"""PyTorch scaled dot-product attention (always available)."""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
base_blockmask = attention_mask
x = sparse_sageattn(
q, k, v,
mask_id=base_blockmask.to(torch.int8),
is_causal=False,
tensor_layout="HND"
)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SPARSE_SAGE_AVAILABLE = False
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
x = _sdpa_fallback(q, k, v, num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SAGE_ATTN_AVAILABLE = False
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
x = _sdpa_fallback(q, k, v, num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
half_dim = max(dim // 2, 1)
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
inv_freq = torch.pow(10000.0, -scale / half_dim)
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
half_dim = max(dim // 2, 1)
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
freqs = torch.pow(theta, -base / max(dim, 1))
steps = torch.arange(end, dtype=torch.float64)
angles = torch.outer(steps, freqs)
return torch.polar(torch.ones_like(angles), angles)
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
orig_dtype = x.dtype
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
x_complex = torch.view_as_complex(reshaped)
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
return x_out.to(orig_dtype)
# ----------------------------
# Norms & Blocks
# ----------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads, enable_sageattention=True):
super().__init__()
self.num_heads = num_heads
self.enable_sageattention = enable_sageattention
def forward(self, q, k, v, attention_mask=None):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
self.local_attn_mask = None
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
B, L, D = x.shape
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
assert f==2, "f must be 2"
if is_stream and (pre_cache_k is None or pre_cache_v is None):
assert f==6, " start f must be 6"
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
win = (2, 8, 8)
q = q.view(B, f, h, w, D)
k = k.view(B, f, h, w, D)
v = v.view(B, f, h, w, D)
q_w = WindowPartition3D.partition(q, win)
k_w = WindowPartition3D.partition(k, win)
v_w = WindowPartition3D.partition(v, win)
seqlen = f//win[0]
one_len = k_w.shape[0] // B // seqlen
if pre_cache_k is not None and pre_cache_v is not None:
k_w = torch.cat([pre_cache_k, k_w], dim=0)
v_w = torch.cat([pre_cache_v, v_w], dim=0)
block_n = q_w.shape[0] // B
block_s = q_w.shape[1]
block_n_kv = k_w.shape[0] // B
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
window_size = win[0]*h*w//128
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
self.local_attn_mask_h = h//8
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
if cache_num > kv_len:
cache_k = k_w[one_len:, :, :]
cache_v = v_w[one_len:, :, :]
else:
cache_k = k_w
cache_v = v_w
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
x = WindowPartition3D.reverse(x, win, (f, h, w))
x = x.view(B, f*h*w, D)
if is_stream:
return self.o(x), cache_k, cache_v
return self.o(x)
class CrossAttention(nn.Module):
"""
仅考虑文本 context提供持久 KV 缓存。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
# 持久缓存
self.cache_k = None
self.cache_v = None
@torch.no_grad()
def init_cache(self, ctx: torch.Tensor):
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
self.cache_k = self.norm_k(self.k(ctx))
self.cache_v = self.v(ctx)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
"""
y 即文本上下文(未做其他分支)。
"""
q = self.norm_q(self.q(x))
assert self.cache_k is not None and self.cache_v is not None
k = self.cache_k
v = self.cache_v
x = self.attn(q, k, v)
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
x = self.gate(x, gate_msa, self_attn_output)
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
if is_stream:
return x, self_attn_cache_k, self_attn_cache_v
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
# ----------------------------
# WanModel (no image branch) — init 时即产生 KV 缓存
# ----------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool = False,
enable_sageattention: bool = True,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
# patch embed
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# text / time embed
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self._cross_kv_initialized = False
# 可选:手动清空 / 重新初始化
# 可选:手动清空 / 重新初始化
def clear_cross_kv(self):
for blk in self.blocks:
blk.cross_attn.clear_cache()
self._cross_kv_initialized = False
@torch.no_grad()
def reinit_cross_kv(self, new_context: torch.Tensor):
ctx_txt = self.text_embedding(new_context)
for blk in self.blocks:
blk.cross_attn.init_cache(ctx_txt)
self._cross_kv_initialized = True
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
LQ_latents: Optional[List[torch.Tensor]] = None,
train_img: bool = False,
topk_ratio: Optional[float] = None,
kv_ratio: Optional[float] = None,
local_num: Optional[int] = None,
is_full_block: bool = False,
causal_idx: Optional[int] = None,
**kwargs,
):
# time / text embeds
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# 这里仍会嵌入 textCrossAttention 若已有缓存会忽略它)
# context = self.text_embedding(context)
# 输入打补丁
x, (f, h, w) = self.patchify(x)
B = x.shape[0]
# window / masks 超参
win = (2, 8, 8)
seqlen = f//win[0]
if local_num is None:
local_random = random.random()
if local_random < 0.3:
local_num = seqlen - 3
elif local_random < 0.4:
local_num = seqlen - 4
elif local_random < 0.5:
local_num = seqlen - 2
else:
local_num = seqlen
window_size = win[0]*h*w//128
square_num = window_size*window_size
topk_ratio = 2.0
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
if kv_ratio is None:
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
decay_ratio = random.uniform(0.7, 1.0)
# RoPE 3D
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block_id, block in enumerate(self.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x += LQ_latents[block_id]
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
# ----------------------------
# State dict converter保持原映射已忽略 has_image_input 使用)
# ----------------------------
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
# 保留原有哈希匹配返回的 config实现本身不使用 has_image_input 分支
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,847 @@
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
CACHE_T = 2
def check_is_instance(model, module_class):
if isinstance(model, module_class):
return True
if hasattr(model, "module") and isinstance(model.module, module_class):
return True
return False
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
assert s % block_size == 0
num_blocks = s // block_size
# build mask
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
for i in range(num_blocks):
mask[:, :,
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
return mask
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(dim,
dim * 2, (3, 1, 1),
padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim,
dim, (3, 1, 1),
stride=(2, 1, 1),
padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
#attn_mask=block_causal_mask(q, block_size=h * w)
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if check_is_instance(m, CausalConv3d):
count += 1
return count
class VideoVAE_(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def stream_decode(self, z, scale):
# self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# print('self._feat_map', len(self._feat_map))
# cache encode
if self.encoder is not None:
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
# print('self._enc_feat_map', len(self._enc_feat_map))
class WanVideoVAE(nn.Module):
def __init__(self, z_dim=16, dim=96):
super().__init__()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
def clear_cache(self):
self.model.clear_cache()
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state for hidden_state in hidden_states]
assert len(hidden_states) == 1
hidden_state = hidden_states[0]
video = self.model.stream_decode(hidden_state, self.scale)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()
class WanVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for name in state_dict:
state_dict_['model.' + name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,3 @@
from .flashvsr_full import FlashVSRFullPipeline
from .flashvsr_tiny import FlashVSRTinyPipeline
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline

View File

@@ -0,0 +1,127 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -0,0 +1,638 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRFullPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
self.TCDecoder.clean_mem()
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
cur_latents = cur_latents - noise_pred_posi
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
else:
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
# Per-chunk color correction
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,625 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-14B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
self.TCDecoder.clean_mem()
LQ_pre_idx = 0
LQ_cur_idx = 0
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
if str(current_dit_device) != str(self.device):
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
self.dit.to(self.device)
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
cur_latents = cur_latents - noise_pred_posi
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
# Per-chunk color correction
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1,619 @@
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
# import pyfiglet
from ..models.utils import clean_vram
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# Safetensors support ---------
# -----------------------------
st_load_file = None # Define the variable in global scope first
try:
from safetensors.torch import load_file as st_load_file
except ImportError:
# st_load_file remains None if import fails
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline仅 dit + vae
# -----------------------------
class FlashVSRTinyLongPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None,
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
# --- Safetensors loading logic added here ---
prompt_path_lower = prompt_path.lower()
if prompt_path_lower.endswith(".safetensors"):
if st_load_file is None:
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
# Load the tensor from safetensors
loaded_dict = st_load_file(prompt_path, device=self.device)
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
if len(loaded_dict) == 1:
ctx = list(loaded_dict.values())[0]
elif 'context' in loaded_dict: # Common key for text context
ctx = loaded_dict['context']
else:
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
else:
# Default behavior for .pth, .pt, etc.
ctx = torch.load(prompt_path, map_location=self.device)
# --------------------------------------------
# ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
skip_vae = False,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
" pipe.init_cross_kv()\n"
"或传入自定义 context\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
self.TCDecoder.clean_mem()
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents = None
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
cur = self.denoising_model().LQ_proj_in.stream_forward(
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
) if LQ_video is not None else None
if cur is None:
continue
if LQ_latents is None:
LQ_latents = cur
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
# Decode
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
# 颜色校正wavelet
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]
# -----------------------------
# TeaCache保留原逻辑此处默认不启用
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache默认不启用
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v

View File

@@ -0,0 +1 @@
from .flow_match import FlowMatchScheduler

View File

@@ -0,0 +1,79 @@
import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -0,0 +1 @@
from .layers import *

View File

@@ -0,0 +1,95 @@
import torch, copy
from ..models.utils import init_weights_on_device
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
model.vram_management_enabled = True

View File

@@ -1,8 +1,11 @@
import logging import logging
import math
import os
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .bim_vfi_arch import BiMVFI from .bim_vfi_arch import BiMVFI
from .ema_vfi_arch import feature_extractor as ema_feature_extractor from .ema_vfi_arch import feature_extractor as ema_feature_extractor
@@ -621,3 +624,248 @@ class GIMMVFIModel:
results.append(torch.clamp(unpadded, 0, 1)) results.append(torch.clamp(unpadded, 0, 1))
return results return results
# ---------------------------------------------------------------------------
# FlashVSR model wrapper (4x video super-resolution)
# ---------------------------------------------------------------------------
class FlashVSRModel:
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
Supports three pipeline modes:
- full: Standard VAE decode, highest quality
- tiny: TCDecoder decode, faster
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
"""
# Minimum input frame count required by the pipeline
MIN_FRAMES = 21
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
from safetensors.torch import load_file
from .flashvsr_arch import (
ModelManager, FlashVSRFullPipeline,
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
)
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
self.mode = mode
self.device = device
self.dtype = dtype
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
mm = ModelManager(torch_dtype=dtype, device="cpu")
if mode == "full":
mm.load_models([dit_path, vae_path])
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
self.pipe.vae.model.encoder = None
self.pipe.vae.model.conv1 = None
else:
mm.load_models([dit_path])
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
self.pipe = Pipeline.from_model_manager(mm, device=device)
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
self.pipe.TCDecoder = build_tcdecoder(
[512, 256, 128, 128], device, dtype, 16 + 768,
)
self.pipe.TCDecoder.load_state_dict(
load_file(tcd_path, device=device), strict=False,
)
self.pipe.TCDecoder.clean_mem()
# LQ frame projection — Causal variant for FlashVSR v1.1
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
if os.path.exists(lq_path):
lq_sd = load_file(lq_path, device="cpu")
cleaned = {}
for k, v in lq_sd.items():
cleaned[k.removeprefix("LQ_proj_in.")] = v
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
self.pipe.denoising_model().LQ_proj_in.to(device)
self.pipe.to(device, dtype)
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
self.pipe.init_cross_kv(prompt_path=prompt_path)
self.pipe.load_models_to_device([]) # offload to CPU
def to(self, device):
self.device = device
self.pipe.device = device
return self
def load_to_device(self):
"""Load models to the compute device for inference."""
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
self.pipe.load_models_to_device(names)
def offload(self):
"""Offload models to CPU."""
self.pipe.load_models_to_device([])
def clear_caches(self):
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
self.pipe.denoising_model().LQ_proj_in.clear_cache()
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
self.pipe.vae.clear_cache()
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
self.pipe.TCDecoder.clean_mem()
# ------------------------------------------------------------------
# Frame preprocessing / postprocessing helpers
# ------------------------------------------------------------------
@staticmethod
def _compute_dims(w, h, scale, align=128):
sw, sh = w * scale, h * scale
tw = math.ceil(sw / align) * align
th = math.ceil(sh / align) * align
return sw, sh, tw, th
@staticmethod
def _restore_video_sequence(result, expected):
"""Trim pipeline output to the expected frame count."""
if result.shape[0] > expected:
result = result[:expected]
elif result.shape[0] < expected:
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
result = torch.cat([result, pad], dim=0)
return result
@staticmethod
def _next_8n5(n, minimum=21):
"""Next integer >= n of the form 8k+5 (minimum 21)."""
if n < minimum:
return minimum
return ((n - 5 + 7) // 8) * 8 + 5
def _prepare_video(self, frames, scale):
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
1. Bicubic-upscale each frame to target resolution
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
3. Normalize to [-1, 1]
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
Returns:
video: [1, C, F_padded, H, W] tensor
th, tw: padded spatial dimensions
nf: padded frame count
sh, sw: actual (unpadded) spatial dimensions
pad_top, pad_left: spatial padding offsets for output cropping
"""
N, H, W, C = frames.shape
sw, sh, tw, th = self._compute_dims(W, H, scale)
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
N_padded = self._next_8n5(N)
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
target = N_padded + 4
# Centered spatial padding offsets
pad_top = (th - sh) // 2
pad_bottom = th - sh - pad_top
pad_left = (tw - sw) // 2
pad_right = tw - sw - pad_left
processed = []
for i in range(target):
frame_idx = min(i, N - 1) # clamp to last real frame
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
# Centered reflect padding (matches naxci1 reference)
try:
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
except RuntimeError:
# Reflect requires pad < input size; fall back to replicate
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
normalized = upscaled * 2.0 - 1.0
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
nf = video.shape[2]
return video, th, tw, nf, sh, sw, pad_top, pad_left
@staticmethod
def _to_frames(video):
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
from einops import rearrange
v = video.squeeze(0) if video.dim() == 5 else video
v = rearrange(v, "C F H W -> F H W C")
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
# ------------------------------------------------------------------
# Main upscale method
# ------------------------------------------------------------------
@torch.no_grad()
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
color_fix=True, unload_dit=False, seed=1,
progress_bar_cmd=None):
"""Upscale video frames with FlashVSR.
Args:
frames: [F, H, W, C] float32 [0, 1] with F >= 21
scale: Upscaling factor (2 or 4)
tiled: Enable VAE tiled decode (saves VRAM)
tile_size: (H, W) tile size for VAE tiling
topk_ratio: Sparse attention ratio (higher = faster, less detail)
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
local_range: Local attention window (9=sharp, 11=stable)
color_fix: Apply wavelet color correction
unload_dit: Offload DiT before VAE decode (saves VRAM)
seed: Random seed
progress_bar_cmd: Callable wrapping an iterable for progress display
Returns:
[F, H*scale, W*scale, C] float32 [0, 1]
"""
if progress_bar_cmd is None:
from tqdm import tqdm
progress_bar_cmd = tqdm
original_count = frames.shape[0]
# Prepare video tensor (bicubic upscale + centered pad)
video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
# Move LQ video to compute device (except for "long" mode which streams)
if "long" not in self.pipe.__class__.__name__.lower():
video = video.to(self.pipe.device)
# Run pipeline
out = self.pipe(
prompt="", negative_prompt="",
cfg_scale=1.0, num_inference_steps=1,
seed=seed, tiled=tiled, tile_size=tile_size,
progress_bar_cmd=progress_bar_cmd,
LQ_video=video,
num_frames=nf, height=th, width=tw,
is_full_block=False, if_buffer=True,
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
kv_ratio=kv_ratio, local_range=local_range,
color_fix=color_fix, unload_dit=unload_dit,
)
# Convert to ComfyUI format with centered spatial crop
result = self._to_frames(out).cpu()
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
# Trim to original frame count
result = self._restore_video_sequence(result, original_count)
return result

1344
nodes.py

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +0,0 @@
[project]
name = "comfyui-tween"
description = "Video frame interpolation nodes for ComfyUI using BIM-VFI, EMA-VFI, SGM-VFI, and GIMM-VFI. Designed for long videos with thousands of frames."
version = "1.1.0"
license = "Apache-2.0"
requires-python = ">=3.10"
dependencies = [
"gdown",
"timm",
"omegaconf",
"yacs",
"easydict",
"einops",
"huggingface_hub",
]
[project.urls]
Repository = "https://github.com/Ethanfel/ComfyUI-Tween"
[tool.comfy]
PublisherId = "ethanfel"
DisplayName = "Tween - Video Frame Interpolation"

View File

@@ -1,7 +1,7 @@
gdown gdown
timm
omegaconf omegaconf
yacs yacs
easydict easydict
einops einops
huggingface_hub huggingface_hub
safetensors

72
web/js/tween_preview.js Normal file
View File

@@ -0,0 +1,72 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
function fitHeight(node) {
node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]);
node?.graph?.setDirtyCanvas(true);
}
app.registerExtension({
name: "Tween.VideoPreview",
async beforeRegisterNodeDef(nodeType, nodeData) {
if (nodeData?.name !== "TweenConcatVideos") return;
const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
onNodeCreated?.apply(this, arguments);
const container = document.createElement("div");
const previewWidget = this.addDOMWidget("videopreview", "preview", container, {
serialize: false,
hideOnZoom: false,
getValue() { return container.value; },
setValue(v) { container.value = v; },
});
previewWidget.computeSize = function (width) {
if (this.aspectRatio && !this.videoEl.hidden) {
const height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
return [width, height > 0 ? height : -4];
}
return [width, -4];
};
const previewNode = this;
previewWidget.videoEl = document.createElement("video");
previewWidget.videoEl.controls = true;
previewWidget.videoEl.loop = true;
previewWidget.videoEl.muted = true;
previewWidget.videoEl.style.width = "100%";
previewWidget.videoEl.hidden = true;
previewWidget.videoEl.addEventListener("loadedmetadata", () => {
previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
fitHeight(previewNode);
});
previewWidget.videoEl.addEventListener("error", () => {
previewWidget.videoEl.hidden = true;
fitHeight(previewNode);
});
container.appendChild(previewWidget.videoEl);
};
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (message) {
onExecuted?.apply(this, arguments);
if (!message?.gifs?.length) return;
const params = message.gifs[0];
const previewWidget = this.widgets?.find((w) => w.name === "videopreview");
if (!previewWidget) return;
const query = new URLSearchParams(params);
query.set("timestamp", Date.now());
previewWidget.videoEl.src = api.apiURL("/view?" + query);
previewWidget.videoEl.hidden = false;
previewWidget.videoEl.autoplay = true;
};
},
});