Compare commits
20 Commits
old-master
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 91947c0b8c | |||
| c4b69321bb | |||
| f1da0f7876 | |||
| 27c5bcf362 | |||
| d2e7db49c7 | |||
| 9f66233b53 | |||
| 7257c1aa4d | |||
| ebece55ed7 | |||
| a60fb2a25e | |||
| c178f756da | |||
| fb921ae620 | |||
| 4723dc329d | |||
| 8fe382e5ec | |||
| 8311fd0261 | |||
| 396dafeefc | |||
| 13a89c5831 | |||
| 2f1cc17f5c | |||
| b2d7d3b634 | |||
| adc4451716 | |||
| 6dd579dcc7 |
20
.github/workflows/publish.yml
vendored
Normal file
20
.github/workflows/publish.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
name: Publish to Comfy registry
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- "pyproject.toml"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish-node:
|
||||||
|
name: Publish Custom Node to registry
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Publish Custom Node
|
||||||
|
uses: Comfy-Org/publish-node-action@main
|
||||||
|
with:
|
||||||
|
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
||||||
373
README.md
373
README.md
@@ -1,55 +1,86 @@
|
|||||||
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR
|
# Tween — Video Frame Interpolation for ComfyUI
|
||||||
|
|
||||||
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM.
|
[](https://registry.comfy.org/)
|
||||||
|
[](https://www.python.org/)
|
||||||
|
[](https://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
[](#which-model-should-i-use)
|
||||||
|
|
||||||
|
Four video frame interpolation models in one package — **BIM-VFI**, **EMA-VFI**, **SGM-VFI**, and **GIMM-VFI**. Designed for long videos with thousands of frames without running out of VRAM.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/model-comparison.svg" alt="Model Comparison" width="720"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install from the [ComfyUI Registry](https://registry.comfy.org/) (recommended) or clone manually:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ComfyUI/custom_nodes
|
||||||
|
git clone https://github.com/Ethanfel/ComfyUI-Tween.git
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
All dependencies (`gdown`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are declared in `pyproject.toml` and `requirements.txt`, installed automatically by ComfyUI Manager or pip.
|
||||||
|
|
||||||
|
### cupy (required for BIM-VFI, SGM-VFI, GIMM-VFI)
|
||||||
|
|
||||||
|
[cupy](https://cupy.dev/) provides GPU-accelerated optical flow warping. **EMA-VFI works without it.**
|
||||||
|
|
||||||
|
1. Find your CUDA version:
|
||||||
|
```bash
|
||||||
|
python -c "import torch; print(torch.version.cuda)"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install the matching package:
|
||||||
|
|
||||||
|
| CUDA | Command |
|
||||||
|
|------|---------|
|
||||||
|
| 12.x | `pip install cupy-cuda12x` |
|
||||||
|
| 11.x | `pip install cupy-cuda11x` |
|
||||||
|
|
||||||
|
> Make sure to run pip in the same Python environment as ComfyUI. If cupy is missing, the Load node shows an error with your CUDA version and the exact install command.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>cupy troubleshooting</summary>
|
||||||
|
|
||||||
|
| Problem | Solution |
|
||||||
|
|---------|----------|
|
||||||
|
| `ModuleNotFoundError: No module named 'cupy'` | Install cupy using the steps above |
|
||||||
|
| `cupy` installed but `ImportError` at runtime | CUDA version mismatch — uninstall and reinstall the correct version |
|
||||||
|
| Install hangs or takes very long | cupy wheels are ~800 MB, be patient |
|
||||||
|
| Docker / no build tools | Use the prebuilt wheel: `pip install cupy-cuda12x` (not bare `cupy` which compiles from source) |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Which model should I use?
|
## Which model should I use?
|
||||||
|
|
||||||
| | BIM-VFI | EMA-VFI | SGM-VFI | GIMM-VFI |
|
| | BIM-VFI | EMA-VFI | SGM-VFI | GIMM-VFI |
|
||||||
|---|---------|---------|---------|----------|
|
|---|---------|---------|---------|----------|
|
||||||
| **Best for** | General-purpose, non-uniform motion | Fast inference, light VRAM | Large motion, occlusion-heavy scenes | High multipliers (4x/8x) in a single pass |
|
| **Best for** | General-purpose | Fast, low VRAM | Large motion | High multipliers (4x/8x) |
|
||||||
| **Quality** | Highest overall | Good | Best on large motion | Good |
|
| **Quality** | Highest | Good | Best on large motion | Good |
|
||||||
| **Speed** | Moderate | Fastest | Slowest | Fast for 4x/8x (single pass) |
|
| **Speed** | Moderate | Fastest | Slowest | Fast for 4x/8x |
|
||||||
| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair | ~2.5 GB/pair |
|
| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair | ~2.5 GB/pair |
|
||||||
| **Params** | ~17M | ~14–65M | ~15M + GMFlow | ~80M (RAFT) / ~123M (FlowFormer) |
|
| **Params** | ~17 M | ~14–65 M | ~15 M + GMFlow | ~80 M (RAFT) / ~123 M (FlowFormer) |
|
||||||
| **Arbitrary timestep** | Yes | Yes (with `_t` checkpoint) | No (fixed 0.5) | Yes (native single-pass) |
|
| **Arbitrary timestep** | Yes | Yes (`_t` checkpoint) | No (fixed 0.5) | Yes (native) |
|
||||||
| **4x/8x mode** | Recursive 2x passes | Recursive 2x passes | Recursive 2x passes | Single forward pass (or recursive) |
|
| **4x/8x** | Recursive passes | Recursive passes | Recursive passes | Single forward pass |
|
||||||
|
| **Requires cupy** | Yes | No | Yes | Yes |
|
||||||
| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 | NeurIPS 2024 |
|
| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 | NeurIPS 2024 |
|
||||||
| **License** | Research only | Apache 2.0 | Apache 2.0 | Apache 2.0 |
|
|
||||||
|
|
||||||
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair.
|
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** for speed or if you can't install cupy. Use **SGM-VFI** for large camera motion. Use **GIMM-VFI** for 4x/8x without recursive passes.
|
||||||
|
|
||||||
### Video Super-Resolution
|
## VRAM Guide
|
||||||
|
|
||||||
FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above.
|
| VRAM | Recommended settings |
|
||||||
|
|------|----------------------|
|
||||||
| | FlashVSR |
|
| 8 GB | `batch_size=1, chunk_size=500` |
|
||||||
|---|----------|
|
| 24 GB | `batch_size=2–4, chunk_size=1000` |
|
||||||
| **Task** | 4x video super-resolution |
|
| 48 GB+ | `batch_size=4–16, all_on_gpu=true` |
|
||||||
| **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) |
|
| 96 GB+ | `batch_size=8–16, all_on_gpu=true, chunk_size=0` |
|
||||||
| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) |
|
|
||||||
| **VRAM** | ~8–12 GB (tiled, tiny mode) / ~16–24 GB (full mode) |
|
|
||||||
| **Params** | ~1.3B (DiT) + ~200M (VAE) |
|
|
||||||
| **Min input** | 21 frames |
|
|
||||||
| **Paper** | arXiv 2510.12747 |
|
|
||||||
| **License** | Apache 2.0 |
|
|
||||||
|
|
||||||
## Nodes
|
## Nodes
|
||||||
|
|
||||||
### BIM-VFI
|
All Interpolate nodes share a common set of controls:
|
||||||
|
|
||||||
#### Load BIM-VFI Model
|
|
||||||
|
|
||||||
Loads the BiM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/bim-vfi/`.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| **model_path** | Checkpoint file from `models/bim-vfi/` |
|
|
||||||
| **auto_pyr_level** | Auto-select pyramid level by resolution (<540p=3, 540p=5, 1080p=6, 4K=7) |
|
|
||||||
| **pyr_level** | Manual pyramid level (3-7), only used when auto is off |
|
|
||||||
|
|
||||||
#### BIM-VFI Interpolate
|
|
||||||
|
|
||||||
Interpolates frames from an image batch.
|
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
@@ -57,207 +88,142 @@ Interpolates frames from an image batch.
|
|||||||
| **model** | Model from the loader node |
|
| **model** | Model from the loader node |
|
||||||
| **multiplier** | 2x, 4x, or 8x frame rate (recursive 2x passes) |
|
| **multiplier** | 2x, 4x, or 8x frame rate (recursive 2x passes) |
|
||||||
| **batch_size** | Frame pairs processed simultaneously (higher = faster, more VRAM) |
|
| **batch_size** | Frame pairs processed simultaneously (higher = faster, more VRAM) |
|
||||||
| **chunk_size** | Process in segments of N input frames (0 = disabled). Bounds VRAM for very long videos. Result is identical to processing all at once |
|
| **chunk_size** | Process in segments of N input frames (0 = disabled). Bounds VRAM for very long videos |
|
||||||
| **keep_device** | Keep model on GPU between pairs (faster, ~200MB constant VRAM) |
|
| **keep_device** | Keep model on GPU between pairs (faster, ~200 MB constant VRAM) |
|
||||||
| **all_on_gpu** | Keep all intermediate frames on GPU (fast, needs large VRAM) |
|
| **all_on_gpu** | Keep all intermediate frames on GPU (fast, needs large VRAM) |
|
||||||
| **clear_cache_after_n_frames** | Clear CUDA cache every N pairs to prevent VRAM buildup |
|
| **clear_cache_after_n_frames** | Clear CUDA cache every N pairs to prevent VRAM buildup |
|
||||||
|
| **source_fps** | Input frame rate. Required when target_fps > 0 |
|
||||||
|
| **target_fps** | Target output FPS. When > 0, overrides multiplier — auto-computes the optimal power-of-2 oversample then selects frames at exact target timestamps. 0 = use multiplier |
|
||||||
|
|
||||||
|
| Output | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| **images** | Interpolated frames at the target FPS (or at the multiplied rate when target_fps = 0) |
|
||||||
|
| **oversampled** | Full power-of-2 oversampled frames before target FPS selection. Same as `images` when target_fps = 0 |
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>BIM-VFI</strong></summary>
|
||||||
|
|
||||||
|
#### Load BIM-VFI Model
|
||||||
|
|
||||||
|
Loads the BiM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/bim-vfi/`.
|
||||||
|
|
||||||
|
| Input | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| **model_path** | Checkpoint from `models/bim-vfi/` |
|
||||||
|
| **auto_pyr_level** | Auto pyramid level by resolution (<540p=3, 540p=5, 1080p=6, 4K=7) |
|
||||||
|
| **pyr_level** | Manual pyramid level (3–7), used when auto is off |
|
||||||
|
|
||||||
|
#### BIM-VFI Interpolate
|
||||||
|
|
||||||
|
Common controls listed above.
|
||||||
|
|
||||||
#### BIM-VFI Segment Interpolate
|
#### BIM-VFI Segment Interpolate
|
||||||
|
|
||||||
Same as Interpolate but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
|
Processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
|
||||||
|
|
||||||
### Tween Concat Videos
|
</details>
|
||||||
|
|
||||||
Concatenates segment video files into a single video using ffmpeg. Connect from any Segment Interpolate's model output to ensure it runs after all segments are saved. Works with all three models.
|
<details>
|
||||||
|
<summary><strong>EMA-VFI</strong></summary>
|
||||||
### EMA-VFI
|
|
||||||
|
|
||||||
#### Load EMA-VFI Model
|
#### Load EMA-VFI Model
|
||||||
|
|
||||||
Loads an EMA-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/ema-vfi/`. Variant (large/small) and timestep support are auto-detected from the filename.
|
Auto-downloads from Google Drive to `ComfyUI/models/ema-vfi/`. Variant and timestep support are auto-detected from the filename.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| **model_path** | Checkpoint file from `models/ema-vfi/` |
|
| **model_path** | Checkpoint from `models/ema-vfi/` |
|
||||||
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
|
| **tta** | Test-time augmentation (~2x slower, slightly better quality) |
|
||||||
|
|
||||||
Available checkpoints:
|
|
||||||
| Checkpoint | Variant | Params | Arbitrary timestep |
|
| Checkpoint | Variant | Params | Arbitrary timestep |
|
||||||
|-----------|---------|--------|-------------------|
|
|-----------|---------|--------|-------------------|
|
||||||
| `ours_t.pkl` | Large | ~65M | Yes |
|
| `ours_t.pkl` | Large | ~65 M | Yes |
|
||||||
| `ours.pkl` | Large | ~65M | No (fixed 0.5) |
|
| `ours.pkl` | Large | ~65 M | No (fixed 0.5) |
|
||||||
| `ours_small_t.pkl` | Small | ~14M | Yes |
|
| `ours_small_t.pkl` | Small | ~14 M | Yes |
|
||||||
| `ours_small.pkl` | Small | ~14M | No (fixed 0.5) |
|
| `ours_small.pkl` | Small | ~14 M | No (fixed 0.5) |
|
||||||
|
|
||||||
#### EMA-VFI Interpolate
|
#### EMA-VFI Interpolate / Segment Interpolate
|
||||||
|
|
||||||
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
|
Same controls as above.
|
||||||
|
|
||||||
#### EMA-VFI Segment Interpolate
|
</details>
|
||||||
|
|
||||||
Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
<details>
|
||||||
|
<summary><strong>SGM-VFI</strong></summary>
|
||||||
### SGM-VFI
|
|
||||||
|
|
||||||
#### Load SGM-VFI Model
|
#### Load SGM-VFI Model
|
||||||
|
|
||||||
Loads an SGM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/sgm-vfi/`. Variant (base/small) is auto-detected from the filename (default is small).
|
Auto-downloads from Google Drive to `ComfyUI/models/sgm-vfi/`. Requires cupy.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| **model_path** | Checkpoint file from `models/sgm-vfi/` |
|
| **model_path** | Checkpoint from `models/sgm-vfi/` |
|
||||||
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
|
| **tta** | Test-time augmentation (~2x slower, slightly better quality) |
|
||||||
| **num_key_points** | Sparsity of global matching (0.0 = global everywhere, 0.5 = default balance, higher = faster) |
|
| **num_key_points** | Global matching sparsity (0.0 = global everywhere, 0.5 = default, higher = faster) |
|
||||||
|
|
||||||
Available checkpoints:
|
|
||||||
| Checkpoint | Variant | Params |
|
| Checkpoint | Variant | Params |
|
||||||
|-----------|---------|--------|
|
|-----------|---------|--------|
|
||||||
| `ours-1-2-points.pkl` | Small | ~15M + GMFlow |
|
| `ours-1-2-points.pkl` | Small | ~15 M + GMFlow |
|
||||||
|
|
||||||
#### SGM-VFI Interpolate
|
#### SGM-VFI Interpolate / Segment Interpolate
|
||||||
|
|
||||||
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
|
Same controls as above.
|
||||||
|
|
||||||
#### SGM-VFI Segment Interpolate
|
</details>
|
||||||
|
|
||||||
Same as SGM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
<details>
|
||||||
|
<summary><strong>GIMM-VFI</strong></summary>
|
||||||
### GIMM-VFI
|
|
||||||
|
|
||||||
#### Load GIMM-VFI Model
|
#### Load GIMM-VFI Model
|
||||||
|
|
||||||
Loads a GIMM-VFI checkpoint. Auto-downloads from [HuggingFace](https://huggingface.co/Kijai/GIMM-VFI_safetensors) on first use to `ComfyUI/models/gimm-vfi/`. The matching flow estimator (RAFT or FlowFormer) is auto-detected and downloaded alongside the main model.
|
Auto-downloads from [HuggingFace](https://huggingface.co/Kijai/GIMM-VFI_safetensors) to `ComfyUI/models/gimm-vfi/`. The matching flow estimator (RAFT or FlowFormer) is auto-detected and downloaded alongside.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| **model_path** | Checkpoint file from `models/gimm-vfi/` |
|
| **model_path** | Checkpoint from `models/gimm-vfi/` |
|
||||||
| **ds_factor** | Downscale factor for internal processing (1.0 = full res, 0.5 = half). Lower = less VRAM, faster, less quality. Try 0.5 for 4K inputs |
|
| **ds_factor** | Downscale factor for internal processing (1.0 = full, 0.5 = half). Try 0.5 for 4K inputs |
|
||||||
|
|
||||||
Available checkpoints:
|
|
||||||
| Checkpoint | Variant | Params | Flow estimator (auto-downloaded) |
|
| Checkpoint | Variant | Params | Flow estimator (auto-downloaded) |
|
||||||
|-----------|---------|--------|----------------------------------|
|
|-----------|---------|--------|----------------------------------|
|
||||||
| `gimmvfi_r_arb_lpips_fp32.safetensors` | RAFT | ~80M | `raft-things_fp32.safetensors` |
|
| `gimmvfi_r_arb_lpips_fp32.safetensors` | RAFT | ~80 M | `raft-things_fp32.safetensors` |
|
||||||
| `gimmvfi_f_arb_lpips_fp32.safetensors` | FlowFormer | ~123M | `flowformer_sintel_fp32.safetensors` |
|
| `gimmvfi_f_arb_lpips_fp32.safetensors` | FlowFormer | ~123 M | `flowformer_sintel_fp32.safetensors` |
|
||||||
|
|
||||||
#### GIMM-VFI Interpolate
|
#### GIMM-VFI Interpolate
|
||||||
|
|
||||||
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, plus:
|
Common controls plus:
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| **single_pass** | When enabled (default), generates all intermediate frames per pair in one forward pass using GIMM-VFI's arbitrary-timestep capability. No recursive 2x passes needed for 4x or 8x. Disable to use the standard recursive approach (same as BIM/EMA/SGM) |
|
| **single_pass** | Generate all intermediate frames per pair in one forward pass (default on). No recursive 2x passes needed for 4x/8x. Disable to use the standard recursive approach |
|
||||||
|
|
||||||
#### GIMM-VFI Segment Interpolate
|
#### GIMM-VFI Segment Interpolate
|
||||||
|
|
||||||
Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
Same pattern as other Segment nodes.
|
||||||
|
|
||||||
**Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
</details>
|
||||||
|
|
||||||
### FlashVSR
|
### Tween Concat Videos
|
||||||
|
|
||||||
FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling.
|
Concatenates segment video files into a single video using ffmpeg. Connect from any Segment Interpolate's model output to ensure it runs after all segments are saved. Works with all four models.
|
||||||
|
|
||||||
#### Load FlashVSR Model
|
### Output frame count
|
||||||
|
|
||||||
Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`.
|
- **Multiplier mode:** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
||||||
|
- **Target FPS mode:** `floor((N-1) / source_fps * target_fps) + 1` frames. Automatically oversamples to the nearest power-of-2 above the ratio, then selects frames at exact target timestamps. Downsampling (target < source) also works — frames are selected from the input with no model calls.
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) |
|
|
||||||
| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) |
|
|
||||||
|
|
||||||
Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)):
|
|
||||||
| Checkpoint | Size | Description |
|
|
||||||
|-----------|------|-------------|
|
|
||||||
| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) |
|
|
||||||
| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE |
|
|
||||||
| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection |
|
|
||||||
| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) |
|
|
||||||
| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings |
|
|
||||||
|
|
||||||
#### FlashVSR Upscale
|
|
||||||
|
|
||||||
Upscales an image batch with 4x spatial super-resolution.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| **images** | Input video frames (minimum 21 frames) |
|
|
||||||
| **model** | Model from the loader node |
|
|
||||||
| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) |
|
|
||||||
| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames |
|
|
||||||
| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) |
|
|
||||||
| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) |
|
|
||||||
| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) |
|
|
||||||
| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) |
|
|
||||||
| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability |
|
|
||||||
| **color_fix** | Apply wavelet color correction to prevent color shifts |
|
|
||||||
| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) |
|
|
||||||
| **seed** | Random seed for the diffusion process |
|
|
||||||
|
|
||||||
#### FlashVSR Segment Upscale
|
|
||||||
|
|
||||||
Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| **segment_index** | Which segment to process (0-based) |
|
|
||||||
| **segment_size** | Number of input frames per segment (minimum 21) |
|
|
||||||
| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending |
|
|
||||||
| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) |
|
|
||||||
|
|
||||||
Plus all the same upscale parameters as FlashVSR Upscale.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Clone into your ComfyUI `custom_nodes/` directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ComfyUI/custom_nodes
|
|
||||||
git clone https://github.com/your-user/ComfyUI-Tween.git
|
|
||||||
```
|
|
||||||
|
|
||||||
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
|
|
||||||
|
|
||||||
> **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with:
|
|
||||||
> ```bash
|
|
||||||
> pip install cupy-cuda12x # replace 12 with your CUDA major version
|
|
||||||
> ```
|
|
||||||
|
|
||||||
To install manually:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ComfyUI-Tween
|
|
||||||
python install.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Requirements
|
|
||||||
|
|
||||||
- PyTorch with CUDA
|
|
||||||
- `cupy` (matching your CUDA version, for BIM-VFI, SGM-VFI, and GIMM-VFI)
|
|
||||||
- `timm` (for EMA-VFI and SGM-VFI)
|
|
||||||
- `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download)
|
|
||||||
- `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI)
|
|
||||||
- `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download)
|
|
||||||
- `safetensors` (for FlashVSR checkpoint loading)
|
|
||||||
|
|
||||||
## VRAM Guide
|
|
||||||
|
|
||||||
| VRAM | Recommended settings |
|
|
||||||
|------|---------------------|
|
|
||||||
| 8 GB | batch_size=1, chunk_size=500 |
|
|
||||||
| 24 GB | batch_size=2-4, chunk_size=1000 |
|
|
||||||
| 48 GB+ | batch_size=4-16, all_on_gpu=true |
|
|
||||||
| 96 GB+ | batch_size=8-16, all_on_gpu=true, chunk_size=0 |
|
|
||||||
|
|
||||||
## Acknowledgments
|
## Acknowledgments
|
||||||
|
|
||||||
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths).
|
| Model | Authors | Venue | Links |
|
||||||
|
|-------|---------|-------|-------|
|
||||||
|
| **BIM-VFI** | Seo, Oh, Kim (KAIST VIC Lab) | CVPR 2025 | [Paper](https://arxiv.org/abs/2412.11365) · [Code](https://github.com/KAIST-VICLab/BiM-VFI) · [Project](https://kaist-viclab.github.io/BiM-VFI_site/) |
|
||||||
|
| **EMA-VFI** | Zhang et al. (MCG-NJU) | CVPR 2023 | [Paper](https://arxiv.org/abs/2303.00440) · [Code](https://github.com/MCG-NJU/EMA-VFI) |
|
||||||
|
| **SGM-VFI** | Zhang et al. (MCG-NJU) | CVPR 2024 | [Paper](https://arxiv.org/abs/2404.06913) · [Code](https://github.com/MCG-NJU/SGM-VFI) |
|
||||||
|
| **GIMM-VFI** | Guo, Li, Loy (S-Lab NTU) | NeurIPS 2024 | [Paper](https://arxiv.org/abs/2407.08680) · [Code](https://github.com/GSeanCDAT/GIMM-VFI) |
|
||||||
|
|
||||||
**BiM-VFI:**
|
GIMM-VFI adaptation from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications.
|
||||||
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
|
|
||||||
> "BiM-VFI: Bidirectional Motion Field-Guided Frame Interpolation for Video with Non-uniform Motions."
|
<details>
|
||||||
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2025.
|
<summary>BibTeX citations</summary>
|
||||||
> [[arXiv]](https://arxiv.org/abs/2412.11365) [[Project Page]](https://kaist-viclab.github.io/BiM-VFI_site/) [[GitHub]](https://github.com/KAIST-VICLab/BiM-VFI)
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{seo2025bimvfi,
|
@inproceedings{seo2025bimvfi,
|
||||||
@@ -266,45 +232,21 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
|
|||||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
year={2025}
|
year={2025}
|
||||||
}
|
}
|
||||||
```
|
|
||||||
|
|
||||||
**EMA-VFI:**
|
|
||||||
> Guozhen Zhang, Yuhan Zhu, Haonan Wang, Youxin Chen, Gangshan Wu, and Limin Wang.
|
|
||||||
> "Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation."
|
|
||||||
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2023.
|
|
||||||
> [[arXiv]](https://arxiv.org/abs/2303.00440) [[GitHub]](https://github.com/MCG-NJU/EMA-VFI)
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{zhang2023emavfi,
|
@inproceedings{zhang2023emavfi,
|
||||||
title={Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation},
|
title={Extracting Motion and Appearance via Inter-Frame Attention for Efficient Video Frame Interpolation},
|
||||||
author={Zhang, Guozhen and Zhu, Yuhan and Wang, Haonan and Chen, Youxin and Wu, Gangshan and Wang, Limin},
|
author={Zhang, Guozhen and Zhu, Yuhan and Wang, Haonan and Chen, Youxin and Wu, Gangshan and Wang, Limin},
|
||||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
```
|
|
||||||
|
|
||||||
**SGM-VFI:**
|
|
||||||
> Guozhen Zhang, Yuhan Zhu, Evan Zheran Liu, Haonan Wang, Mingzhen Sun, Gangshan Wu, and Limin Wang.
|
|
||||||
> "Sparse Global Matching for Video Frame Interpolation with Large Motion."
|
|
||||||
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2024.
|
|
||||||
> [[arXiv]](https://arxiv.org/abs/2404.06913) [[GitHub]](https://github.com/MCG-NJU/SGM-VFI)
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{zhang2024sgmvfi,
|
@inproceedings{zhang2024sgmvfi,
|
||||||
title={Sparse Global Matching for Video Frame Interpolation with Large Motion},
|
title={Sparse Global Matching for Video Frame Interpolation with Large Motion},
|
||||||
author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin},
|
author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin},
|
||||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
year={2024}
|
year={2024}
|
||||||
}
|
}
|
||||||
```
|
|
||||||
|
|
||||||
**GIMM-VFI:**
|
|
||||||
> Zujin Guo, Wei Li, and Chen Change Loy.
|
|
||||||
> "Generalizable Implicit Motion Modeling for Video Frame Interpolation."
|
|
||||||
> *Advances in Neural Information Processing Systems (NeurIPS)*, 2024.
|
|
||||||
> [[arXiv]](https://arxiv.org/abs/2407.08680) [[GitHub]](https://github.com/GSeanCDAT/GIMM-VFI)
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{guo2024gimmvfi,
|
@inproceedings{guo2024gimmvfi,
|
||||||
title={Generalizable Implicit Motion Modeling for Video Frame Interpolation},
|
title={Generalizable Implicit Motion Modeling for Video Frame Interpolation},
|
||||||
author={Guo, Zujin and Li, Wei and Loy, Chen Change},
|
author={Guo, Zujin and Li, Wei and Loy, Chen Change},
|
||||||
@@ -313,29 +255,12 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**FlashVSR:**
|
</details>
|
||||||
> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng.
|
|
||||||
> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer."
|
|
||||||
> *arXiv preprint arXiv:2510.12747*, 2025.
|
|
||||||
> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR)
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@article{zhuang2025flashvsr,
|
|
||||||
title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer},
|
|
||||||
author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing},
|
|
||||||
journal={arXiv preprint arXiv:2510.12747},
|
|
||||||
year={2025}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
|
**BIM-VFI:** Research and education only. Commercial use requires permission from Prof. Munchurl Kim (mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI).
|
||||||
|
|
||||||
The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details.
|
**EMA-VFI, SGM-VFI, GIMM-VFI:** [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). GIMM-VFI ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
|
||||||
|
|
||||||
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
|
**This wrapper code:** [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
|
|
||||||
|
|
||||||
The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR).
|
|
||||||
|
|||||||
59
__init__.py
59
__init__.py
@@ -1,60 +1,11 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger("Tween")
|
|
||||||
|
|
||||||
|
|
||||||
def _auto_install_deps():
|
|
||||||
"""Auto-install missing dependencies on first load."""
|
|
||||||
# gdown
|
|
||||||
try:
|
|
||||||
import gdown # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
logger.info("[Tween] Installing gdown...")
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
|
|
||||||
|
|
||||||
# timm (required for EMA-VFI's MotionFormer backbone)
|
|
||||||
try:
|
|
||||||
import timm # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
logger.info("[Tween] Installing timm...")
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "timm"])
|
|
||||||
|
|
||||||
# cupy
|
|
||||||
try:
|
|
||||||
import cupy # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
major = int(torch.version.cuda.split(".")[0])
|
|
||||||
cupy_pkg = f"cupy-cuda{major}x"
|
|
||||||
logger.info(f"[Tween] Installing {cupy_pkg} (CUDA {torch.version.cuda})...")
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", cupy_pkg])
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
|
|
||||||
|
|
||||||
# GIMM-VFI + FlashVSR dependencies
|
|
||||||
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"):
|
|
||||||
try:
|
|
||||||
__import__(pkg)
|
|
||||||
except ImportError:
|
|
||||||
logger.info(f"[Tween] Installing {pkg}...")
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
|
|
||||||
|
|
||||||
|
|
||||||
_auto_install_deps()
|
|
||||||
|
|
||||||
from .nodes import (
|
from .nodes import (
|
||||||
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
|
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, TweenConcatVideos,
|
||||||
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
||||||
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
||||||
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
||||||
LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale,
|
VFIOptimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web"
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LoadBIMVFIModel": LoadBIMVFIModel,
|
"LoadBIMVFIModel": LoadBIMVFIModel,
|
||||||
"BIMVFIInterpolate": BIMVFIInterpolate,
|
"BIMVFIInterpolate": BIMVFIInterpolate,
|
||||||
@@ -69,9 +20,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadGIMMVFIModel": LoadGIMMVFIModel,
|
"LoadGIMMVFIModel": LoadGIMMVFIModel,
|
||||||
"GIMMVFIInterpolate": GIMMVFIInterpolate,
|
"GIMMVFIInterpolate": GIMMVFIInterpolate,
|
||||||
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
|
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
|
||||||
"LoadFlashVSRModel": LoadFlashVSRModel,
|
"VFIOptimizer": VFIOptimizer,
|
||||||
"FlashVSRUpscale": FlashVSRUpscale,
|
|
||||||
"FlashVSRSegmentUpscale": FlashVSRSegmentUpscale,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -88,7 +37,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
|
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
|
||||||
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
|
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
|
||||||
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
|
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
|
||||||
"LoadFlashVSRModel": "Load FlashVSR Model",
|
"VFIOptimizer": "VFI Optimizer",
|
||||||
"FlashVSRUpscale": "FlashVSR Upscale",
|
|
||||||
"FlashVSRSegmentUpscale": "FlashVSR Segment Upscale",
|
|
||||||
}
|
}
|
||||||
|
|||||||
84
assets/model-comparison.svg
Normal file
84
assets/model-comparison.svg
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 720 320" width="720" height="320">
|
||||||
|
<defs>
|
||||||
|
<linearGradient id="gQ" x1="0" y1="0" x2="1" y2="0">
|
||||||
|
<stop offset="0%" stop-color="#7aa2f7"/><stop offset="100%" stop-color="#7dcfff"/>
|
||||||
|
</linearGradient>
|
||||||
|
<linearGradient id="gS" x1="0" y1="0" x2="1" y2="0">
|
||||||
|
<stop offset="0%" stop-color="#9ece6a"/><stop offset="100%" stop-color="#73daca"/>
|
||||||
|
</linearGradient>
|
||||||
|
<linearGradient id="gV" x1="0" y1="0" x2="1" y2="0">
|
||||||
|
<stop offset="0%" stop-color="#bb9af7"/><stop offset="100%" stop-color="#d2a8ff"/>
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
|
||||||
|
<!-- Background -->
|
||||||
|
<rect width="720" height="320" rx="16" fill="#0d1117"/>
|
||||||
|
|
||||||
|
<!-- ═══ BIM-VFI (top-left) ═══ -->
|
||||||
|
<rect x="10" y="10" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
|
||||||
|
<rect x="11" y="22" width="3" height="121" fill="#3fb950"/>
|
||||||
|
<text x="30" y="38" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">BIM-VFI</text>
|
||||||
|
<text x="30" y="56" fill="#3fb950" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">★ Recommended · Best quality · CVPR 2025</text>
|
||||||
|
<line x1="30" y1="64" x2="330" y2="64" stroke="#30363d" stroke-width="0.5"/>
|
||||||
|
<text x="30" y="82" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
|
||||||
|
<rect x="88" y="72" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="72" width="244" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
|
||||||
|
<text x="30" y="100" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
|
||||||
|
<rect x="88" y="90" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="90" width="146" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
|
||||||
|
<text x="30" y="118" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
|
||||||
|
<rect x="88" y="108" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="108" width="195" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
|
||||||
|
<text x="262" y="143" fill="#f0883e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Research only</text>
|
||||||
|
|
||||||
|
<!-- ═══ EMA-VFI (top-right) ═══ -->
|
||||||
|
<rect x="370" y="10" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
|
||||||
|
<rect x="371" y="22" width="3" height="121" fill="#58a6ff"/>
|
||||||
|
<text x="390" y="38" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">EMA-VFI</text>
|
||||||
|
<text x="390" y="56" fill="#58a6ff" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Fastest · No cupy needed · CVPR 2023</text>
|
||||||
|
<line x1="390" y1="64" x2="690" y2="64" stroke="#30363d" stroke-width="0.5"/>
|
||||||
|
<text x="390" y="82" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
|
||||||
|
<rect x="448" y="72" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="72" width="146" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
|
||||||
|
<text x="390" y="100" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
|
||||||
|
<rect x="448" y="90" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="90" width="244" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
|
||||||
|
<text x="390" y="118" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
|
||||||
|
<rect x="448" y="108" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="108" width="244" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
|
||||||
|
<text x="632" y="143" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
|
||||||
|
|
||||||
|
<!-- ═══ SGM-VFI (bottom-left) ═══ -->
|
||||||
|
<rect x="10" y="165" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
|
||||||
|
<rect x="11" y="177" width="3" height="121" fill="#f0883e"/>
|
||||||
|
<text x="30" y="193" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">SGM-VFI</text>
|
||||||
|
<text x="30" y="211" fill="#f0883e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Large motion specialist · CVPR 2024</text>
|
||||||
|
<line x1="30" y1="219" x2="330" y2="219" stroke="#30363d" stroke-width="0.5"/>
|
||||||
|
<text x="30" y="237" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
|
||||||
|
<rect x="88" y="227" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="227" width="195" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
|
||||||
|
<text x="30" y="255" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
|
||||||
|
<rect x="88" y="245" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="245" width="98" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
|
||||||
|
<text x="30" y="273" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
|
||||||
|
<rect x="88" y="263" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="88" y="263" width="98" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
|
||||||
|
<text x="272" y="298" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
|
||||||
|
|
||||||
|
<!-- ═══ GIMM-VFI (bottom-right) ═══ -->
|
||||||
|
<rect x="370" y="165" width="340" height="145" rx="10" fill="#161b22" stroke="#30363d" stroke-width="1"/>
|
||||||
|
<rect x="371" y="177" width="3" height="121" fill="#bc8cff"/>
|
||||||
|
<text x="390" y="193" fill="#e6edf3" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="15" font-weight="600">GIMM-VFI</text>
|
||||||
|
<text x="390" y="211" fill="#bc8cff" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Single-pass 4×/8× · NeurIPS 2024</text>
|
||||||
|
<line x1="390" y1="219" x2="690" y2="219" stroke="#30363d" stroke-width="0.5"/>
|
||||||
|
<text x="390" y="237" fill="#7aa2f7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Quality</text>
|
||||||
|
<rect x="448" y="227" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="227" width="146" height="11" rx="3" fill="url(#gQ)" opacity="0.85"/>
|
||||||
|
<text x="390" y="255" fill="#9ece6a" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">Speed</text>
|
||||||
|
<rect x="448" y="245" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="245" width="195" height="11" rx="3" fill="url(#gS)" opacity="0.85"/>
|
||||||
|
<text x="390" y="273" fill="#bb9af7" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="11">VRAM</text>
|
||||||
|
<rect x="448" y="263" width="244" height="11" rx="3" fill="#21262d"/>
|
||||||
|
<rect x="448" y="263" width="146" height="11" rx="3" fill="url(#gV)" opacity="0.85"/>
|
||||||
|
<text x="632" y="298" fill="#8b949e" font-family="-apple-system,BlinkMacSystemFont,'Segoe UI','Noto Sans',Helvetica,Arial,sans-serif" font-size="10">Apache 2.0</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 7.8 KiB |
@@ -206,100 +206,6 @@
|
|||||||
"2"
|
"2"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 12,
|
|
||||||
"type": "easy forLoopStart",
|
|
||||||
"pos": [
|
|
||||||
-8160,
|
|
||||||
576
|
|
||||||
],
|
|
||||||
"size": [
|
|
||||||
270,
|
|
||||||
138
|
|
||||||
],
|
|
||||||
"flags": {},
|
|
||||||
"order": 6,
|
|
||||||
"mode": 0,
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"name": "initial_value1",
|
|
||||||
"shape": 7,
|
|
||||||
"type": "*",
|
|
||||||
"link": 68
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "total",
|
|
||||||
"type": "INT",
|
|
||||||
"widget": {
|
|
||||||
"name": "total"
|
|
||||||
},
|
|
||||||
"link": 33
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "initial_value2",
|
|
||||||
"type": "*",
|
|
||||||
"link": 44
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "initial_value3",
|
|
||||||
"type": "*",
|
|
||||||
"link": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "flow",
|
|
||||||
"shape": 5,
|
|
||||||
"type": "FLOW_CONTROL",
|
|
||||||
"links": [
|
|
||||||
15
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "index",
|
|
||||||
"type": "INT",
|
|
||||||
"links": [
|
|
||||||
25,
|
|
||||||
26
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "value1",
|
|
||||||
"type": "*",
|
|
||||||
"links": [
|
|
||||||
18
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "value2",
|
|
||||||
"type": "*",
|
|
||||||
"links": [
|
|
||||||
21,
|
|
||||||
64
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "value3",
|
|
||||||
"type": "*",
|
|
||||||
"links": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"cnr_id": "comfyui-easy-use",
|
|
||||||
"ver": "7c470c67d6df44498e52c902173c1ac77cd5bdfd",
|
|
||||||
"Node name for S&R": "easy forLoopStart",
|
|
||||||
"ue_properties": {
|
|
||||||
"widget_ue_connectable": {},
|
|
||||||
"input_ue_unconnectable": {},
|
|
||||||
"version": "7.6.2"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"widgets_values": [
|
|
||||||
6
|
|
||||||
],
|
|
||||||
"color": "#223",
|
|
||||||
"bgcolor": "#335"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"type": "easy forLoopEnd",
|
"type": "easy forLoopEnd",
|
||||||
@@ -371,85 +277,6 @@
|
|||||||
"color": "#223",
|
"color": "#223",
|
||||||
"bgcolor": "#335"
|
"bgcolor": "#335"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 11,
|
|
||||||
"type": "BIMVFISegmentInterpolate",
|
|
||||||
"pos": [
|
|
||||||
-7584,
|
|
||||||
576
|
|
||||||
],
|
|
||||||
"size": [
|
|
||||||
321.58209228515625,
|
|
||||||
246
|
|
||||||
],
|
|
||||||
"flags": {},
|
|
||||||
"order": 9,
|
|
||||||
"mode": 0,
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"name": "images",
|
|
||||||
"type": "IMAGE",
|
|
||||||
"link": 21
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "model",
|
|
||||||
"type": "BIM_VFI_MODEL",
|
|
||||||
"link": 18
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "segment_index",
|
|
||||||
"type": "INT",
|
|
||||||
"widget": {
|
|
||||||
"name": "segment_index"
|
|
||||||
},
|
|
||||||
"link": 25
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "segment_size",
|
|
||||||
"type": "INT",
|
|
||||||
"widget": {
|
|
||||||
"name": "segment_size"
|
|
||||||
},
|
|
||||||
"link": 35
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "images",
|
|
||||||
"type": "IMAGE",
|
|
||||||
"links": [
|
|
||||||
66
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "model",
|
|
||||||
"type": "BIM_VFI_MODEL",
|
|
||||||
"links": [
|
|
||||||
67
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"aux_id": "Comfyui-BIM-VFI.git",
|
|
||||||
"ver": "7cf7162143eaa5b0939e0e122f80bc956baf65ea",
|
|
||||||
"Node name for S&R": "BIMVFISegmentInterpolate",
|
|
||||||
"ue_properties": {
|
|
||||||
"widget_ue_connectable": {},
|
|
||||||
"input_ue_unconnectable": {},
|
|
||||||
"version": "7.6.2"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"widgets_values": [
|
|
||||||
2,
|
|
||||||
40,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
500
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 3,
|
"id": 3,
|
||||||
"type": "LoadBIMVFIModel",
|
"type": "LoadBIMVFIModel",
|
||||||
@@ -561,7 +388,6 @@
|
|||||||
"video/",
|
"video/",
|
||||||
"tween_sgm",
|
"tween_sgm",
|
||||||
"tween_video_sgm.mp4",
|
"tween_video_sgm.mp4",
|
||||||
true,
|
|
||||||
true
|
true
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -574,7 +400,7 @@
|
|||||||
],
|
],
|
||||||
"size": [
|
"size": [
|
||||||
544,
|
544,
|
||||||
352
|
334
|
||||||
],
|
],
|
||||||
"flags": {},
|
"flags": {},
|
||||||
"order": 10,
|
"order": 10,
|
||||||
@@ -647,11 +473,227 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"id": 16,
|
||||||
|
"type": "PrimitiveInt",
|
||||||
|
"pos": [
|
||||||
|
-9184,
|
||||||
|
544
|
||||||
|
],
|
||||||
|
"size": [
|
||||||
|
270,
|
||||||
|
82
|
||||||
|
],
|
||||||
|
"flags": {},
|
||||||
|
"order": 2,
|
||||||
|
"mode": 0,
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "INT",
|
||||||
|
"type": "INT",
|
||||||
|
"links": [
|
||||||
|
31,
|
||||||
|
35
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Frames number each loops",
|
||||||
|
"properties": {
|
||||||
|
"cnr_id": "comfy-core",
|
||||||
|
"ver": "0.13.0",
|
||||||
|
"Node name for S&R": "PrimitiveInt",
|
||||||
|
"ue_properties": {
|
||||||
|
"widget_ue_connectable": {},
|
||||||
|
"input_ue_unconnectable": {},
|
||||||
|
"version": "7.6.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"widgets_values": [
|
||||||
|
100,
|
||||||
|
"fixed"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"type": "easy forLoopStart",
|
||||||
|
"pos": [
|
||||||
|
-8160,
|
||||||
|
576
|
||||||
|
],
|
||||||
|
"size": [
|
||||||
|
270,
|
||||||
|
138
|
||||||
|
],
|
||||||
|
"flags": {},
|
||||||
|
"order": 6,
|
||||||
|
"mode": 0,
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "initial_value1",
|
||||||
|
"shape": 7,
|
||||||
|
"type": "*",
|
||||||
|
"link": 68
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "total",
|
||||||
|
"type": "INT",
|
||||||
|
"widget": {
|
||||||
|
"name": "total"
|
||||||
|
},
|
||||||
|
"link": 33
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "initial_value2",
|
||||||
|
"type": "*",
|
||||||
|
"link": 44
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "initial_value3",
|
||||||
|
"type": "*",
|
||||||
|
"link": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "flow",
|
||||||
|
"shape": 5,
|
||||||
|
"type": "FLOW_CONTROL",
|
||||||
|
"links": [
|
||||||
|
15
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "index",
|
||||||
|
"type": "INT",
|
||||||
|
"links": [
|
||||||
|
25,
|
||||||
|
26
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "value1",
|
||||||
|
"type": "*",
|
||||||
|
"links": [
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "value2",
|
||||||
|
"type": "*",
|
||||||
|
"links": [
|
||||||
|
21,
|
||||||
|
64
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "value3",
|
||||||
|
"type": "*",
|
||||||
|
"links": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"cnr_id": "comfyui-easy-use",
|
||||||
|
"ver": "7c470c67d6df44498e52c902173c1ac77cd5bdfd",
|
||||||
|
"Node name for S&R": "easy forLoopStart",
|
||||||
|
"ue_properties": {
|
||||||
|
"widget_ue_connectable": {},
|
||||||
|
"input_ue_unconnectable": {},
|
||||||
|
"version": "7.6.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"widgets_values": [
|
||||||
|
6
|
||||||
|
],
|
||||||
|
"color": "#223",
|
||||||
|
"bgcolor": "#335"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11,
|
||||||
|
"type": "BIMVFISegmentInterpolate",
|
||||||
|
"pos": [
|
||||||
|
-7584,
|
||||||
|
576
|
||||||
|
],
|
||||||
|
"size": [
|
||||||
|
321.58209228515625,
|
||||||
|
294
|
||||||
|
],
|
||||||
|
"flags": {},
|
||||||
|
"order": 9,
|
||||||
|
"mode": 0,
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "images",
|
||||||
|
"type": "IMAGE",
|
||||||
|
"link": 21
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "model",
|
||||||
|
"type": "BIM_VFI_MODEL",
|
||||||
|
"link": 18
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "segment_index",
|
||||||
|
"type": "INT",
|
||||||
|
"widget": {
|
||||||
|
"name": "segment_index"
|
||||||
|
},
|
||||||
|
"link": 25
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "segment_size",
|
||||||
|
"type": "INT",
|
||||||
|
"widget": {
|
||||||
|
"name": "segment_size"
|
||||||
|
},
|
||||||
|
"link": 35
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "images",
|
||||||
|
"type": "IMAGE",
|
||||||
|
"links": [
|
||||||
|
66
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "model",
|
||||||
|
"type": "BIM_VFI_MODEL",
|
||||||
|
"links": [
|
||||||
|
67
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"aux_id": "Comfyui-BIM-VFI.git",
|
||||||
|
"ver": "7cf7162143eaa5b0939e0e122f80bc956baf65ea",
|
||||||
|
"Node name for S&R": "BIMVFISegmentInterpolate",
|
||||||
|
"ue_properties": {
|
||||||
|
"widget_ue_connectable": {},
|
||||||
|
"input_ue_unconnectable": {},
|
||||||
|
"version": "7.6.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"widgets_values": [
|
||||||
|
2,
|
||||||
|
40,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
1,
|
||||||
|
500,
|
||||||
|
16,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
500
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"id": 28,
|
"id": 28,
|
||||||
"type": "VHS_LoadVideoPath",
|
"type": "VHS_LoadVideoPath",
|
||||||
"pos": [
|
"pos": [
|
||||||
-9152,
|
-9184,
|
||||||
704
|
704
|
||||||
],
|
],
|
||||||
"size": [
|
"size": [
|
||||||
@@ -659,7 +701,7 @@
|
|||||||
286
|
286
|
||||||
],
|
],
|
||||||
"flags": {},
|
"flags": {},
|
||||||
"order": 2,
|
"order": 3,
|
||||||
"mode": 0,
|
"mode": 0,
|
||||||
"inputs": [
|
"inputs": [
|
||||||
{
|
{
|
||||||
@@ -738,47 +780,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16,
|
|
||||||
"type": "PrimitiveInt",
|
|
||||||
"pos": [
|
|
||||||
-9152,
|
|
||||||
576
|
|
||||||
],
|
|
||||||
"size": [
|
|
||||||
270,
|
|
||||||
82
|
|
||||||
],
|
|
||||||
"flags": {},
|
|
||||||
"order": 3,
|
|
||||||
"mode": 0,
|
|
||||||
"inputs": [],
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "INT",
|
|
||||||
"type": "INT",
|
|
||||||
"links": [
|
|
||||||
31,
|
|
||||||
35
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"title": "Frames number each loops",
|
|
||||||
"properties": {
|
|
||||||
"cnr_id": "comfy-core",
|
|
||||||
"ver": "0.13.0",
|
|
||||||
"Node name for S&R": "PrimitiveInt",
|
|
||||||
"ue_properties": {
|
|
||||||
"widget_ue_connectable": {},
|
|
||||||
"input_ue_unconnectable": {},
|
|
||||||
"version": "7.6.2"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"widgets_values": [
|
|
||||||
100,
|
|
||||||
"fixed"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"links": [
|
"links": [
|
||||||
@@ -933,10 +934,10 @@
|
|||||||
"workflowRendererVersion": "LG",
|
"workflowRendererVersion": "LG",
|
||||||
"ue_links": [],
|
"ue_links": [],
|
||||||
"ds": {
|
"ds": {
|
||||||
"scale": 1.0834705943388552,
|
"scale": 0.8954302432552531,
|
||||||
"offset": [
|
"offset": [
|
||||||
10009.878269742538,
|
10389.297857289295,
|
||||||
-100.68482917709798
|
79.21414284327875
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"links_added_by_ue": [],
|
"links_added_by_ue": [],
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
from .models.model_manager import ModelManager
|
|
||||||
from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
|
|
||||||
from .models.utils import clean_vram, Buffer_LQ4x_Proj
|
|
||||||
from .models.TCDecoder import build_tcdecoder
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
from ..models.wan_video_dit import WanModel
|
|
||||||
from ..models.wan_video_vae import WanVideoVAE
|
|
||||||
|
|
||||||
model_loader_configs = [
|
|
||||||
# (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
||||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
|
||||||
]
|
|
||||||
huggingface_model_loader_configs = [
|
|
||||||
]
|
|
||||||
patch_model_loader_configs = [
|
|
||||||
]
|
|
||||||
@@ -1,320 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
|
|
||||||
- Encoder removed
|
|
||||||
- Transplant/widening helpers removed
|
|
||||||
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from collections import namedtuple
|
|
||||||
from einops import rearrange
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
|
||||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Utility / building blocks
|
|
||||||
# ----------------------------
|
|
||||||
|
|
||||||
class IdentityConv2d(nn.Conv2d):
|
|
||||||
"""Same-shape Conv2d initialized to identity (Dirac)."""
|
|
||||||
def __init__(self, C, kernel_size=3, bias=False):
|
|
||||||
pad = kernel_size // 2
|
|
||||||
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
|
|
||||||
with torch.no_grad():
|
|
||||||
init.dirac_(self.weight)
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias.zero_()
|
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
|
||||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
||||||
|
|
||||||
class Clamp(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.tanh(x / 3) * 3
|
|
||||||
|
|
||||||
class MemBlock(nn.Module):
|
|
||||||
def __init__(self, n_in, n_out):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Sequential(
|
|
||||||
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
|
|
||||||
conv(n_out, n_out), nn.ReLU(inplace=True),
|
|
||||||
conv(n_out, n_out)
|
|
||||||
)
|
|
||||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
||||||
self.act = nn.ReLU(inplace=True)
|
|
||||||
def forward(self, x, past):
|
|
||||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
|
||||||
|
|
||||||
class TPool(nn.Module):
|
|
||||||
def __init__(self, n_f, stride):
|
|
||||||
super().__init__()
|
|
||||||
self.stride = stride
|
|
||||||
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
|
|
||||||
def forward(self, x):
|
|
||||||
_NT, C, H, W = x.shape
|
|
||||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
|
||||||
|
|
||||||
class TGrow(nn.Module):
|
|
||||||
def __init__(self, n_f, stride):
|
|
||||||
super().__init__()
|
|
||||||
self.stride = stride
|
|
||||||
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
|
||||||
def forward(self, x):
|
|
||||||
_NT, C, H, W = x.shape
|
|
||||||
x = self.conv(x)
|
|
||||||
return x.reshape(-1, C, H, W)
|
|
||||||
|
|
||||||
class PixelShuffle3d(nn.Module):
|
|
||||||
def __init__(self, ff, hh, ww):
|
|
||||||
super().__init__()
|
|
||||||
self.ff = ff
|
|
||||||
self.hh = hh
|
|
||||||
self.ww = ww
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (B, C, F, H, W)
|
|
||||||
B, C, F, H, W = x.shape
|
|
||||||
if F % self.ff != 0:
|
|
||||||
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
|
|
||||||
x = torch.cat([first_frame, x], dim=2)
|
|
||||||
return rearrange(
|
|
||||||
x,
|
|
||||||
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
|
||||||
ff=self.ff, hh=self.hh, ww=self.ww
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Generic NTCHW graph executor (kept; used by decoder)
|
|
||||||
# ----------------------------
|
|
||||||
|
|
||||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
|
|
||||||
"""
|
|
||||||
Apply a sequential model with memblocks to the given input.
|
|
||||||
Args:
|
|
||||||
- model: nn.Sequential of blocks to apply
|
|
||||||
- x: input data, of dimensions NTCHW
|
|
||||||
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
|
||||||
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
|
||||||
- show_progress_bar: if True, enables tqdm progressbar display
|
|
||||||
|
|
||||||
Returns NTCHW tensor of output data.
|
|
||||||
"""
|
|
||||||
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
|
||||||
N, T, C, H, W = x.shape
|
|
||||||
if parallel:
|
|
||||||
x = x.reshape(N*T, C, H, W)
|
|
||||||
for b in tqdm(model, disable=not show_progress_bar):
|
|
||||||
if isinstance(b, MemBlock):
|
|
||||||
NT, C, H, W = x.shape
|
|
||||||
T = NT // N
|
|
||||||
_x = x.reshape(N, T, C, H, W)
|
|
||||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
|
||||||
x = b(x, mem)
|
|
||||||
else:
|
|
||||||
x = b(x)
|
|
||||||
NT, C, H, W = x.shape
|
|
||||||
T = NT // N
|
|
||||||
x = x.view(N, T, C, H, W)
|
|
||||||
else:
|
|
||||||
out = []
|
|
||||||
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
|
||||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
|
||||||
while work_queue:
|
|
||||||
xt, i = work_queue.pop(0)
|
|
||||||
if i == 0:
|
|
||||||
progress_bar.update(1)
|
|
||||||
if i == len(model):
|
|
||||||
out.append(xt)
|
|
||||||
else:
|
|
||||||
b = model[i]
|
|
||||||
if isinstance(b, MemBlock):
|
|
||||||
if mem[i] is None:
|
|
||||||
xt_new = b(xt, xt * 0)
|
|
||||||
mem[i] = xt
|
|
||||||
else:
|
|
||||||
xt_new = b(xt, mem[i])
|
|
||||||
mem[i].copy_(xt)
|
|
||||||
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
|
||||||
elif isinstance(b, TPool):
|
|
||||||
if mem[i] is None:
|
|
||||||
mem[i] = []
|
|
||||||
mem[i].append(xt)
|
|
||||||
if len(mem[i]) > b.stride:
|
|
||||||
raise ValueError("TPool internal state invalid.")
|
|
||||||
elif len(mem[i]) == b.stride:
|
|
||||||
N_, C_, H_, W_ = xt.shape
|
|
||||||
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
|
|
||||||
mem[i] = []
|
|
||||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
|
||||||
elif isinstance(b, TGrow):
|
|
||||||
xt = b(xt)
|
|
||||||
NT, C_, H_, W_ = xt.shape
|
|
||||||
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
|
|
||||||
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
|
||||||
else:
|
|
||||||
xt = b(xt)
|
|
||||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
|
||||||
progress_bar.close()
|
|
||||||
x = torch.stack(out, 1)
|
|
||||||
return x, mem
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Decoder-only TAEHV
|
|
||||||
# ----------------------------
|
|
||||||
|
|
||||||
class TAEHV(nn.Module):
|
|
||||||
image_channels = 3
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_path="taehv.pth",
|
|
||||||
decoder_time_upscale=(True, True),
|
|
||||||
decoder_space_upscale=(True, True, True),
|
|
||||||
channels = [256, 128, 64, 64],
|
|
||||||
latent_channels = 16
|
|
||||||
):
|
|
||||||
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
|
|
||||||
Deepening config: how_many_each=1, k=3 (fixed as requested).
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.latent_channels = latent_channels
|
|
||||||
n_f = channels
|
|
||||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
|
||||||
|
|
||||||
# Build the decoder "skeleton"
|
|
||||||
base_decoder = nn.Sequential(
|
|
||||||
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
|
||||||
|
|
||||||
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
|
|
||||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
|
|
||||||
TGrow(n_f[0], 1),
|
|
||||||
conv(n_f[0], n_f[1], bias=False),
|
|
||||||
|
|
||||||
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
|
|
||||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
|
|
||||||
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
|
|
||||||
conv(n_f[1], n_f[2], bias=False),
|
|
||||||
|
|
||||||
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
|
|
||||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
|
|
||||||
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
|
|
||||||
conv(n_f[2], n_f[3], bias=False),
|
|
||||||
|
|
||||||
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
|
|
||||||
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
|
|
||||||
|
|
||||||
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
|
|
||||||
|
|
||||||
if checkpoint_path is not None:
|
|
||||||
missing_keys = self.load_state_dict(
|
|
||||||
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
|
|
||||||
strict=False
|
|
||||||
)
|
|
||||||
print('missing_keys', missing_keys)
|
|
||||||
|
|
||||||
# Initialize decoder mem state
|
|
||||||
self.mem = [None] * len(self.decoder)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
|
|
||||||
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
|
|
||||||
new_layers = []
|
|
||||||
for b in decoder:
|
|
||||||
new_layers.append(b)
|
|
||||||
if isinstance(b, nn.ReLU):
|
|
||||||
# Deduce channel count from preceding layer
|
|
||||||
C = None
|
|
||||||
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
|
|
||||||
C = new_layers[-2].out_channels
|
|
||||||
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
|
|
||||||
C = new_layers[-2].conv[-1].out_channels
|
|
||||||
if C is not None:
|
|
||||||
for _ in range(how_many_each):
|
|
||||||
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
|
|
||||||
new_layers.append(nn.ReLU(inplace=True))
|
|
||||||
return nn.Sequential(*new_layers)
|
|
||||||
|
|
||||||
def patch_tgrow_layers(self, sd):
|
|
||||||
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
|
|
||||||
new_sd = self.state_dict()
|
|
||||||
for i, layer in enumerate(self.decoder):
|
|
||||||
if isinstance(layer, TGrow):
|
|
||||||
key = f"decoder.{i}.conv.weight"
|
|
||||||
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
|
|
||||||
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
|
|
||||||
"""Decode a sequence of frames from latents.
|
|
||||||
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
|
|
||||||
"""
|
|
||||||
trim_flag = self.mem[-8] is None # keeps original relative check
|
|
||||||
|
|
||||||
if cond is not None:
|
|
||||||
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
|
|
||||||
|
|
||||||
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
|
|
||||||
|
|
||||||
if trim_flag:
|
|
||||||
return x[:, self.frames_to_trim:]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
|
|
||||||
|
|
||||||
def clean_mem(self):
|
|
||||||
self.mem = [None] * len(self.decoder)
|
|
||||||
|
|
||||||
class DotDict(dict):
|
|
||||||
__getattr__ = dict.__getitem__
|
|
||||||
__setattr__ = dict.__setitem__
|
|
||||||
|
|
||||||
class TAEW2_1DiffusersWrapper(nn.Module):
|
|
||||||
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
|
|
||||||
super().__init__()
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
self.device = "cuda"
|
|
||||||
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
|
|
||||||
self.temperal_downsample = [True, True, False] # [sic]
|
|
||||||
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
|
|
||||||
|
|
||||||
def decode(self, latents, return_dict=None):
|
|
||||||
n, c, t, h, w = latents.shape
|
|
||||||
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
|
|
||||||
|
|
||||||
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
|
|
||||||
n, c, t, h, w = latents.shape
|
|
||||||
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
|
|
||||||
|
|
||||||
def clean_mem(self):
|
|
||||||
self.taehv.clean_mem()
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Simplified builder (no small, no transplant, no post-hoc deepening)
|
|
||||||
# ----------------------------
|
|
||||||
|
|
||||||
def build_tcdecoder(new_channels = [512, 256, 128, 128],
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
new_latent_channels=None):
|
|
||||||
"""
|
|
||||||
构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
|
|
||||||
- 不创建 small / 不做移植
|
|
||||||
- base_ckpt_path 参数保留但不使用(接口兼容)
|
|
||||||
|
|
||||||
返回:big (单个模型)
|
|
||||||
"""
|
|
||||||
if new_latent_channels is not None:
|
|
||||||
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
|
|
||||||
else:
|
|
||||||
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
|
|
||||||
|
|
||||||
big.clean_mem()
|
|
||||||
return big
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .model_manager import *
|
|
||||||
@@ -1,402 +0,0 @@
|
|||||||
import os, torch, json, importlib
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
|
||||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
|
||||||
|
|
||||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
|
||||||
state_dict_converter = model_class.state_dict_converter()
|
|
||||||
if model_resource == "civitai":
|
|
||||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
|
||||||
elif model_resource == "diffusers":
|
|
||||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
|
||||||
if isinstance(state_dict_results, tuple):
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results
|
|
||||||
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
|
||||||
else:
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
|
||||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
|
||||||
with init_weights_on_device():
|
|
||||||
model = model_class(**extra_kwargs)
|
|
||||||
if hasattr(model, "eval"):
|
|
||||||
model = model.eval()
|
|
||||||
model.load_state_dict(model_state_dict, assign=True)
|
|
||||||
model = model.to(dtype=torch_dtype, device=device)
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
|
||||||
else:
|
|
||||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
|
||||||
model = model.half()
|
|
||||||
try:
|
|
||||||
model = model.to(device=device)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
|
||||||
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
|
||||||
base_state_dict = base_model.state_dict()
|
|
||||||
base_model.to("cpu")
|
|
||||||
del base_model
|
|
||||||
model = model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(base_state_dict, strict=False)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.to(dtype=torch_dtype, device=device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
while True:
|
|
||||||
for model_id in range(len(model_manager.model)):
|
|
||||||
base_model_name = model_manager.model_name[model_id]
|
|
||||||
if base_model_name == model_name:
|
|
||||||
base_model_path = model_manager.model_path[model_id]
|
|
||||||
base_model = model_manager.model[model_id]
|
|
||||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
|
||||||
patched_model = load_single_patch_model_from_single_file(
|
|
||||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
|
||||||
loaded_model_names.append(base_model_name)
|
|
||||||
loaded_models.append(patched_model)
|
|
||||||
model_manager.model.pop(model_id)
|
|
||||||
model_manager.model_path.pop(model_id)
|
|
||||||
model_manager.model_name.pop(model_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorTemplate:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
self.keys_hash_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
|
||||||
if keys_hash is not None:
|
|
||||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
# Load models without strict matching
|
|
||||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
super().__init__(model_loader_configs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
# Split the state_dict and load from each component
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
valid_state_dict = {}
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
valid_state_dict.update(sub_state_dict)
|
|
||||||
if super().match(file_path, valid_state_dict):
|
|
||||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
else:
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromHuggingfaceFolder:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.architecture_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
|
||||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
|
||||||
return False
|
|
||||||
file_list = os.listdir(file_path)
|
|
||||||
if "config.json" not in file_list:
|
|
||||||
return False
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
if "architectures" not in config and "_class_name" not in config:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
|
||||||
for architecture in architectures:
|
|
||||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
|
||||||
if redirected_architecture is not None:
|
|
||||||
architecture = redirected_architecture
|
|
||||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
|
||||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromPatchedSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device="cuda",
|
|
||||||
file_path_list: List[str] = [],
|
|
||||||
):
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.device = device
|
|
||||||
self.model = []
|
|
||||||
self.model_path = []
|
|
||||||
self.model_name = []
|
|
||||||
self.model_detector = [
|
|
||||||
ModelDetectorFromSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
|
||||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
|
||||||
]
|
|
||||||
self.load_models(file_path_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
|
||||||
print(f"Loading models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
#print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
|
||||||
print(f"Loading models from folder: {file_path}")
|
|
||||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
#print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
|
||||||
print(f"Loading patch models from file: {file_path}")
|
|
||||||
model_names, models = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following patched models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
|
||||||
if isinstance(file_path, list):
|
|
||||||
for file_path_ in file_path:
|
|
||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
|
||||||
else:
|
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
|
||||||
is_loaded = False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
|
||||||
for lora in get_lora_loaders():
|
|
||||||
match_results = lora.match(model, state_dict)
|
|
||||||
if match_results is not None:
|
|
||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
|
||||||
lora_prefix, model_resource = match_results
|
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
|
||||||
is_loaded = True
|
|
||||||
break
|
|
||||||
if not is_loaded:
|
|
||||||
print(f" Cannot load LoRA: {file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
|
||||||
#print(f"Loading models from: {file_path}")
|
|
||||||
if device is None: device = self.device
|
|
||||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
|
||||||
if isinstance(file_path, list):
|
|
||||||
state_dict = {}
|
|
||||||
for path in file_path:
|
|
||||||
state_dict.update(load_state_dict(path))
|
|
||||||
elif os.path.isfile(file_path):
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
else:
|
|
||||||
state_dict = None
|
|
||||||
for model_detector in self.model_detector:
|
|
||||||
if model_detector.match(file_path, state_dict):
|
|
||||||
model_names, models = model_detector.load(
|
|
||||||
file_path, state_dict,
|
|
||||||
device=device, torch_dtype=torch_dtype,
|
|
||||||
allowed_model_names=model_names, model_manager=self
|
|
||||||
)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
#print(f" The following models are loaded: {model_names}.")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f" We cannot detect the model type. No models are loaded.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
|
||||||
for file_path in file_path_list:
|
|
||||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
|
||||||
fetched_models = []
|
|
||||||
fetched_model_paths = []
|
|
||||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
|
||||||
if file_path is not None and file_path != model_path:
|
|
||||||
continue
|
|
||||||
if model_name == model_name_:
|
|
||||||
fetched_models.append(model)
|
|
||||||
fetched_model_paths.append(model_path)
|
|
||||||
if len(fetched_models) == 0:
|
|
||||||
#print(f"No {model_name} models available.")
|
|
||||||
return None
|
|
||||||
if len(fetched_models) == 1:
|
|
||||||
print(f"Using {model_name} from {fetched_model_paths[0]}")
|
|
||||||
else:
|
|
||||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
|
|
||||||
if require_model_path:
|
|
||||||
return fetched_models[0], fetched_model_paths[0]
|
|
||||||
else:
|
|
||||||
return fetched_models[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.model:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .core import sparse_sageattn
|
|
||||||
|
|
||||||
__all__ = ["sparse_sageattn"]
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
Sparse SageAttention — block-sparse INT8 attention via Triton.
|
|
||||||
|
|
||||||
https://github.com/jt-zhang/Sparse_SageAttention_API
|
|
||||||
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
Licensed under the Apache License, Version 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .quant_per_block import per_block_int8
|
|
||||||
from .sparse_int8_attn import forward as sparse_sageattn_fwd
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
|
|
||||||
if mask_id is None:
|
|
||||||
mask_id = torch.ones(
|
|
||||||
(q.shape[0], q.shape[1],
|
|
||||||
(q.shape[2] + 128 - 1) // 128,
|
|
||||||
(q.shape[3] + 64 - 1) // 64),
|
|
||||||
dtype=torch.int8, device=q.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_dtype = q.dtype
|
|
||||||
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
|
|
||||||
v = v.to(torch.float16)
|
|
||||||
|
|
||||||
seq_dim = 1 if tensor_layout == "NHD" else 2
|
|
||||||
km = k.mean(dim=seq_dim, keepdim=True)
|
|
||||||
|
|
||||||
q_int8, q_scale, k_int8, k_scale = per_block_int8(
|
|
||||||
q, k, km=km, tensor_layout=tensor_layout,
|
|
||||||
)
|
|
||||||
|
|
||||||
o = sparse_sageattn_fwd(
|
|
||||||
q_int8, k_int8, mask_id, v, q_scale, k_scale,
|
|
||||||
is_causal=is_causal, tensor_layout=tensor_layout,
|
|
||||||
output_dtype=output_dtype,
|
|
||||||
)
|
|
||||||
return o
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
"""
|
|
||||||
Per-block INT8 quantization kernel for Sparse SageAttention.
|
|
||||||
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
Licensed under the Apache License, Version 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def quant_per_block_int8_kernel(
|
|
||||||
Input, Output, Scale, L,
|
|
||||||
stride_iz, stride_ih, stride_in,
|
|
||||||
stride_oz, stride_oh, stride_on,
|
|
||||||
stride_sz, stride_sh,
|
|
||||||
sm_scale,
|
|
||||||
C: tl.constexpr, BLK: tl.constexpr,
|
|
||||||
):
|
|
||||||
off_blk = tl.program_id(0)
|
|
||||||
off_h = tl.program_id(1)
|
|
||||||
off_b = tl.program_id(2)
|
|
||||||
|
|
||||||
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
|
||||||
offs_k = tl.arange(0, C)
|
|
||||||
|
|
||||||
input_ptrs = (
|
|
||||||
Input
|
|
||||||
+ off_b * stride_iz
|
|
||||||
+ off_h * stride_ih
|
|
||||||
+ offs_n[:, None] * stride_in
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
output_ptrs = (
|
|
||||||
Output
|
|
||||||
+ off_b * stride_oz
|
|
||||||
+ off_h * stride_oh
|
|
||||||
+ offs_n[:, None] * stride_on
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
|
|
||||||
|
|
||||||
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
|
||||||
x = x.to(tl.float32)
|
|
||||||
x *= sm_scale
|
|
||||||
scale = tl.max(tl.abs(x)) / 127.0
|
|
||||||
x_int8 = x / scale
|
|
||||||
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
|
||||||
x_int8 = x_int8.to(tl.int8)
|
|
||||||
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
|
||||||
tl.store(scale_ptrs, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
|
|
||||||
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
|
||||||
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
|
||||||
|
|
||||||
if km is not None:
|
|
||||||
k = k - km
|
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
|
||||||
b, h_qo, qo_len, head_dim = q.shape
|
|
||||||
_, h_kv, kv_len, _ = k.shape
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
|
||||||
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
|
||||||
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
|
|
||||||
elif tensor_layout == "NHD":
|
|
||||||
b, qo_len, h_qo, head_dim = q.shape
|
|
||||||
_, kv_len, h_kv, _ = k.shape
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
|
||||||
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
|
||||||
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
|
|
||||||
|
|
||||||
q_scale = torch.empty(
|
|
||||||
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32,
|
|
||||||
)
|
|
||||||
k_scale = torch.empty(
|
|
||||||
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
if sm_scale is None:
|
|
||||||
sm_scale = head_dim ** -0.5
|
|
||||||
|
|
||||||
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
|
|
||||||
quant_per_block_int8_kernel[grid](
|
|
||||||
q, q_int8, q_scale, qo_len,
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q,
|
|
||||||
stride_bz_qo, stride_h_qo, stride_seq_qo,
|
|
||||||
q_scale.stride(0), q_scale.stride(1),
|
|
||||||
sm_scale=(sm_scale * 1.44269504),
|
|
||||||
C=head_dim, BLK=BLKQ,
|
|
||||||
)
|
|
||||||
|
|
||||||
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
|
|
||||||
quant_per_block_int8_kernel[grid](
|
|
||||||
k, k_int8, k_scale, kv_len,
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k,
|
|
||||||
stride_bz_ko, stride_h_ko, stride_seq_ko,
|
|
||||||
k_scale.stride(0), k_scale.stride(1),
|
|
||||||
sm_scale=1.0,
|
|
||||||
C=head_dim, BLK=BLKK,
|
|
||||||
)
|
|
||||||
|
|
||||||
return q_int8, q_scale, k_int8, k_scale
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
"""
|
|
||||||
Sparse INT8 attention kernel for Sparse SageAttention.
|
|
||||||
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
Licensed under the Apache License, Version 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_fwd_inner(
|
|
||||||
acc, l_i, old_m, q, q_scale, kv_len,
|
|
||||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
|
||||||
stride_kn, stride_vn, start_m,
|
|
||||||
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
|
||||||
):
|
|
||||||
if STAGE == 1:
|
|
||||||
lo, hi = 0, start_m * BLOCK_M
|
|
||||||
elif STAGE == 2:
|
|
||||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
|
||||||
lo = tl.multiple_of(lo, BLOCK_M)
|
|
||||||
K_scale_ptr += lo // BLOCK_N
|
|
||||||
K_ptrs += stride_kn * lo
|
|
||||||
V_ptrs += stride_vn * lo
|
|
||||||
elif STAGE == 3:
|
|
||||||
lo, hi = 0, kv_len
|
|
||||||
for start_n in range(lo, hi, BLOCK_N):
|
|
||||||
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
|
|
||||||
if kbid:
|
|
||||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
|
||||||
k = tl.load(K_ptrs, mask=k_mask)
|
|
||||||
k_scale = tl.load(K_scale_ptr)
|
|
||||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
|
||||||
if STAGE == 2:
|
|
||||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
|
||||||
qk = qk + tl.where(mask, 0, -1.0e6)
|
|
||||||
local_m = tl.max(qk, 1)
|
|
||||||
new_m = tl.maximum(old_m, local_m)
|
|
||||||
qk -= new_m[:, None]
|
|
||||||
else:
|
|
||||||
local_m = tl.max(qk, 1)
|
|
||||||
new_m = tl.maximum(old_m, local_m)
|
|
||||||
qk = qk - new_m[:, None]
|
|
||||||
|
|
||||||
p = tl.math.exp2(qk)
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
alpha = tl.math.exp2(old_m - new_m)
|
|
||||||
l_i = l_i * alpha + l_ij
|
|
||||||
acc = acc * alpha[:, None]
|
|
||||||
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
|
||||||
p = p.to(tl.float16)
|
|
||||||
acc += tl.dot(p, v, out_dtype=tl.float16)
|
|
||||||
old_m = new_m
|
|
||||||
K_ptrs += BLOCK_N * stride_kn
|
|
||||||
K_scale_ptr += 1
|
|
||||||
V_ptrs += BLOCK_N * stride_vn
|
|
||||||
return acc, l_i, old_m
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_fwd(
|
|
||||||
Q, K, K_blkid, V, Q_scale, K_scale, Out,
|
|
||||||
stride_qz, stride_qh, stride_qn,
|
|
||||||
stride_kz, stride_kh, stride_kn,
|
|
||||||
stride_vz, stride_vh, stride_vn,
|
|
||||||
stride_oz, stride_oh, stride_on,
|
|
||||||
stride_kbidq, stride_kbidk,
|
|
||||||
qo_len, kv_len,
|
|
||||||
H: tl.constexpr, num_kv_groups: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr,
|
|
||||||
):
|
|
||||||
start_m = tl.program_id(0)
|
|
||||||
off_z = tl.program_id(2).to(tl.int64)
|
|
||||||
off_h = tl.program_id(1).to(tl.int64)
|
|
||||||
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
|
|
||||||
k_scale_offset = (
|
|
||||||
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
|
||||||
) * tl.cdiv(kv_len, BLOCK_N)
|
|
||||||
k_bid_offset = (
|
|
||||||
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
|
||||||
) * stride_kbidq
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_k = tl.arange(0, HEAD_DIM)
|
|
||||||
Q_ptrs = (
|
|
||||||
Q
|
|
||||||
+ (off_z * stride_qz + off_h * stride_qh)
|
|
||||||
+ offs_m[:, None] * stride_qn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
Q_scale_ptr = Q_scale + q_scale_offset + start_m
|
|
||||||
K_ptrs = (
|
|
||||||
K
|
|
||||||
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
|
|
||||||
+ offs_n[None, :] * stride_kn
|
|
||||||
+ offs_k[:, None]
|
|
||||||
)
|
|
||||||
K_scale_ptr = K_scale + k_scale_offset
|
|
||||||
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
|
|
||||||
V_ptrs = (
|
|
||||||
V
|
|
||||||
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
|
|
||||||
+ offs_n[:, None] * stride_vn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
O_block_ptr = (
|
|
||||||
Out
|
|
||||||
+ (off_z * stride_oz + off_h * stride_oh)
|
|
||||||
+ offs_m[:, None] * stride_on
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
|
||||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
|
||||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
|
||||||
q_scale = tl.load(Q_scale_ptr)
|
|
||||||
acc, l_i, m_i = _attn_fwd_inner(
|
|
||||||
acc, l_i, m_i, q, q_scale, kv_len,
|
|
||||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
|
||||||
stride_kn, stride_vn,
|
|
||||||
start_m,
|
|
||||||
BLOCK_M, HEAD_DIM, BLOCK_N,
|
|
||||||
4 - STAGE, offs_m, offs_n,
|
|
||||||
)
|
|
||||||
if STAGE != 1:
|
|
||||||
acc, l_i, _ = _attn_fwd_inner(
|
|
||||||
acc, l_i, m_i, q, q_scale, kv_len,
|
|
||||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
|
||||||
stride_kn, stride_vn,
|
|
||||||
start_m,
|
|
||||||
BLOCK_M, HEAD_DIM, BLOCK_N,
|
|
||||||
2, offs_m, offs_n,
|
|
||||||
)
|
|
||||||
acc = acc / l_i[:, None]
|
|
||||||
tl.store(
|
|
||||||
O_block_ptr,
|
|
||||||
acc.to(Out.type.element_ty),
|
|
||||||
mask=(offs_m[:, None] < qo_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
q, k, k_block_id, v, q_scale, k_scale,
|
|
||||||
is_causal=False, tensor_layout="HND", output_dtype=torch.float16,
|
|
||||||
):
|
|
||||||
BLOCK_M = 128
|
|
||||||
BLOCK_N = 64
|
|
||||||
stage = 3 if is_causal else 1
|
|
||||||
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
|
||||||
b, h_qo, qo_len, head_dim = q.shape
|
|
||||||
_, h_kv, kv_len, _ = k.shape
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
|
||||||
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
|
|
||||||
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
|
|
||||||
elif tensor_layout == "NHD":
|
|
||||||
b, qo_len, h_qo, head_dim = q.shape
|
|
||||||
_, kv_len, h_kv, _ = k.shape
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
|
||||||
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
|
|
||||||
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"tensor_layout {tensor_layout} not supported")
|
|
||||||
|
|
||||||
if is_causal:
|
|
||||||
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
|
|
||||||
|
|
||||||
HEAD_DIM_K = head_dim
|
|
||||||
num_kv_groups = h_qo // h_kv
|
|
||||||
|
|
||||||
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
|
|
||||||
_attn_fwd[grid](
|
|
||||||
q, k, k_block_id, v, q_scale, k_scale, o,
|
|
||||||
stride_bz_q, stride_h_q, stride_seq_q,
|
|
||||||
stride_bz_k, stride_h_k, stride_seq_k,
|
|
||||||
stride_bz_v, stride_h_v, stride_seq_v,
|
|
||||||
stride_bz_o, stride_h_o, stride_seq_o,
|
|
||||||
k_block_id.stride(1), k_block_id.stride(2),
|
|
||||||
qo_len, kv_len,
|
|
||||||
h_qo, num_kv_groups,
|
|
||||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
|
|
||||||
STAGE=stage,
|
|
||||||
num_warps=4 if head_dim == 64 else 8,
|
|
||||||
num_stages=4,
|
|
||||||
)
|
|
||||||
return o
|
|
||||||
@@ -1,460 +0,0 @@
|
|||||||
import torch, os, gc
|
|
||||||
from safetensors import safe_open
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from tqdm import tqdm
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
CACHE_T = 2
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
|
||||||
|
|
||||||
old_register_parameter = torch.nn.Module.register_parameter
|
|
||||||
if include_buffers:
|
|
||||||
old_register_buffer = torch.nn.Module.register_buffer
|
|
||||||
|
|
||||||
def register_empty_parameter(module, name, param):
|
|
||||||
old_register_parameter(module, name, param)
|
|
||||||
if param is not None:
|
|
||||||
param_cls = type(module._parameters[name])
|
|
||||||
kwargs = module._parameters[name].__dict__
|
|
||||||
kwargs["requires_grad"] = param.requires_grad
|
|
||||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
||||||
|
|
||||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
|
||||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
|
||||||
if buffer is not None:
|
|
||||||
module._buffers[name] = module._buffers[name].to(device)
|
|
||||||
|
|
||||||
def patch_tensor_constructor(fn):
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
kwargs["device"] = device
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
if include_buffers:
|
|
||||||
tensor_constructors_to_patch = {
|
|
||||||
torch_function_name: getattr(torch, torch_function_name)
|
|
||||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
tensor_constructors_to_patch = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
torch.nn.Module.register_parameter = register_empty_parameter
|
|
||||||
if include_buffers:
|
|
||||||
torch.nn.Module.register_buffer = register_empty_buffer
|
|
||||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
|
||||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
torch.nn.Module.register_parameter = old_register_parameter
|
|
||||||
if include_buffers:
|
|
||||||
torch.nn.Module.register_buffer = old_register_buffer
|
|
||||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
|
||||||
setattr(torch, torch_function_name, old_torch_function)
|
|
||||||
|
|
||||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
for file_name in os.listdir(file_path):
|
|
||||||
if "." in file_name and file_name.split(".")[-1] in [
|
|
||||||
"safetensors", "bin", "ckpt", "pth", "pt"
|
|
||||||
]:
|
|
||||||
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None):
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
||||||
else:
|
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
state_dict[k] = f.get_tensor(k)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
for i in state_dict:
|
|
||||||
if isinstance(state_dict[i], torch.Tensor):
|
|
||||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_embeddings(state_dict):
|
|
||||||
embeddings = []
|
|
||||||
for k in state_dict:
|
|
||||||
if isinstance(state_dict[k], torch.Tensor):
|
|
||||||
embeddings.append(state_dict[k])
|
|
||||||
elif isinstance(state_dict[k], dict):
|
|
||||||
embeddings += search_for_embeddings(state_dict[k])
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def search_parameter(param, state_dict):
|
|
||||||
for name, param_ in state_dict.items():
|
|
||||||
if param.numel() == param_.numel():
|
|
||||||
if param.shape == param_.shape:
|
|
||||||
if torch.dist(param, param_) < 1e-3:
|
|
||||||
return name
|
|
||||||
else:
|
|
||||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
|
||||||
return name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
||||||
matched_keys = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
for name in source_state_dict:
|
|
||||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
||||||
if rename is not None:
|
|
||||||
print(f'"{name}": "{rename}",')
|
|
||||||
matched_keys.add(rename)
|
|
||||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
||||||
length = source_state_dict[name].shape[0] // 3
|
|
||||||
rename = []
|
|
||||||
for i in range(3):
|
|
||||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
||||||
if None not in rename:
|
|
||||||
print(f'"{name}": {rename},')
|
|
||||||
for rename_ in rename:
|
|
||||||
matched_keys.add(rename_)
|
|
||||||
for name in target_state_dict:
|
|
||||||
if name not in matched_keys:
|
|
||||||
print("Cannot find", name, target_state_dict[name].shape)
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_files(folder, extensions):
|
|
||||||
files = []
|
|
||||||
if os.path.isdir(folder):
|
|
||||||
for file in sorted(os.listdir(folder)):
|
|
||||||
files += search_for_files(os.path.join(folder, file), extensions)
|
|
||||||
elif os.path.isfile(folder):
|
|
||||||
for extension in extensions:
|
|
||||||
if folder.endswith(extension):
|
|
||||||
files.append(folder)
|
|
||||||
break
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
|
||||||
keys = []
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if isinstance(key, str):
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
if with_shape:
|
|
||||||
shape = "_".join(map(str, list(value.shape)))
|
|
||||||
keys.append(key + ":" + shape)
|
|
||||||
keys.append(key)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
|
||||||
keys.sort()
|
|
||||||
keys_str = ",".join(keys)
|
|
||||||
return keys_str
|
|
||||||
|
|
||||||
|
|
||||||
def split_state_dict_with_prefix(state_dict):
|
|
||||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
|
||||||
prefix_dict = {}
|
|
||||||
for key in keys:
|
|
||||||
prefix = key if "." not in key else key.split(".")[0]
|
|
||||||
if prefix not in prefix_dict:
|
|
||||||
prefix_dict[prefix] = []
|
|
||||||
prefix_dict[prefix].append(key)
|
|
||||||
state_dicts = []
|
|
||||||
for prefix, keys in prefix_dict.items():
|
|
||||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
|
||||||
state_dicts.append(sub_state_dict)
|
|
||||||
return state_dicts
|
|
||||||
|
|
||||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
|
||||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
|
||||||
keys_str = keys_str.encode(encoding="UTF-8")
|
|
||||||
return hashlib.md5(keys_str).hexdigest()
|
|
||||||
|
|
||||||
def clean_vram():
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
if torch.mps.is_available():
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
def get_device_list():
|
|
||||||
devs = []
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
|
||||||
|
|
||||||
if torch.mps.is_available():
|
|
||||||
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
|
|
||||||
|
|
||||||
return devs
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
|
||||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
||||||
|
|
||||||
self.channel_first = channel_first
|
|
||||||
self.scale = dim**0.5
|
|
||||||
self.gamma = nn.Parameter(torch.ones(shape))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return F.normalize(
|
|
||||||
x, dim=(1 if self.channel_first else
|
|
||||||
-1)) * self.scale * self.gamma + self.bias
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Conv3d):
|
|
||||||
"""
|
|
||||||
Causal 3d convolusion.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
|
||||||
self.padding[1], 2 * self.padding[0], 0)
|
|
||||||
self.padding = (0, 0, 0)
|
|
||||||
|
|
||||||
def forward(self, x, cache_x=None):
|
|
||||||
padding = list(self._padding)
|
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
|
||||||
cache_x = cache_x.to(x.device)
|
|
||||||
# print(cache_x.shape, x.shape)
|
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
|
||||||
padding[4] -= cache_x.shape[2]
|
|
||||||
# print('cache!')
|
|
||||||
x = F.pad(x, padding, mode='replicate') # mode='replicate'
|
|
||||||
# print(x[0,0,:,0,0])
|
|
||||||
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
class PixelShuffle3d(nn.Module):
|
|
||||||
def __init__(self, ff, hh, ww):
|
|
||||||
super().__init__()
|
|
||||||
self.ff = ff
|
|
||||||
self.hh = hh
|
|
||||||
self.ww = ww
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (B, C, F, H, W)
|
|
||||||
return rearrange(x,
|
|
||||||
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
|
||||||
ff=self.ff, hh=self.hh, ww=self.ww)
|
|
||||||
|
|
||||||
class Buffer_LQ4x_Proj(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, layer_num=30):
|
|
||||||
super().__init__()
|
|
||||||
self.ff = 1
|
|
||||||
self.hh = 16
|
|
||||||
self.ww = 16
|
|
||||||
self.hidden_dim1 = 2048
|
|
||||||
self.hidden_dim2 = 3072
|
|
||||||
self.layer_num = layer_num
|
|
||||||
|
|
||||||
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
|
||||||
|
|
||||||
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
|
||||||
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
|
||||||
self.act1 = nn.SiLU()
|
|
||||||
|
|
||||||
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
|
||||||
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
|
||||||
self.act2 = nn.SiLU()
|
|
||||||
|
|
||||||
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
|
||||||
|
|
||||||
self.clip_idx = 0
|
|
||||||
|
|
||||||
def forward(self, video):
|
|
||||||
self.clear_cache()
|
|
||||||
# x: (B, C, F, H, W)
|
|
||||||
|
|
||||||
t = video.shape[2]
|
|
||||||
iter_ = 1 + (t - 1) // 4
|
|
||||||
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
||||||
video = torch.cat([first_frame, video], dim=2)
|
|
||||||
# print(video.shape)
|
|
||||||
|
|
||||||
out_x = []
|
|
||||||
for i in range(iter_):
|
|
||||||
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv1'] = cache1_x
|
|
||||||
x = self.conv1(x, self.cache['conv1'])
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv2'] = cache2_x
|
|
||||||
if i == 0:
|
|
||||||
continue
|
|
||||||
x = self.conv2(x, self.cache['conv2'])
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.act2(x)
|
|
||||||
out_x.append(x)
|
|
||||||
out_x = torch.cat(out_x, dim = 2)
|
|
||||||
# print(out_x.shape)
|
|
||||||
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
|
||||||
outputs = []
|
|
||||||
for i in range(self.layer_num):
|
|
||||||
outputs.append(self.linear_layers[i](out_x))
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cache = {}
|
|
||||||
self.cache['conv1'] = None
|
|
||||||
self.cache['conv2'] = None
|
|
||||||
self.clip_idx = 0
|
|
||||||
|
|
||||||
def stream_forward(self, video_clip):
|
|
||||||
if self.clip_idx == 0:
|
|
||||||
# self.clear_cache()
|
|
||||||
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
||||||
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
|
||||||
x = self.pixel_shuffle(video_clip)
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv1'] = cache1_x
|
|
||||||
x = self.conv1(x, self.cache['conv1'])
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv2'] = cache2_x
|
|
||||||
self.clip_idx += 1
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
x = self.pixel_shuffle(video_clip)
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv1'] = cache1_x
|
|
||||||
x = self.conv1(x, self.cache['conv1'])
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv2'] = cache2_x
|
|
||||||
x = self.conv2(x, self.cache['conv2'])
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.act2(x)
|
|
||||||
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
|
||||||
outputs = []
|
|
||||||
for i in range(self.layer_num):
|
|
||||||
outputs.append(self.linear_layers[i](out_x))
|
|
||||||
self.clip_idx += 1
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class Causal_LQ4x_Proj(nn.Module):
|
|
||||||
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
|
|
||||||
|
|
||||||
Key difference: reads old cache BEFORE writing new cache (truly causal),
|
|
||||||
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, layer_num=30):
|
|
||||||
super().__init__()
|
|
||||||
self.ff = 1
|
|
||||||
self.hh = 16
|
|
||||||
self.ww = 16
|
|
||||||
self.hidden_dim1 = 2048
|
|
||||||
self.hidden_dim2 = 3072
|
|
||||||
self.layer_num = layer_num
|
|
||||||
|
|
||||||
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
|
||||||
|
|
||||||
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
|
||||||
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
|
||||||
self.act1 = nn.SiLU()
|
|
||||||
|
|
||||||
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
|
||||||
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
|
||||||
self.act2 = nn.SiLU()
|
|
||||||
|
|
||||||
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
|
||||||
|
|
||||||
self.clip_idx = 0
|
|
||||||
|
|
||||||
def forward(self, video):
|
|
||||||
self.clear_cache()
|
|
||||||
t = video.shape[2]
|
|
||||||
iter_ = 1 + (t - 1) // 4
|
|
||||||
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
||||||
video = torch.cat([first_frame, video], dim=2)
|
|
||||||
|
|
||||||
out_x = []
|
|
||||||
for i in range(iter_):
|
|
||||||
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
|
||||||
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if i == 0:
|
|
||||||
self.cache['conv2'] = cache2_x
|
|
||||||
continue
|
|
||||||
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
|
||||||
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.act2(x)
|
|
||||||
out_x.append(x)
|
|
||||||
out_x = torch.cat(out_x, dim=2)
|
|
||||||
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
|
||||||
outputs = []
|
|
||||||
for i in range(self.layer_num):
|
|
||||||
outputs.append(self.linear_layers[i](out_x))
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cache = {}
|
|
||||||
self.cache['conv1'] = None
|
|
||||||
self.cache['conv2'] = None
|
|
||||||
self.clip_idx = 0
|
|
||||||
|
|
||||||
def stream_forward(self, video_clip):
|
|
||||||
if self.clip_idx == 0:
|
|
||||||
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
||||||
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
|
||||||
x = self.pixel_shuffle(video_clip)
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
|
|
||||||
self.cache['conv1'] = cache1_x # writes AFTER
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
self.cache['conv2'] = cache2_x
|
|
||||||
self.clip_idx += 1
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
x = self.pixel_shuffle(video_clip)
|
|
||||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
|
||||||
self.cache['conv1'] = cache1_x # writes AFTER
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
|
||||||
self.cache['conv2'] = cache2_x # writes AFTER
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.act2(x)
|
|
||||||
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
|
||||||
outputs = []
|
|
||||||
for i in range(self.layer_num):
|
|
||||||
outputs.append(self.linear_layers[i](out_x))
|
|
||||||
self.clip_idx += 1
|
|
||||||
return outputs
|
|
||||||
@@ -1,865 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Tuple, Optional, List
|
|
||||||
from einops import rearrange
|
|
||||||
from .utils import hash_state_dict_keys
|
|
||||||
|
|
||||||
try:
|
|
||||||
import flash_attn_interface
|
|
||||||
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
|
|
||||||
FLASH_ATTN_3_AVAILABLE = True
|
|
||||||
except Exception:
|
|
||||||
FLASH_ATTN_3_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
assert callable(getattr(flash_attn, "flash_attn_func", None))
|
|
||||||
FLASH_ATTN_2_AVAILABLE = True
|
|
||||||
except Exception:
|
|
||||||
FLASH_ATTN_2_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sageattention import sageattn
|
|
||||||
assert callable(sageattn)
|
|
||||||
SAGE_ATTN_AVAILABLE = True
|
|
||||||
except Exception:
|
|
||||||
SAGE_ATTN_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .sparse_sage.core import sparse_sageattn
|
|
||||||
assert callable(sparse_sageattn)
|
|
||||||
SPARSE_SAGE_AVAILABLE = True
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
from sageattn.core import sparse_sageattn
|
|
||||||
assert callable(sparse_sageattn)
|
|
||||||
SPARSE_SAGE_AVAILABLE = True
|
|
||||||
except Exception:
|
|
||||||
SPARSE_SAGE_AVAILABLE = False
|
|
||||||
sparse_sageattn = None
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, "
|
|
||||||
f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, "
|
|
||||||
f"sage_attn={SAGE_ATTN_AVAILABLE}")
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Local / window masks
|
|
||||||
# ----------------------------
|
|
||||||
@torch.no_grad()
|
|
||||||
def build_local_block_mask_shifted_vec(block_h: int,
|
|
||||||
block_w: int,
|
|
||||||
win_h: int = 6,
|
|
||||||
win_w: int = 6,
|
|
||||||
include_self: bool = True,
|
|
||||||
device=None) -> torch.Tensor:
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
H, W = block_h, block_w
|
|
||||||
r = torch.arange(H, device=device)
|
|
||||||
c = torch.arange(W, device=device)
|
|
||||||
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
|
||||||
r_all = YY.reshape(-1)
|
|
||||||
c_all = XX.reshape(-1)
|
|
||||||
r_half = win_h // 2
|
|
||||||
c_half = win_w // 2
|
|
||||||
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
|
|
||||||
end_r = start_r + win_h - 1
|
|
||||||
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
|
|
||||||
end_c = start_c + win_w - 1
|
|
||||||
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
|
||||||
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
|
||||||
mask = in_row & in_col
|
|
||||||
if not include_self:
|
|
||||||
mask.fill_diagonal_(False)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
|
|
||||||
block_w: int,
|
|
||||||
win_h: int = 6,
|
|
||||||
win_w: int = 6,
|
|
||||||
include_self: bool = True,
|
|
||||||
device=None) -> torch.Tensor:
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
H, W = block_h, block_w
|
|
||||||
r = torch.arange(H, device=device)
|
|
||||||
c = torch.arange(W, device=device)
|
|
||||||
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
|
||||||
r_all = YY.reshape(-1)
|
|
||||||
c_all = XX.reshape(-1)
|
|
||||||
r_half = win_h // 2
|
|
||||||
c_half = win_w // 2
|
|
||||||
start_r = r_all - r_half
|
|
||||||
end_r = start_r + win_h - 1
|
|
||||||
start_c = c_all - c_half
|
|
||||||
end_c = start_c + win_w - 1
|
|
||||||
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
|
||||||
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
|
||||||
mask = in_row & in_col
|
|
||||||
if not include_self:
|
|
||||||
mask.fill_diagonal_(False)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class WindowPartition3D:
|
|
||||||
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
|
|
||||||
@staticmethod
|
|
||||||
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
|
|
||||||
B, F, H, W, C = x.shape
|
|
||||||
wf, wh, ww = win
|
|
||||||
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
|
|
||||||
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
|
|
||||||
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
|
|
||||||
return x.view(-1, wf * wh * ww, C)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
|
|
||||||
F, H, W = orig
|
|
||||||
wf, wh, ww = win
|
|
||||||
nf, nh, nw = F // wf, H // wh, W // ww
|
|
||||||
B = windows.size(0) // (nf * nh * nw)
|
|
||||||
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
|
|
||||||
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
|
|
||||||
return x.view(B, F, H, W, -1)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate_draft_block_mask(batch_size, nheads, seqlen,
|
|
||||||
q_w, k_w, topk=10, local_attn_mask=None):
|
|
||||||
assert batch_size == 1, "Only batch_size=1 supported for now"
|
|
||||||
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
|
||||||
avgpool_q = torch.mean(q_w, dim=1)
|
|
||||||
avgpool_k = torch.mean(k_w, dim=1)
|
|
||||||
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
|
||||||
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
|
|
||||||
q_heads = avgpool_q.permute(1, 0, 2)
|
|
||||||
k_heads = avgpool_k.permute(1, 0, 2)
|
|
||||||
D = avgpool_q.shape[-1]
|
|
||||||
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
|
|
||||||
|
|
||||||
repeat_head = scores.shape[0]
|
|
||||||
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
|
||||||
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
|
|
||||||
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
|
||||||
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
|
||||||
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
|
||||||
local_attn_mask = local_attn_mask.to(torch.float32)
|
|
||||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
|
||||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
|
||||||
scores = scores + local_attn_mask
|
|
||||||
|
|
||||||
attn_map = torch.softmax(scores, dim=-1)
|
|
||||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
|
||||||
loop_num, s1, s2 = attn_map.shape
|
|
||||||
flat = attn_map.reshape(loop_num, -1)
|
|
||||||
n = flat.shape[1]
|
|
||||||
apply_topk = min(flat.shape[1]-1, topk)
|
|
||||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
|
||||||
thresholds = thresholds.unsqueeze(1)
|
|
||||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
|
||||||
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
|
|
||||||
# 修正:上行变量名统一
|
|
||||||
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
|
|
||||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
|
||||||
q_w, k_w, topk=10, local_attn_mask=None):
|
|
||||||
assert batch_size == 1, "Only batch_size=1 supported for now"
|
|
||||||
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
|
||||||
|
|
||||||
avgpool_q = torch.mean(q_w, dim=1)
|
|
||||||
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
|
||||||
q_heads = avgpool_q.permute(1, 0, 2)
|
|
||||||
D = avgpool_q.shape[-1]
|
|
||||||
|
|
||||||
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
|
|
||||||
avgpool_k_split = torch.mean(k_w_split, dim=2)
|
|
||||||
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
|
|
||||||
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
|
|
||||||
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2)
|
|
||||||
|
|
||||||
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
|
|
||||||
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
|
|
||||||
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
|
|
||||||
scores = torch.cat([scores_1, scores_2], dim=-1)
|
|
||||||
|
|
||||||
repeat_head = scores.shape[0]
|
|
||||||
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
|
||||||
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
|
|
||||||
|
|
||||||
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
|
||||||
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
|
||||||
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
|
|
||||||
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
|
||||||
|
|
||||||
local_attn_mask = local_attn_mask.to(torch.float32)
|
|
||||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
|
||||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
|
||||||
|
|
||||||
assert scores.shape == local_attn_mask.shape, \
|
|
||||||
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
|
|
||||||
|
|
||||||
scores = scores + local_attn_mask
|
|
||||||
attn_map = torch.softmax(scores, dim=-1)
|
|
||||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
|
||||||
loop_num, s1, s2 = attn_map.shape
|
|
||||||
flat = attn_map.reshape(loop_num, -1)
|
|
||||||
apply_topk = min(flat.shape[1]-1, topk)
|
|
||||||
|
|
||||||
if apply_topk <= 0:
|
|
||||||
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
|
|
||||||
else:
|
|
||||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
|
||||||
thresholds = thresholds.unsqueeze(1)
|
|
||||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
|
||||||
|
|
||||||
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
|
|
||||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Attention kernels
|
|
||||||
# ----------------------------
|
|
||||||
def _sdpa_fallback(q, k, v, num_heads):
|
|
||||||
"""PyTorch scaled dot-product attention (always available)."""
|
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
|
||||||
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
|
|
||||||
|
|
||||||
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
|
||||||
try:
|
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
base_blockmask = attention_mask
|
|
||||||
x = sparse_sageattn(
|
|
||||||
q, k, v,
|
|
||||||
mask_id=base_blockmask.to(torch.int8),
|
|
||||||
is_causal=False,
|
|
||||||
tensor_layout="HND"
|
|
||||||
)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
except Exception:
|
|
||||||
SPARSE_SAGE_AVAILABLE = False
|
|
||||||
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
|
||||||
# q,k,v already rearranged to [b, n, s, d] above
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
elif compatibility_mode:
|
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
|
||||||
elif FLASH_ATTN_3_AVAILABLE:
|
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
|
||||||
if isinstance(x, tuple):
|
|
||||||
x = x[0]
|
|
||||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
||||||
elif FLASH_ATTN_2_AVAILABLE:
|
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
x = flash_attn.flash_attn_func(q, k, v)
|
|
||||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
||||||
elif SAGE_ATTN_AVAILABLE:
|
|
||||||
try:
|
|
||||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
||||||
x = sageattn(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
except Exception:
|
|
||||||
SAGE_ATTN_AVAILABLE = False
|
|
||||||
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
|
||||||
# q,k,v already rearranged to [b, n, s, d] above
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
||||||
else:
|
|
||||||
x = _sdpa_fallback(q, k, v, num_heads)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
||||||
return (x * (1 + scale) + shift)
|
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
|
||||||
half_dim = max(dim // 2, 1)
|
|
||||||
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
|
|
||||||
inv_freq = torch.pow(10000.0, -scale / half_dim)
|
|
||||||
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
|
|
||||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
||||||
return x.to(position.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
||||||
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
|
||||||
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
|
||||||
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
|
||||||
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
||||||
half_dim = max(dim // 2, 1)
|
|
||||||
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
|
|
||||||
freqs = torch.pow(theta, -base / max(dim, 1))
|
|
||||||
steps = torch.arange(end, dtype=torch.float64)
|
|
||||||
angles = torch.outer(steps, freqs)
|
|
||||||
return torch.polar(torch.ones_like(angles), angles)
|
|
||||||
|
|
||||||
|
|
||||||
def rope_apply(x, freqs, num_heads):
|
|
||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(reshaped)
|
|
||||||
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
|
|
||||||
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
|
|
||||||
return x_out.to(orig_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Norms & Blocks
|
|
||||||
# ----------------------------
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps=1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
dtype = x.dtype
|
|
||||||
return self.norm(x.float()).to(dtype) * self.weight
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionModule(nn.Module):
|
|
||||||
def __init__(self, num_heads, enable_sageattention=True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.enable_sageattention = enable_sageattention
|
|
||||||
|
|
||||||
def forward(self, q, k, v, attention_mask=None):
|
|
||||||
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
|
||||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
|
|
||||||
self.q = nn.Linear(dim, dim)
|
|
||||||
self.k = nn.Linear(dim, dim)
|
|
||||||
self.v = nn.Linear(dim, dim)
|
|
||||||
self.o = nn.Linear(dim, dim)
|
|
||||||
self.norm_q = RMSNorm(dim, eps=eps)
|
|
||||||
self.norm_k = RMSNorm(dim, eps=eps)
|
|
||||||
|
|
||||||
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
|
|
||||||
self.local_attn_mask = None
|
|
||||||
|
|
||||||
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
|
|
||||||
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
|
||||||
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
|
||||||
B, L, D = x.shape
|
|
||||||
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
|
|
||||||
assert f==2, "f must be 2"
|
|
||||||
if is_stream and (pre_cache_k is None or pre_cache_v is None):
|
|
||||||
assert f==6, " start f must be 6"
|
|
||||||
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
|
|
||||||
|
|
||||||
q = self.norm_q(self.q(x))
|
|
||||||
k = self.norm_k(self.k(x))
|
|
||||||
v = self.v(x)
|
|
||||||
q = rope_apply(q, freqs, self.num_heads)
|
|
||||||
k = rope_apply(k, freqs, self.num_heads)
|
|
||||||
|
|
||||||
win = (2, 8, 8)
|
|
||||||
q = q.view(B, f, h, w, D)
|
|
||||||
k = k.view(B, f, h, w, D)
|
|
||||||
v = v.view(B, f, h, w, D)
|
|
||||||
|
|
||||||
q_w = WindowPartition3D.partition(q, win)
|
|
||||||
k_w = WindowPartition3D.partition(k, win)
|
|
||||||
v_w = WindowPartition3D.partition(v, win)
|
|
||||||
|
|
||||||
seqlen = f//win[0]
|
|
||||||
one_len = k_w.shape[0] // B // seqlen
|
|
||||||
if pre_cache_k is not None and pre_cache_v is not None:
|
|
||||||
k_w = torch.cat([pre_cache_k, k_w], dim=0)
|
|
||||||
v_w = torch.cat([pre_cache_v, v_w], dim=0)
|
|
||||||
|
|
||||||
block_n = q_w.shape[0] // B
|
|
||||||
block_s = q_w.shape[1]
|
|
||||||
block_n_kv = k_w.shape[0] // B
|
|
||||||
|
|
||||||
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
|
|
||||||
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
|
||||||
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
|
||||||
|
|
||||||
window_size = win[0]*h*w//128
|
|
||||||
|
|
||||||
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
|
|
||||||
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
|
|
||||||
self.local_attn_mask_h = h//8
|
|
||||||
self.local_attn_mask_w = w//8
|
|
||||||
self.local_range = local_range
|
|
||||||
attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
|
||||||
|
|
||||||
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
|
|
||||||
|
|
||||||
cur_block_n, cur_block_s, _ = k_w.shape
|
|
||||||
cache_num = cur_block_n // one_len
|
|
||||||
if cache_num > kv_len:
|
|
||||||
cache_k = k_w[one_len:, :, :]
|
|
||||||
cache_v = v_w[one_len:, :, :]
|
|
||||||
else:
|
|
||||||
cache_k = k_w
|
|
||||||
cache_v = v_w
|
|
||||||
|
|
||||||
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
|
|
||||||
x = WindowPartition3D.reverse(x, win, (f, h, w))
|
|
||||||
x = x.view(B, f*h*w, D)
|
|
||||||
|
|
||||||
if is_stream:
|
|
||||||
return self.o(x), cache_k, cache_v
|
|
||||||
return self.o(x)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
仅考虑文本 context;提供持久 KV 缓存。
|
|
||||||
"""
|
|
||||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
|
|
||||||
self.q = nn.Linear(dim, dim)
|
|
||||||
self.k = nn.Linear(dim, dim)
|
|
||||||
self.v = nn.Linear(dim, dim)
|
|
||||||
self.o = nn.Linear(dim, dim)
|
|
||||||
|
|
||||||
self.norm_q = RMSNorm(dim, eps=eps)
|
|
||||||
self.norm_k = RMSNorm(dim, eps=eps)
|
|
||||||
|
|
||||||
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
|
|
||||||
|
|
||||||
# 持久缓存
|
|
||||||
self.cache_k = None
|
|
||||||
self.cache_v = None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def init_cache(self, ctx: torch.Tensor):
|
|
||||||
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
|
|
||||||
self.cache_k = self.norm_k(self.k(ctx))
|
|
||||||
self.cache_v = self.v(ctx)
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cache_k = None
|
|
||||||
self.cache_v = None
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
|
|
||||||
"""
|
|
||||||
y 即文本上下文(未做其他分支)。
|
|
||||||
"""
|
|
||||||
q = self.norm_q(self.q(x))
|
|
||||||
assert self.cache_k is not None and self.cache_v is not None
|
|
||||||
k = self.cache_k
|
|
||||||
v = self.cache_v
|
|
||||||
|
|
||||||
x = self.attn(q, k, v)
|
|
||||||
return self.o(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GateModule(nn.Module):
|
|
||||||
def __init__(self,):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x, gate, residual):
|
|
||||||
return x + gate * residual
|
|
||||||
|
|
||||||
|
|
||||||
class DiTBlock(nn.Module):
|
|
||||||
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.ffn_dim = ffn_dim
|
|
||||||
|
|
||||||
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
|
|
||||||
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
||||||
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
||||||
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
|
||||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
|
||||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
|
||||||
self.gate = GateModule()
|
|
||||||
|
|
||||||
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
|
|
||||||
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
|
||||||
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
||||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
|
||||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
||||||
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
|
|
||||||
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
|
|
||||||
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
|
|
||||||
|
|
||||||
x = self.gate(x, gate_msa, self_attn_output)
|
|
||||||
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
|
|
||||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
||||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
|
||||||
if is_stream:
|
|
||||||
return x, self_attn_cache_k, self_attn_cache_v
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Sequential(
|
|
||||||
nn.LayerNorm(in_dim),
|
|
||||||
nn.Linear(in_dim, in_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(in_dim, out_dim),
|
|
||||||
nn.LayerNorm(out_dim)
|
|
||||||
)
|
|
||||||
self.has_pos_emb = has_pos_emb
|
|
||||||
if has_pos_emb:
|
|
||||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.has_pos_emb:
|
|
||||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
|
||||||
return self.proj(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
|
||||||
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
||||||
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
|
||||||
|
|
||||||
def forward(self, x, t_mod):
|
|
||||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
|
||||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# WanModel (no image branch) — init 时即产生 KV 缓存
|
|
||||||
# ----------------------------
|
|
||||||
class WanModel(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
in_dim: int,
|
|
||||||
ffn_dim: int,
|
|
||||||
out_dim: int,
|
|
||||||
text_dim: int,
|
|
||||||
freq_dim: int,
|
|
||||||
eps: float,
|
|
||||||
patch_size: Tuple[int, int, int],
|
|
||||||
num_heads: int,
|
|
||||||
num_layers: int,
|
|
||||||
has_image_input: bool = False,
|
|
||||||
enable_sageattention: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.freq_dim = freq_dim
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
# patch embed
|
|
||||||
self.patch_embedding = nn.Conv3d(
|
|
||||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
|
||||||
# text / time embed
|
|
||||||
self.text_embedding = nn.Sequential(
|
|
||||||
nn.Linear(text_dim, dim),
|
|
||||||
nn.GELU(approximate='tanh'),
|
|
||||||
nn.Linear(dim, dim)
|
|
||||||
)
|
|
||||||
self.time_embedding = nn.Sequential(
|
|
||||||
nn.Linear(freq_dim, dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim)
|
|
||||||
)
|
|
||||||
self.time_projection = nn.Sequential(
|
|
||||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
|
||||||
|
|
||||||
# blocks
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
])
|
|
||||||
self.head = Head(dim, out_dim, patch_size, eps)
|
|
||||||
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
|
||||||
|
|
||||||
self._cross_kv_initialized = False
|
|
||||||
|
|
||||||
# 可选:手动清空 / 重新初始化
|
|
||||||
# 可选:手动清空 / 重新初始化
|
|
||||||
def clear_cross_kv(self):
|
|
||||||
for blk in self.blocks:
|
|
||||||
blk.cross_attn.clear_cache()
|
|
||||||
self._cross_kv_initialized = False
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def reinit_cross_kv(self, new_context: torch.Tensor):
|
|
||||||
ctx_txt = self.text_embedding(new_context)
|
|
||||||
for blk in self.blocks:
|
|
||||||
blk.cross_attn.init_cache(ctx_txt)
|
|
||||||
self._cross_kv_initialized = True
|
|
||||||
|
|
||||||
def patchify(self, x: torch.Tensor):
|
|
||||||
x = self.patch_embedding(x)
|
|
||||||
grid_size = x.shape[2:]
|
|
||||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
|
||||||
return x, grid_size # x, grid_size: (f, h, w)
|
|
||||||
|
|
||||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
|
||||||
return rearrange(
|
|
||||||
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
|
||||||
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
|
||||||
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timestep: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
use_gradient_checkpointing: bool = False,
|
|
||||||
use_gradient_checkpointing_offload: bool = False,
|
|
||||||
LQ_latents: Optional[List[torch.Tensor]] = None,
|
|
||||||
train_img: bool = False,
|
|
||||||
topk_ratio: Optional[float] = None,
|
|
||||||
kv_ratio: Optional[float] = None,
|
|
||||||
local_num: Optional[int] = None,
|
|
||||||
is_full_block: bool = False,
|
|
||||||
causal_idx: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# time / text embeds
|
|
||||||
t = self.time_embedding(
|
|
||||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
|
||||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
||||||
|
|
||||||
# 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它)
|
|
||||||
# context = self.text_embedding(context)
|
|
||||||
|
|
||||||
# 输入打补丁
|
|
||||||
x, (f, h, w) = self.patchify(x)
|
|
||||||
B = x.shape[0]
|
|
||||||
|
|
||||||
# window / masks 超参
|
|
||||||
win = (2, 8, 8)
|
|
||||||
seqlen = f//win[0]
|
|
||||||
if local_num is None:
|
|
||||||
local_random = random.random()
|
|
||||||
if local_random < 0.3:
|
|
||||||
local_num = seqlen - 3
|
|
||||||
elif local_random < 0.4:
|
|
||||||
local_num = seqlen - 4
|
|
||||||
elif local_random < 0.5:
|
|
||||||
local_num = seqlen - 2
|
|
||||||
else:
|
|
||||||
local_num = seqlen
|
|
||||||
|
|
||||||
window_size = win[0]*h*w//128
|
|
||||||
square_num = window_size*window_size
|
|
||||||
topk_ratio = 2.0
|
|
||||||
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
|
|
||||||
|
|
||||||
if kv_ratio is None:
|
|
||||||
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
|
|
||||||
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
|
|
||||||
|
|
||||||
decay_ratio = random.uniform(0.7, 1.0)
|
|
||||||
|
|
||||||
# RoPE 3D
|
|
||||||
freqs = torch.cat([
|
|
||||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# blocks
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
|
||||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
|
||||||
x += LQ_latents[block_id]
|
|
||||||
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
if use_gradient_checkpointing_offload:
|
|
||||||
with torch.autograd.graph.save_on_cpu():
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
|
||||||
train_img, block_id, kv_len, is_full_block, False,
|
|
||||||
None, None,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
|
||||||
train_img, block_id, kv_len, is_full_block, False,
|
|
||||||
None, None,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
|
|
||||||
train_img, block_id, kv_len, is_full_block, False,
|
|
||||||
None, None)
|
|
||||||
|
|
||||||
x = self.head(x, t)
|
|
||||||
x = self.unpatchify(x, (f, h, w))
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return WanModelStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# State dict converter(保持原映射;已忽略 has_image_input 使用)
|
|
||||||
# ----------------------------
|
|
||||||
class WanModelStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
|
||||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
|
||||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
|
||||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
|
||||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
|
||||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
|
||||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
|
||||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
|
||||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
|
||||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
|
||||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
|
||||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
|
||||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
|
||||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
|
||||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
|
||||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
|
||||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
|
||||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
|
||||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
|
||||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
|
||||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
|
||||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
|
||||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
|
||||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
|
||||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
|
||||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
|
||||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
|
||||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
|
||||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
|
||||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
|
||||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
|
||||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
|
||||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
|
||||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
|
||||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
|
||||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
|
||||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
|
||||||
"patch_embedding.bias": "patch_embedding.bias",
|
|
||||||
"patch_embedding.weight": "patch_embedding.weight",
|
|
||||||
"scale_shift_table": "head.modulation",
|
|
||||||
"proj_out.bias": "head.head.bias",
|
|
||||||
"proj_out.weight": "head.head.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
|
||||||
if name_ in rename_dict:
|
|
||||||
name_ = rename_dict[name_]
|
|
||||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
|
||||||
config = {
|
|
||||||
"model_type": "t2v",
|
|
||||||
"patch_size": (1, 2, 2),
|
|
||||||
"text_len": 512,
|
|
||||||
"in_dim": 16,
|
|
||||||
"dim": 5120,
|
|
||||||
"ffn_dim": 13824,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 40,
|
|
||||||
"num_layers": 40,
|
|
||||||
"window_size": (-1, -1),
|
|
||||||
"qk_norm": True,
|
|
||||||
"cross_attn_norm": True,
|
|
||||||
"eps": 1e-6,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
config = {}
|
|
||||||
return state_dict_, config
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
|
||||||
# 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支
|
|
||||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
|
||||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
|
|
||||||
else:
|
|
||||||
config = {}
|
|
||||||
return state_dict, config
|
|
||||||
|
|
||||||
@@ -1,847 +0,0 @@
|
|||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
CACHE_T = 2
|
|
||||||
|
|
||||||
|
|
||||||
def check_is_instance(model, module_class):
|
|
||||||
if isinstance(model, module_class):
|
|
||||||
return True
|
|
||||||
if hasattr(model, "module") and isinstance(model.module, module_class):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def block_causal_mask(x, block_size):
|
|
||||||
# params
|
|
||||||
b, n, s, _, device = *x.size(), x.device
|
|
||||||
assert s % block_size == 0
|
|
||||||
num_blocks = s // block_size
|
|
||||||
|
|
||||||
# build mask
|
|
||||||
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
|
||||||
for i in range(num_blocks):
|
|
||||||
mask[:, :,
|
|
||||||
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Conv3d):
|
|
||||||
"""
|
|
||||||
Causal 3d convolusion.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
|
||||||
self.padding[1], 2 * self.padding[0], 0)
|
|
||||||
self.padding = (0, 0, 0)
|
|
||||||
|
|
||||||
def forward(self, x, cache_x=None):
|
|
||||||
padding = list(self._padding)
|
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
|
||||||
cache_x = cache_x.to(x.device)
|
|
||||||
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
|
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
|
||||||
padding[4] -= cache_x.shape[2]
|
|
||||||
x = F.pad(x, padding)
|
|
||||||
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
|
||||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
||||||
|
|
||||||
self.channel_first = channel_first
|
|
||||||
self.scale = dim**0.5
|
|
||||||
self.gamma = nn.Parameter(torch.ones(shape))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return F.normalize(
|
|
||||||
x, dim=(1 if self.channel_first else
|
|
||||||
-1)) * self.scale * self.gamma + self.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Upsample):
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Fix bfloat16 support for nearest neighbor interpolation.
|
|
||||||
"""
|
|
||||||
return super().forward(x.float()).type_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Resample(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, mode):
|
|
||||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
|
||||||
'downsample3d')
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
# layers
|
|
||||||
if mode == 'upsample2d':
|
|
||||||
self.resample = nn.Sequential(
|
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
|
||||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
|
||||||
elif mode == 'upsample3d':
|
|
||||||
self.resample = nn.Sequential(
|
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
|
||||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
|
||||||
self.time_conv = CausalConv3d(dim,
|
|
||||||
dim * 2, (3, 1, 1),
|
|
||||||
padding=(1, 0, 0))
|
|
||||||
|
|
||||||
elif mode == 'downsample2d':
|
|
||||||
self.resample = nn.Sequential(
|
|
||||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
|
||||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
|
||||||
elif mode == 'downsample3d':
|
|
||||||
self.resample = nn.Sequential(
|
|
||||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
|
||||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
|
||||||
self.time_conv = CausalConv3d(dim,
|
|
||||||
dim, (3, 1, 1),
|
|
||||||
stride=(2, 1, 1),
|
|
||||||
padding=(0, 0, 0))
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.resample = nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
||||||
b, c, t, h, w = x.size()
|
|
||||||
if self.mode == 'upsample3d':
|
|
||||||
if feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
if feat_cache[idx] is None:
|
|
||||||
feat_cache[idx] = 'Rep'
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] != 'Rep':
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] == 'Rep':
|
|
||||||
cache_x = torch.cat([
|
|
||||||
torch.zeros_like(cache_x).to(cache_x.device),
|
|
||||||
cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if feat_cache[idx] == 'Rep':
|
|
||||||
x = self.time_conv(x)
|
|
||||||
else:
|
|
||||||
x = self.time_conv(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
|
|
||||||
x = x.reshape(b, 2, c, t, h, w)
|
|
||||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
|
||||||
3)
|
|
||||||
x = x.reshape(b, c, t * 2, h, w)
|
|
||||||
t = x.shape[2]
|
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
|
||||||
x = self.resample(x)
|
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
|
||||||
|
|
||||||
if self.mode == 'downsample3d':
|
|
||||||
if feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
if feat_cache[idx] is None:
|
|
||||||
feat_cache[idx] = x.clone()
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
cache_x = x[:, :, -1:, :, :].clone()
|
|
||||||
x = self.time_conv(
|
|
||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_weight(self, conv):
|
|
||||||
conv_weight = conv.weight
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
one_matrix = torch.eye(c1, c2)
|
|
||||||
init_matrix = one_matrix
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
def init_weight2(self, conv):
|
|
||||||
conv_weight = conv.weight.data
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
init_matrix = torch.eye(c1 // 2, c2)
|
|
||||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
|
||||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.in_dim = in_dim
|
|
||||||
self.out_dim = out_dim
|
|
||||||
|
|
||||||
# layers
|
|
||||||
self.residual = nn.Sequential(
|
|
||||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
|
||||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
|
||||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
|
||||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
|
||||||
if in_dim != out_dim else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
||||||
h = self.shortcut(x)
|
|
||||||
for layer in self.residual:
|
|
||||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x + h
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Causal self-attention with a single head.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
# layers
|
|
||||||
self.norm = RMS_norm(dim)
|
|
||||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
|
||||||
self.proj = nn.Conv2d(dim, dim, 1)
|
|
||||||
|
|
||||||
# zero out the last layer params
|
|
||||||
nn.init.zeros_(self.proj.weight)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
b, c, t, h, w = x.size()
|
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
|
||||||
x = self.norm(x)
|
|
||||||
# compute query, key, value
|
|
||||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
|
||||||
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
|
||||||
|
|
||||||
# apply attention
|
|
||||||
x = F.scaled_dot_product_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
#attn_mask=block_causal_mask(q, block_size=h * w)
|
|
||||||
)
|
|
||||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
|
||||||
|
|
||||||
# output
|
|
||||||
x = self.proj(x)
|
|
||||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
|
||||||
return x + identity
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder3d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim=128,
|
|
||||||
z_dim=4,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_downsample=[True, True, False],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.z_dim = z_dim
|
|
||||||
self.dim_mult = dim_mult
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attn_scales = attn_scales
|
|
||||||
self.temperal_downsample = temperal_downsample
|
|
||||||
|
|
||||||
# dimensions
|
|
||||||
dims = [dim * u for u in [1] + dim_mult]
|
|
||||||
scale = 1.0
|
|
||||||
|
|
||||||
# init block
|
|
||||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
|
||||||
|
|
||||||
# downsample blocks
|
|
||||||
downsamples = []
|
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
||||||
# residual (+attention) blocks
|
|
||||||
for _ in range(num_res_blocks):
|
|
||||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
||||||
if scale in attn_scales:
|
|
||||||
downsamples.append(AttentionBlock(out_dim))
|
|
||||||
in_dim = out_dim
|
|
||||||
|
|
||||||
# downsample block
|
|
||||||
if i != len(dim_mult) - 1:
|
|
||||||
mode = 'downsample3d' if temperal_downsample[
|
|
||||||
i] else 'downsample2d'
|
|
||||||
downsamples.append(Resample(out_dim, mode=mode))
|
|
||||||
scale /= 2.0
|
|
||||||
self.downsamples = nn.Sequential(*downsamples)
|
|
||||||
|
|
||||||
# middle blocks
|
|
||||||
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
|
||||||
AttentionBlock(out_dim),
|
|
||||||
ResidualBlock(out_dim, out_dim, dropout))
|
|
||||||
|
|
||||||
# output blocks
|
|
||||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
|
||||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
||||||
if feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = self.conv1(x)
|
|
||||||
|
|
||||||
## downsamples
|
|
||||||
for layer in self.downsamples:
|
|
||||||
if feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## middle
|
|
||||||
for layer in self.middle:
|
|
||||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## head
|
|
||||||
for layer in self.head:
|
|
||||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder3d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim=128,
|
|
||||||
z_dim=4,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_upsample=[False, True, True],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.z_dim = z_dim
|
|
||||||
self.dim_mult = dim_mult
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attn_scales = attn_scales
|
|
||||||
self.temperal_upsample = temperal_upsample
|
|
||||||
|
|
||||||
# dimensions
|
|
||||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
|
||||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
|
||||||
|
|
||||||
# init block
|
|
||||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
|
||||||
|
|
||||||
# middle blocks
|
|
||||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
|
||||||
AttentionBlock(dims[0]),
|
|
||||||
ResidualBlock(dims[0], dims[0], dropout))
|
|
||||||
|
|
||||||
# upsample blocks
|
|
||||||
upsamples = []
|
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
||||||
# residual (+attention) blocks
|
|
||||||
if i == 1 or i == 2 or i == 3:
|
|
||||||
in_dim = in_dim // 2
|
|
||||||
for _ in range(num_res_blocks + 1):
|
|
||||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
||||||
if scale in attn_scales:
|
|
||||||
upsamples.append(AttentionBlock(out_dim))
|
|
||||||
in_dim = out_dim
|
|
||||||
|
|
||||||
# upsample block
|
|
||||||
if i != len(dim_mult) - 1:
|
|
||||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
|
||||||
upsamples.append(Resample(out_dim, mode=mode))
|
|
||||||
scale *= 2.0
|
|
||||||
self.upsamples = nn.Sequential(*upsamples)
|
|
||||||
|
|
||||||
# output blocks
|
|
||||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
|
||||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
||||||
## conv1
|
|
||||||
if feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = self.conv1(x)
|
|
||||||
|
|
||||||
## middle
|
|
||||||
for layer in self.middle:
|
|
||||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## upsamples
|
|
||||||
for layer in self.upsamples:
|
|
||||||
if feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## head
|
|
||||||
for layer in self.head:
|
|
||||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
|
||||||
idx = feat_idx[0]
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def count_conv3d(model):
|
|
||||||
count = 0
|
|
||||||
for m in model.modules():
|
|
||||||
if check_is_instance(m, CausalConv3d):
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
class VideoVAE_(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim=96,
|
|
||||||
z_dim=16,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_downsample=[False, True, True],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.z_dim = z_dim
|
|
||||||
self.dim_mult = dim_mult
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attn_scales = attn_scales
|
|
||||||
self.temperal_downsample = temperal_downsample
|
|
||||||
self.temperal_upsample = temperal_downsample[::-1]
|
|
||||||
|
|
||||||
# modules
|
|
||||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
|
||||||
attn_scales, self.temperal_downsample, dropout)
|
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
|
||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
mu, log_var = self.encode(x)
|
|
||||||
z = self.reparameterize(mu, log_var)
|
|
||||||
x_recon = self.decode(z)
|
|
||||||
return x_recon, mu, log_var
|
|
||||||
|
|
||||||
def encode(self, x, scale):
|
|
||||||
self.clear_cache()
|
|
||||||
## cache
|
|
||||||
t = x.shape[2]
|
|
||||||
iter_ = 1 + (t - 1) // 4
|
|
||||||
|
|
||||||
for i in range(iter_):
|
|
||||||
self._enc_conv_idx = [0]
|
|
||||||
if i == 0:
|
|
||||||
out = self.encoder(x[:, :, :1, :, :],
|
|
||||||
feat_cache=self._enc_feat_map,
|
|
||||||
feat_idx=self._enc_conv_idx)
|
|
||||||
else:
|
|
||||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
|
||||||
feat_cache=self._enc_feat_map,
|
|
||||||
feat_idx=self._enc_conv_idx)
|
|
||||||
out = torch.cat([out, out_], 2)
|
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
|
||||||
if isinstance(scale[0], torch.Tensor):
|
|
||||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
|
||||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
|
||||||
1, self.z_dim, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
|
||||||
mu = (mu - scale[0]) * scale[1]
|
|
||||||
return mu
|
|
||||||
|
|
||||||
def decode(self, z, scale):
|
|
||||||
self.clear_cache()
|
|
||||||
# z: [b,c,t,h,w]
|
|
||||||
if isinstance(scale[0], torch.Tensor):
|
|
||||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
|
||||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
|
||||||
1, self.z_dim, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
|
||||||
z = z / scale[1] + scale[0]
|
|
||||||
iter_ = z.shape[2]
|
|
||||||
x = self.conv2(z)
|
|
||||||
for i in range(iter_):
|
|
||||||
self._conv_idx = [0]
|
|
||||||
if i == 0:
|
|
||||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
else:
|
|
||||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def stream_decode(self, z, scale):
|
|
||||||
# self.clear_cache()
|
|
||||||
# z: [b,c,t,h,w]
|
|
||||||
if isinstance(scale[0], torch.Tensor):
|
|
||||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
|
||||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
|
||||||
1, self.z_dim, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
|
||||||
z = z / scale[1] + scale[0]
|
|
||||||
iter_ = z.shape[2]
|
|
||||||
x = self.conv2(z)
|
|
||||||
for i in range(iter_):
|
|
||||||
self._conv_idx = [0]
|
|
||||||
if i == 0:
|
|
||||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
else:
|
|
||||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
|
||||||
return out
|
|
||||||
|
|
||||||
def reparameterize(self, mu, log_var):
|
|
||||||
std = torch.exp(0.5 * log_var)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return eps * std + mu
|
|
||||||
|
|
||||||
def sample(self, imgs, deterministic=False):
|
|
||||||
mu, log_var = self.encode(imgs)
|
|
||||||
if deterministic:
|
|
||||||
return mu
|
|
||||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
|
||||||
return mu + std * torch.randn_like(std)
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self._conv_num = count_conv3d(self.decoder)
|
|
||||||
self._conv_idx = [0]
|
|
||||||
self._feat_map = [None] * self._conv_num
|
|
||||||
# print('self._feat_map', len(self._feat_map))
|
|
||||||
# cache encode
|
|
||||||
if self.encoder is not None:
|
|
||||||
self._enc_conv_num = count_conv3d(self.encoder)
|
|
||||||
self._enc_conv_idx = [0]
|
|
||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
|
||||||
# print('self._enc_feat_map', len(self._enc_feat_map))
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoVAE(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, z_dim=16, dim=96):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
mean = [
|
|
||||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
|
||||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
|
||||||
]
|
|
||||||
std = [
|
|
||||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
|
||||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
|
||||||
]
|
|
||||||
self.mean = torch.tensor(mean)
|
|
||||||
self.std = torch.tensor(std)
|
|
||||||
self.scale = [self.mean, 1.0 / self.std]
|
|
||||||
|
|
||||||
# init model
|
|
||||||
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
|
|
||||||
self.upsampling_factor = 8
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
||||||
x = torch.ones((length,))
|
|
||||||
if not left_bound:
|
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
||||||
if not right_bound:
|
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
|
||||||
_, _, _, H, W = data.shape
|
|
||||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
|
||||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
|
||||||
|
|
||||||
h = repeat(h, "H -> H W", H=H, W=W)
|
|
||||||
w = repeat(w, "W -> H W", H=H, W=W)
|
|
||||||
|
|
||||||
mask = torch.stack([h, w]).min(dim=0).values
|
|
||||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
|
||||||
_, _, T, H, W = hidden_states.shape
|
|
||||||
size_h, size_w = tile_size
|
|
||||||
stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
h_, w_ = h + size_h, w + size_w
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
data_device = "cpu"
|
|
||||||
computation_device = device
|
|
||||||
|
|
||||||
out_T = T * 4 - 3
|
|
||||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
|
||||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
|
||||||
|
|
||||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
|
||||||
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
|
||||||
).to(dtype=hidden_states.dtype, device=data_device)
|
|
||||||
|
|
||||||
target_h = h * self.upsampling_factor
|
|
||||||
target_w = w * self.upsampling_factor
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_h:target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w:target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
values = values / weight
|
|
||||||
values = values.clamp_(-1, 1)
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_encode(self, video, device, tile_size, tile_stride):
|
|
||||||
_, _, T, H, W = video.shape
|
|
||||||
size_h, size_w = tile_size
|
|
||||||
stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
h_, w_ = h + size_h, w + size_w
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
data_device = "cpu"
|
|
||||||
computation_device = device
|
|
||||||
|
|
||||||
out_T = (T + 3) // 4
|
|
||||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
|
||||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
|
||||||
|
|
||||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
|
||||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
|
||||||
).to(dtype=video.dtype, device=data_device)
|
|
||||||
|
|
||||||
target_h = h // self.upsampling_factor
|
|
||||||
target_w = w // self.upsampling_factor
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_h:target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w:target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
values = values / weight
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
def single_encode(self, video, device):
|
|
||||||
video = video.to(device)
|
|
||||||
x = self.model.encode(video, self.scale)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def single_decode(self, hidden_state, device):
|
|
||||||
hidden_state = hidden_state.to(device)
|
|
||||||
video = self.model.decode(hidden_state, self.scale)
|
|
||||||
return video.clamp_(-1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
|
|
||||||
videos = [video.to("cpu") for video in videos]
|
|
||||||
hidden_states = []
|
|
||||||
for video in videos:
|
|
||||||
video = video.unsqueeze(0)
|
|
||||||
if tiled:
|
|
||||||
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
|
||||||
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
|
||||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
|
||||||
else:
|
|
||||||
hidden_state = self.single_encode(video, device)
|
|
||||||
hidden_state = hidden_state.squeeze(0)
|
|
||||||
hidden_states.append(hidden_state)
|
|
||||||
hidden_states = torch.stack(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
|
||||||
videos = []
|
|
||||||
for hidden_state in hidden_states:
|
|
||||||
hidden_state = hidden_state.unsqueeze(0)
|
|
||||||
if tiled:
|
|
||||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
|
||||||
else:
|
|
||||||
video = self.single_decode(hidden_state, device)
|
|
||||||
video = video.squeeze(0)
|
|
||||||
videos.append(video)
|
|
||||||
videos = torch.stack(videos)
|
|
||||||
return videos
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.model.clear_cache()
|
|
||||||
|
|
||||||
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
hidden_states = [hidden_state for hidden_state in hidden_states]
|
|
||||||
assert len(hidden_states) == 1
|
|
||||||
hidden_state = hidden_states[0]
|
|
||||||
video = self.model.stream_decode(hidden_state, self.scale)
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return WanVideoVAEStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoVAEStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
if 'model_state' in state_dict:
|
|
||||||
state_dict = state_dict['model_state']
|
|
||||||
for name in state_dict:
|
|
||||||
state_dict_['model.' + name] = state_dict[name]
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .flashvsr_full import FlashVSRFullPipeline
|
|
||||||
from .flashvsr_tiny import FlashVSRTinyPipeline
|
|
||||||
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms import GaussianBlur
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasePipeline(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.height_division_factor = height_division_factor
|
|
||||||
self.width_division_factor = width_division_factor
|
|
||||||
self.cpu_offload = False
|
|
||||||
self.model_names = []
|
|
||||||
|
|
||||||
|
|
||||||
def check_resize_height_width(self, height, width):
|
|
||||||
if height % self.height_division_factor != 0:
|
|
||||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
|
||||||
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
|
||||||
if width % self.width_division_factor != 0:
|
|
||||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
|
||||||
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_images(self, images):
|
|
||||||
return [self.preprocess_image(image) for image in images]
|
|
||||||
|
|
||||||
|
|
||||||
def vae_output_to_image(self, vae_output):
|
|
||||||
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
|
||||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def vae_output_to_video(self, vae_output):
|
|
||||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
|
||||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
|
||||||
if len(latents) > 0:
|
|
||||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
|
||||||
height, width = value.shape[-2:]
|
|
||||||
weight = torch.ones_like(value)
|
|
||||||
for latent, mask, scale in zip(latents, masks, scales):
|
|
||||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
|
||||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
|
||||||
mask = blur(mask)
|
|
||||||
value += latent * mask * scale
|
|
||||||
weight += mask * scale
|
|
||||||
value /= weight
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
|
||||||
if special_kwargs is None:
|
|
||||||
noise_pred_global = inference_callback(prompt_emb_global)
|
|
||||||
else:
|
|
||||||
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
|
||||||
if special_local_kwargs_list is None:
|
|
||||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
|
||||||
else:
|
|
||||||
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
|
||||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
|
||||||
return noise_pred
|
|
||||||
|
|
||||||
|
|
||||||
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
|
||||||
local_prompts = local_prompts or []
|
|
||||||
masks = masks or []
|
|
||||||
mask_scales = mask_scales or []
|
|
||||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
|
||||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
|
||||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
|
||||||
masks += extended_prompt_dict.get("masks", [])
|
|
||||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
|
||||||
return prompt, local_prompts, masks, mask_scales
|
|
||||||
|
|
||||||
|
|
||||||
def enable_cpu_offload(self):
|
|
||||||
self.cpu_offload = True
|
|
||||||
|
|
||||||
|
|
||||||
def load_models_to_device(self, loadmodel_names=[]):
|
|
||||||
# only load models to device if cpu_offload is enabled
|
|
||||||
if not self.cpu_offload:
|
|
||||||
return
|
|
||||||
# offload the unneeded models to cpu
|
|
||||||
for model_name in self.model_names:
|
|
||||||
if model_name not in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
||||||
for module in model.modules():
|
|
||||||
if hasattr(module, "offload"):
|
|
||||||
module.offload()
|
|
||||||
else:
|
|
||||||
model.cpu()
|
|
||||||
# load the needed models to device
|
|
||||||
for model_name in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
|
||||||
for module in model.modules():
|
|
||||||
if hasattr(module, "onload"):
|
|
||||||
module.onload()
|
|
||||||
else:
|
|
||||||
model.to(self.device)
|
|
||||||
# fresh the cuda cache
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
|
||||||
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
|
||||||
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
||||||
return noise
|
|
||||||
@@ -1,638 +0,0 @@
|
|||||||
import types
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional, Tuple, Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
# import pyfiglet
|
|
||||||
|
|
||||||
from ..models.utils import clean_vram
|
|
||||||
from ..models import ModelManager
|
|
||||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
|
||||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
|
||||||
N, C = feat.shape[:2]
|
|
||||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
|
||||||
std = var.sqrt().view(N, C, 1, 1)
|
|
||||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
|
|
||||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = _calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = _calc_mean_std(content_feat)
|
|
||||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
|
||||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
|
||||||
# -----------------------------
|
|
||||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
|
||||||
vals = [
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
[0.125, 0.25, 0.125 ],
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
]
|
|
||||||
return torch.tensor(vals, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
N, C, H, W = x.shape
|
|
||||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
|
||||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
|
||||||
pad = radius
|
|
||||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
|
||||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
high = torch.zeros_like(x)
|
|
||||||
low = x
|
|
||||||
for i in range(levels):
|
|
||||||
radius = 2 ** i
|
|
||||||
blurred = _wavelet_blur(low, radius)
|
|
||||||
high = high + (low - blurred)
|
|
||||||
low = blurred
|
|
||||||
return high, low
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
|
||||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
|
||||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
|
||||||
return c_high + s_low
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Safetensors support ---------
|
|
||||||
# -----------------------------
|
|
||||||
st_load_file = None # Define the variable in global scope first
|
|
||||||
try:
|
|
||||||
from safetensors.torch import load_file as st_load_file
|
|
||||||
except ImportError:
|
|
||||||
# st_load_file remains None if import fails
|
|
||||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
class TorchColorCorrectorWavelet(nn.Module):
|
|
||||||
def __init__(self, levels: int = 5):
|
|
||||||
super().__init__()
|
|
||||||
self.levels = levels
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
|
||||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
|
||||||
B, C, f, H, W = x.shape
|
|
||||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
|
||||||
return y, B, f
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
|
||||||
BF, C, H, W = y.shape
|
|
||||||
assert BF == B * f
|
|
||||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
|
||||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
|
||||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
|
||||||
|
|
||||||
B, C, f, H, W = hq_image.shape
|
|
||||||
if chunk_size is None or chunk_size >= f:
|
|
||||||
hq4, B, f = self._flatten_time(hq_image)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_image)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out = self._unflatten_time(out4, B, f)
|
|
||||||
return out
|
|
||||||
|
|
||||||
outs = []
|
|
||||||
for start in range(0, f, chunk_size):
|
|
||||||
end = min(start + chunk_size, f)
|
|
||||||
hq_chunk = hq_image[:, :, start:end]
|
|
||||||
lq_chunk = lq_image[:, :, start:end]
|
|
||||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
|
||||||
outs.append(out_chunk)
|
|
||||||
out = torch.cat(outs, dim=2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版 Pipeline(仅 dit + vae)
|
|
||||||
# -----------------------------
|
|
||||||
class FlashVSRFullPipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
|
||||||
self.dit: WanModel = None
|
|
||||||
self.vae: WanVideoVAE = None
|
|
||||||
self.model_names = ['dit', 'vae']
|
|
||||||
self.height_division_factor = 16
|
|
||||||
self.width_division_factor = 16
|
|
||||||
self.use_unified_sequence_parallel = False
|
|
||||||
self.prompt_emb_posi = None
|
|
||||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
|
||||||
# 仅管理 dit / vae
|
|
||||||
dtype = next(iter(self.dit.parameters())).dtype
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
|
||||||
enable_vram_management(
|
|
||||||
self.dit,
|
|
||||||
module_map={
|
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
|
||||||
RMSNorm: AutoWrappedModule,
|
|
||||||
},
|
|
||||||
module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device=self.device,
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
max_num_param=num_persistent_param_in_dit,
|
|
||||||
overflow_module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device="cpu",
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
dtype = next(iter(self.vae.parameters())).dtype
|
|
||||||
enable_vram_management(
|
|
||||||
self.vae,
|
|
||||||
module_map={
|
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
|
||||||
torch.nn.Conv2d: AutoWrappedModule,
|
|
||||||
RMS_norm: AutoWrappedModule,
|
|
||||||
CausalConv3d: AutoWrappedModule,
|
|
||||||
Upsample: AutoWrappedModule,
|
|
||||||
torch.nn.SiLU: AutoWrappedModule,
|
|
||||||
torch.nn.Dropout: AutoWrappedModule,
|
|
||||||
},
|
|
||||||
module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device=self.device,
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.enable_cpu_offload()
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
|
||||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
|
||||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
|
||||||
if device is None: device = model_manager.device
|
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
|
||||||
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
pipe.fetch_models(model_manager)
|
|
||||||
# 可选:统一序列并行入口(此处默认关闭)
|
|
||||||
pipe.use_unified_sequence_parallel = False
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.dit
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# 新增:显式 KV 预初始化函数
|
|
||||||
# -------------------------
|
|
||||||
def init_cross_kv(
|
|
||||||
self,
|
|
||||||
context_tensor: Optional[torch.Tensor] = None,
|
|
||||||
prompt_path = None
|
|
||||||
):
|
|
||||||
self.load_models_to_device(["dit"])
|
|
||||||
"""
|
|
||||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
|
||||||
必须在 __call__ 前显式调用一次。
|
|
||||||
"""
|
|
||||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
|
||||||
if self.dit is None:
|
|
||||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
|
||||||
|
|
||||||
if context_tensor is None:
|
|
||||||
if prompt_path is None:
|
|
||||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
|
||||||
|
|
||||||
# --- Safetensors loading logic added here ---
|
|
||||||
prompt_path_lower = prompt_path.lower()
|
|
||||||
if prompt_path_lower.endswith(".safetensors"):
|
|
||||||
if st_load_file is None:
|
|
||||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
|
||||||
|
|
||||||
# Load the tensor from safetensors
|
|
||||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
|
||||||
|
|
||||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
|
||||||
if len(loaded_dict) == 1:
|
|
||||||
ctx = list(loaded_dict.values())[0]
|
|
||||||
elif 'context' in loaded_dict: # Common key for text context
|
|
||||||
ctx = loaded_dict['context']
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Default behavior for .pth, .pt, etc.
|
|
||||||
ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
else:
|
|
||||||
ctx = context_tensor
|
|
||||||
|
|
||||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
|
|
||||||
if self.prompt_emb_posi is None:
|
|
||||||
self.prompt_emb_posi = {}
|
|
||||||
self.prompt_emb_posi['context'] = ctx
|
|
||||||
|
|
||||||
if hasattr(self.dit, "reinit_cross_kv"):
|
|
||||||
self.dit.reinit_cross_kv(ctx)
|
|
||||||
else:
|
|
||||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
|
||||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
|
||||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
|
||||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
|
||||||
self.load_models_to_device([])
|
|
||||||
|
|
||||||
def prepare_unified_sequence_parallel(self):
|
|
||||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return frames
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt=None,
|
|
||||||
negative_prompt="",
|
|
||||||
denoising_strength=1.0,
|
|
||||||
seed=None,
|
|
||||||
rand_device="gpu",
|
|
||||||
height=480,
|
|
||||||
width=832,
|
|
||||||
num_frames=81,
|
|
||||||
cfg_scale=5.0,
|
|
||||||
num_inference_steps=50,
|
|
||||||
sigma_shift=5.0,
|
|
||||||
tiled=True,
|
|
||||||
tile_size=(60, 104),
|
|
||||||
tile_stride=(30, 52),
|
|
||||||
tea_cache_l1_thresh=None,
|
|
||||||
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
LQ_video=None,
|
|
||||||
is_full_block=False,
|
|
||||||
if_buffer=False,
|
|
||||||
topk_ratio=2.0,
|
|
||||||
kv_ratio=3.0,
|
|
||||||
local_range = 9,
|
|
||||||
color_fix = True,
|
|
||||||
unload_dit = False,
|
|
||||||
skip_vae = False,
|
|
||||||
):
|
|
||||||
# 只接受 cfg=1.0(与原代码一致)
|
|
||||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
|
||||||
|
|
||||||
# 要求:必须先 init_cross_kv()
|
|
||||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
|
||||||
" pipe.init_cross_kv()\n"
|
|
||||||
"或传入自定义 context:\n"
|
|
||||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_frames % 4 != 1:
|
|
||||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
|
||||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
|
||||||
|
|
||||||
# Tiler 参数
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# 初始化噪声
|
|
||||||
if if_buffer:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
else:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
latents = noise
|
|
||||||
|
|
||||||
process_total_num = (num_frames - 1) // 8 - 2
|
|
||||||
is_stream = True
|
|
||||||
|
|
||||||
# 清理可能存在的 LQ_proj_in cache
|
|
||||||
if hasattr(self.dit, "LQ_proj_in"):
|
|
||||||
self.dit.LQ_proj_in.clear_cache()
|
|
||||||
|
|
||||||
frames_total = []
|
|
||||||
LQ_pre_idx = 0
|
|
||||||
LQ_cur_idx = 0
|
|
||||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
|
||||||
self.TCDecoder.clean_mem()
|
|
||||||
|
|
||||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
|
||||||
current_dit_device = next(iter(self.dit.parameters())).device
|
|
||||||
if str(current_dit_device) != str(self.device):
|
|
||||||
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
|
|
||||||
self.dit.to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
pre_cache_k = [None] * len(self.dit.blocks)
|
|
||||||
pre_cache_v = [None] * len(self.dit.blocks)
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 7
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
|
||||||
cur_latents = latents[:, :, :6, :, :]
|
|
||||||
else:
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 2
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
|
||||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
|
||||||
self.dit,
|
|
||||||
x=cur_latents,
|
|
||||||
timestep=self.timestep,
|
|
||||||
context=None,
|
|
||||||
tea_cache=None,
|
|
||||||
use_unified_sequence_parallel=False,
|
|
||||||
LQ_latents=LQ_latents,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k,
|
|
||||||
pre_cache_v=pre_cache_v,
|
|
||||||
topk_ratio=topk_ratio,
|
|
||||||
kv_ratio=kv_ratio,
|
|
||||||
cur_process_idx=cur_process_idx,
|
|
||||||
t_mod=self.t_mod,
|
|
||||||
t=self.t,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_latents = cur_latents - noise_pred_posi
|
|
||||||
|
|
||||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
|
||||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
|
||||||
|
|
||||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
|
||||||
cur_frames = self.TCDecoder.decode_video(
|
|
||||||
cur_latents.transpose(1, 2),
|
|
||||||
parallel=False,
|
|
||||||
show_progress_bar=False,
|
|
||||||
cond=cur_LQ_frame
|
|
||||||
).transpose(1, 2).mul_(2).sub_(1)
|
|
||||||
else:
|
|
||||||
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
|
|
||||||
|
|
||||||
# Per-chunk color correction
|
|
||||||
try:
|
|
||||||
if color_fix:
|
|
||||||
cur_frames = self.ColorCorrector(
|
|
||||||
cur_frames.to(device=self.device),
|
|
||||||
cur_LQ_frame,
|
|
||||||
clip_range=(-1, 1),
|
|
||||||
chunk_size=None,
|
|
||||||
method='adain'
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
frames_total.append(cur_frames.to('cpu'))
|
|
||||||
LQ_pre_idx = LQ_cur_idx
|
|
||||||
|
|
||||||
del cur_frames, cur_latents, cur_LQ_frame
|
|
||||||
clean_vram()
|
|
||||||
|
|
||||||
frames = torch.cat(frames_total, dim=2)
|
|
||||||
return frames[0]
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# TeaCache(保留原逻辑;此处默认不启用)
|
|
||||||
# -----------------------------
|
|
||||||
class TeaCache:
|
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.step = 0
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = None
|
|
||||||
self.rel_l1_thresh = rel_l1_thresh
|
|
||||||
self.previous_residual = None
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
self.coefficients_dict = {
|
|
||||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
|
||||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
|
||||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
|
||||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
|
||||||
}
|
|
||||||
if model_id not in self.coefficients_dict:
|
|
||||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
|
||||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
|
||||||
self.coefficients = self.coefficients_dict[model_id]
|
|
||||||
|
|
||||||
def check(self, dit: WanModel, x, t_mod):
|
|
||||||
modulated_inp = t_mod.clone()
|
|
||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
else:
|
|
||||||
coefficients = self.coefficients
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
||||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
|
||||||
if should_calc:
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = modulated_inp
|
|
||||||
self.step = (self.step + 1) % self.num_inference_steps
|
|
||||||
if should_calc:
|
|
||||||
self.previous_hidden_states = x.clone()
|
|
||||||
return not should_calc
|
|
||||||
|
|
||||||
def store(self, hidden_states):
|
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def update(self, hidden_states):
|
|
||||||
hidden_states = hidden_states + self.previous_residual
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
|
||||||
# -----------------------------
|
|
||||||
def model_fn_wan_video(
|
|
||||||
dit: WanModel,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timestep: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
tea_cache: Optional[TeaCache] = None,
|
|
||||||
use_unified_sequence_parallel: bool = False,
|
|
||||||
LQ_latents: Optional[torch.Tensor] = None,
|
|
||||||
is_full_block: bool = False,
|
|
||||||
is_stream: bool = False,
|
|
||||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
|
||||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
|
||||||
topk_ratio: float = 2.0,
|
|
||||||
kv_ratio: float = 3.0,
|
|
||||||
cur_process_idx: int = 0,
|
|
||||||
t_mod : torch.Tensor = None,
|
|
||||||
t : torch.Tensor = None,
|
|
||||||
local_range: int = 9,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# patchify
|
|
||||||
x, (f, h, w) = dit.patchify(x)
|
|
||||||
|
|
||||||
win = (2, 8, 8)
|
|
||||||
seqlen = f // win[0]
|
|
||||||
local_num = seqlen
|
|
||||||
window_size = win[0] * h * w // 128
|
|
||||||
square_num = window_size * window_size
|
|
||||||
topk = int(square_num * topk_ratio) - 1
|
|
||||||
kv_len = int(kv_ratio)
|
|
||||||
|
|
||||||
# RoPE 位置(分段)
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
else:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
# TeaCache(默认不启用)
|
|
||||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
|
||||||
|
|
||||||
# 统一序列并行(此处默认关闭)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
||||||
get_sequence_parallel_world_size,
|
|
||||||
get_sp_group)
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
|
|
||||||
# Block 堆叠
|
|
||||||
if tea_cache_update:
|
|
||||||
x = tea_cache.update(x)
|
|
||||||
else:
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
|
||||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
|
||||||
x = x + LQ_latents[block_id]
|
|
||||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
|
||||||
x, context, t_mod, freqs, f, h, w,
|
|
||||||
local_num, topk,
|
|
||||||
block_id=block_id,
|
|
||||||
kv_len=kv_len,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
|
||||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
|
||||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
|
||||||
|
|
||||||
x = dit.head(x, t)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import get_sp_group
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
|
||||||
return x, pre_cache_k, pre_cache_v
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
import types
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional, Tuple, Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
# import pyfiglet
|
|
||||||
|
|
||||||
from ..models.utils import clean_vram
|
|
||||||
from ..models import ModelManager
|
|
||||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
|
||||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
|
||||||
N, C = feat.shape[:2]
|
|
||||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
|
||||||
std = var.sqrt().view(N, C, 1, 1)
|
|
||||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
|
|
||||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = _calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = _calc_mean_std(content_feat)
|
|
||||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
|
||||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
|
||||||
# -----------------------------
|
|
||||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
|
||||||
vals = [
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
[0.125, 0.25, 0.125 ],
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
]
|
|
||||||
return torch.tensor(vals, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
N, C, H, W = x.shape
|
|
||||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
|
||||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
|
||||||
pad = radius
|
|
||||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
|
||||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
high = torch.zeros_like(x)
|
|
||||||
low = x
|
|
||||||
for i in range(levels):
|
|
||||||
radius = 2 ** i
|
|
||||||
blurred = _wavelet_blur(low, radius)
|
|
||||||
high = high + (low - blurred)
|
|
||||||
low = blurred
|
|
||||||
return high, low
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
|
||||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
|
||||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
|
||||||
return c_high + s_low
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Safetensors support ---------
|
|
||||||
# -----------------------------
|
|
||||||
st_load_file = None # Define the variable in global scope first
|
|
||||||
try:
|
|
||||||
from safetensors.torch import load_file as st_load_file
|
|
||||||
except ImportError:
|
|
||||||
# st_load_file remains None if import fails
|
|
||||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
class TorchColorCorrectorWavelet(nn.Module):
|
|
||||||
def __init__(self, levels: int = 5):
|
|
||||||
super().__init__()
|
|
||||||
self.levels = levels
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
|
||||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
|
||||||
B, C, f, H, W = x.shape
|
|
||||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
|
||||||
return y, B, f
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
|
||||||
BF, C, H, W = y.shape
|
|
||||||
assert BF == B * f
|
|
||||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
|
||||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
|
||||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
|
||||||
|
|
||||||
B, C, f, H, W = hq_image.shape
|
|
||||||
if chunk_size is None or chunk_size >= f:
|
|
||||||
hq4, B, f = self._flatten_time(hq_image)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_image)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out = self._unflatten_time(out4, B, f)
|
|
||||||
return out
|
|
||||||
|
|
||||||
outs = []
|
|
||||||
for start in range(0, f, chunk_size):
|
|
||||||
end = min(start + chunk_size, f)
|
|
||||||
hq_chunk = hq_image[:, :, start:end]
|
|
||||||
lq_chunk = lq_image[:, :, start:end]
|
|
||||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
|
||||||
outs.append(out_chunk)
|
|
||||||
out = torch.cat(outs, dim=2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版 Pipeline(仅 dit + vae)
|
|
||||||
# -----------------------------
|
|
||||||
class FlashVSRTinyPipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
|
||||||
self.dit: WanModel = None
|
|
||||||
self.vae: WanVideoVAE = None
|
|
||||||
self.model_names = ['dit', 'vae']
|
|
||||||
self.height_division_factor = 16
|
|
||||||
self.width_division_factor = 16
|
|
||||||
self.use_unified_sequence_parallel = False
|
|
||||||
self.prompt_emb_posi = None
|
|
||||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
|
||||||
# 仅管理 dit / vae
|
|
||||||
dtype = next(iter(self.dit.parameters())).dtype
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
|
||||||
enable_vram_management(
|
|
||||||
self.dit,
|
|
||||||
module_map={
|
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
|
||||||
RMSNorm: AutoWrappedModule,
|
|
||||||
},
|
|
||||||
module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device=self.device,
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
max_num_param=num_persistent_param_in_dit,
|
|
||||||
overflow_module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device="cpu",
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.enable_cpu_offload()
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
|
||||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
|
||||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
|
||||||
if device is None: device = model_manager.device
|
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
|
||||||
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
pipe.fetch_models(model_manager)
|
|
||||||
# 可选:统一序列并行入口(此处默认关闭)
|
|
||||||
pipe.use_unified_sequence_parallel = False
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.dit
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# 新增:显式 KV 预初始化函数
|
|
||||||
# -------------------------
|
|
||||||
def init_cross_kv(
|
|
||||||
self,
|
|
||||||
context_tensor: Optional[torch.Tensor] = None,
|
|
||||||
prompt_path = None,
|
|
||||||
):
|
|
||||||
self.load_models_to_device(["dit"])
|
|
||||||
"""
|
|
||||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
|
||||||
必须在 __call__ 前显式调用一次。
|
|
||||||
"""
|
|
||||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
|
||||||
|
|
||||||
if self.dit is None:
|
|
||||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
|
||||||
|
|
||||||
if context_tensor is None:
|
|
||||||
if prompt_path is None:
|
|
||||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
|
||||||
|
|
||||||
# --- Safetensors loading logic added here ---
|
|
||||||
prompt_path_lower = prompt_path.lower()
|
|
||||||
if prompt_path_lower.endswith(".safetensors"):
|
|
||||||
if st_load_file is None:
|
|
||||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
|
||||||
|
|
||||||
# Load the tensor from safetensors
|
|
||||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
|
||||||
|
|
||||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
|
||||||
if len(loaded_dict) == 1:
|
|
||||||
ctx = list(loaded_dict.values())[0]
|
|
||||||
elif 'context' in loaded_dict: # Common key for text context
|
|
||||||
ctx = loaded_dict['context']
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Default behavior for .pth, .pt, etc.
|
|
||||||
ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
else:
|
|
||||||
ctx = context_tensor
|
|
||||||
|
|
||||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
|
|
||||||
if self.prompt_emb_posi is None:
|
|
||||||
self.prompt_emb_posi = {}
|
|
||||||
self.prompt_emb_posi['context'] = ctx
|
|
||||||
|
|
||||||
if hasattr(self.dit, "reinit_cross_kv"):
|
|
||||||
self.dit.reinit_cross_kv(ctx)
|
|
||||||
else:
|
|
||||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
|
||||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
|
||||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
|
||||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
|
||||||
self.load_models_to_device([])
|
|
||||||
|
|
||||||
def prepare_unified_sequence_parallel(self):
|
|
||||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def decode_video(self, latents, cond=None, **kwargs):
|
|
||||||
frames = self.TCDecoder.decode_video(
|
|
||||||
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
|
||||||
parallel=False,
|
|
||||||
show_progress_bar=False,
|
|
||||||
cond=cond
|
|
||||||
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt=None,
|
|
||||||
negative_prompt="",
|
|
||||||
denoising_strength=1.0,
|
|
||||||
seed=None,
|
|
||||||
rand_device="gpu",
|
|
||||||
height=480,
|
|
||||||
width=832,
|
|
||||||
num_frames=81,
|
|
||||||
cfg_scale=5.0,
|
|
||||||
num_inference_steps=50,
|
|
||||||
sigma_shift=5.0,
|
|
||||||
tiled=True,
|
|
||||||
tile_size=(60, 104),
|
|
||||||
tile_stride=(30, 52),
|
|
||||||
tea_cache_l1_thresh=None,
|
|
||||||
tea_cache_model_id="Wan2.1-T2V-14B",
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
LQ_video=None,
|
|
||||||
is_full_block=False,
|
|
||||||
if_buffer=False,
|
|
||||||
topk_ratio=2.0,
|
|
||||||
kv_ratio=3.0,
|
|
||||||
local_range = 9,
|
|
||||||
color_fix = True,
|
|
||||||
unload_dit = False,
|
|
||||||
skip_vae = False,
|
|
||||||
):
|
|
||||||
# 只接受 cfg=1.0(与原代码一致)
|
|
||||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
|
||||||
|
|
||||||
# 要求:必须先 init_cross_kv()
|
|
||||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
|
||||||
" pipe.init_cross_kv()\n"
|
|
||||||
"或传入自定义 context:\n"
|
|
||||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 尺寸修正
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
if num_frames % 4 != 1:
|
|
||||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
|
||||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
|
||||||
|
|
||||||
# Tiler 参数
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# 初始化噪声
|
|
||||||
if if_buffer:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
else:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
latents = noise
|
|
||||||
|
|
||||||
process_total_num = (num_frames - 1) // 8 - 2
|
|
||||||
is_stream = True
|
|
||||||
|
|
||||||
# 清理可能存在的 LQ_proj_in cache
|
|
||||||
if hasattr(self.dit, "LQ_proj_in"):
|
|
||||||
self.dit.LQ_proj_in.clear_cache()
|
|
||||||
|
|
||||||
frames_total = []
|
|
||||||
self.TCDecoder.clean_mem()
|
|
||||||
LQ_pre_idx = 0
|
|
||||||
LQ_cur_idx = 0
|
|
||||||
|
|
||||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
|
||||||
current_dit_device = next(iter(self.dit.parameters())).device
|
|
||||||
if str(current_dit_device) != str(self.device):
|
|
||||||
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
|
|
||||||
self.dit.to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
pre_cache_k = [None] * len(self.dit.blocks)
|
|
||||||
pre_cache_v = [None] * len(self.dit.blocks)
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 7
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
|
||||||
cur_latents = latents[:, :, :6, :, :]
|
|
||||||
else:
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 2
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
|
||||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
|
||||||
self.dit,
|
|
||||||
x=cur_latents,
|
|
||||||
timestep=self.timestep,
|
|
||||||
context=None,
|
|
||||||
tea_cache=None,
|
|
||||||
use_unified_sequence_parallel=False,
|
|
||||||
LQ_latents=LQ_latents,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k,
|
|
||||||
pre_cache_v=pre_cache_v,
|
|
||||||
topk_ratio=topk_ratio,
|
|
||||||
kv_ratio=kv_ratio,
|
|
||||||
cur_process_idx=cur_process_idx,
|
|
||||||
t_mod=self.t_mod,
|
|
||||||
t=self.t,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_latents = cur_latents - noise_pred_posi
|
|
||||||
|
|
||||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
|
||||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
|
||||||
cur_frames = self.TCDecoder.decode_video(
|
|
||||||
cur_latents.transpose(1, 2),
|
|
||||||
parallel=False,
|
|
||||||
show_progress_bar=False,
|
|
||||||
cond=cur_LQ_frame
|
|
||||||
).transpose(1, 2).mul_(2).sub_(1)
|
|
||||||
|
|
||||||
# Per-chunk color correction
|
|
||||||
try:
|
|
||||||
if color_fix:
|
|
||||||
cur_frames = self.ColorCorrector(
|
|
||||||
cur_frames.to(device=self.device),
|
|
||||||
cur_LQ_frame,
|
|
||||||
clip_range=(-1, 1),
|
|
||||||
chunk_size=None,
|
|
||||||
method='adain'
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
frames_total.append(cur_frames.to('cpu'))
|
|
||||||
LQ_pre_idx = LQ_cur_idx
|
|
||||||
|
|
||||||
del cur_frames, cur_latents, cur_LQ_frame
|
|
||||||
clean_vram()
|
|
||||||
|
|
||||||
frames = torch.cat(frames_total, dim=2)
|
|
||||||
return frames[0]
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# TeaCache(保留原逻辑;此处默认不启用)
|
|
||||||
# -----------------------------
|
|
||||||
class TeaCache:
|
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.step = 0
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = None
|
|
||||||
self.rel_l1_thresh = rel_l1_thresh
|
|
||||||
self.previous_residual = None
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
self.coefficients_dict = {
|
|
||||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
|
||||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
|
||||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
|
||||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
|
||||||
}
|
|
||||||
if model_id not in self.coefficients_dict:
|
|
||||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
|
||||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
|
||||||
self.coefficients = self.coefficients_dict[model_id]
|
|
||||||
|
|
||||||
def check(self, dit: WanModel, x, t_mod):
|
|
||||||
modulated_inp = t_mod.clone()
|
|
||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
else:
|
|
||||||
coefficients = self.coefficients
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
||||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
|
||||||
if should_calc:
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = modulated_inp
|
|
||||||
self.step = (self.step + 1) % self.num_inference_steps
|
|
||||||
if should_calc:
|
|
||||||
self.previous_hidden_states = x.clone()
|
|
||||||
return not should_calc
|
|
||||||
|
|
||||||
def store(self, hidden_states):
|
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def update(self, hidden_states):
|
|
||||||
hidden_states = hidden_states + self.previous_residual
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
|
||||||
# -----------------------------
|
|
||||||
def model_fn_wan_video(
|
|
||||||
dit: WanModel,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timestep: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
tea_cache: Optional[TeaCache] = None,
|
|
||||||
use_unified_sequence_parallel: bool = False,
|
|
||||||
LQ_latents: Optional[torch.Tensor] = None,
|
|
||||||
is_full_block: bool = False,
|
|
||||||
is_stream: bool = False,
|
|
||||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
|
||||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
|
||||||
topk_ratio: float = 2.0,
|
|
||||||
kv_ratio: float = 3.0,
|
|
||||||
cur_process_idx: int = 0,
|
|
||||||
t_mod : torch.Tensor = None,
|
|
||||||
t : torch.Tensor = None,
|
|
||||||
local_range: int = 9,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# patchify
|
|
||||||
x, (f, h, w) = dit.patchify(x)
|
|
||||||
|
|
||||||
win = (2, 8, 8)
|
|
||||||
seqlen = f // win[0]
|
|
||||||
local_num = seqlen
|
|
||||||
window_size = win[0] * h * w // 128
|
|
||||||
square_num = window_size * window_size
|
|
||||||
topk = int(square_num * topk_ratio) - 1
|
|
||||||
kv_len = int(kv_ratio)
|
|
||||||
|
|
||||||
# RoPE 位置(分段)
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
else:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
# TeaCache(默认不启用)
|
|
||||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
|
||||||
|
|
||||||
# 统一序列并行(此处默认关闭)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
||||||
get_sequence_parallel_world_size,
|
|
||||||
get_sp_group)
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
|
|
||||||
# Block 堆叠
|
|
||||||
if tea_cache_update:
|
|
||||||
x = tea_cache.update(x)
|
|
||||||
else:
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
|
||||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
|
||||||
x = x + LQ_latents[block_id]
|
|
||||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
|
||||||
x, context, t_mod, freqs, f, h, w,
|
|
||||||
local_num, topk,
|
|
||||||
block_id=block_id,
|
|
||||||
kv_len=kv_len,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
|
||||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
|
||||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
|
||||||
|
|
||||||
x = dit.head(x, t)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import get_sp_group
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
|
||||||
return x, pre_cache_k, pre_cache_v
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
import types
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional, Tuple, Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
# import pyfiglet
|
|
||||||
|
|
||||||
from ..models.utils import clean_vram
|
|
||||||
from ..models import ModelManager
|
|
||||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
|
||||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
|
||||||
N, C = feat.shape[:2]
|
|
||||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
|
||||||
std = var.sqrt().view(N, C, 1, 1)
|
|
||||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
|
|
||||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = _calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = _calc_mean_std(content_feat)
|
|
||||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
|
||||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
|
||||||
# -----------------------------
|
|
||||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
|
||||||
vals = [
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
[0.125, 0.25, 0.125 ],
|
|
||||||
[0.0625, 0.125, 0.0625],
|
|
||||||
]
|
|
||||||
return torch.tensor(vals, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
N, C, H, W = x.shape
|
|
||||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
|
||||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
|
||||||
pad = radius
|
|
||||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
|
||||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
|
||||||
high = torch.zeros_like(x)
|
|
||||||
low = x
|
|
||||||
for i in range(levels):
|
|
||||||
radius = 2 ** i
|
|
||||||
blurred = _wavelet_blur(low, radius)
|
|
||||||
high = high + (low - blurred)
|
|
||||||
low = blurred
|
|
||||||
return high, low
|
|
||||||
|
|
||||||
|
|
||||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
|
||||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
|
||||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
|
||||||
return c_high + s_low
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Safetensors support ---------
|
|
||||||
# -----------------------------
|
|
||||||
st_load_file = None # Define the variable in global scope first
|
|
||||||
try:
|
|
||||||
from safetensors.torch import load_file as st_load_file
|
|
||||||
except ImportError:
|
|
||||||
# st_load_file remains None if import fails
|
|
||||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
|
||||||
# -----------------------------
|
|
||||||
class TorchColorCorrectorWavelet(nn.Module):
|
|
||||||
def __init__(self, levels: int = 5):
|
|
||||||
super().__init__()
|
|
||||||
self.levels = levels
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
|
||||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
|
||||||
B, C, f, H, W = x.shape
|
|
||||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
|
||||||
return y, B, f
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
|
||||||
BF, C, H, W = y.shape
|
|
||||||
assert BF == B * f
|
|
||||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
|
||||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
|
||||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
|
||||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
|
||||||
|
|
||||||
B, C, f, H, W = hq_image.shape
|
|
||||||
if chunk_size is None or chunk_size >= f:
|
|
||||||
hq4, B, f = self._flatten_time(hq_image)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_image)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out = self._unflatten_time(out4, B, f)
|
|
||||||
return out
|
|
||||||
|
|
||||||
outs = []
|
|
||||||
for start in range(0, f, chunk_size):
|
|
||||||
end = min(start + chunk_size, f)
|
|
||||||
hq_chunk = hq_image[:, :, start:end]
|
|
||||||
lq_chunk = lq_image[:, :, start:end]
|
|
||||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
|
||||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
|
||||||
if method == 'wavelet':
|
|
||||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
|
||||||
elif method == 'adain':
|
|
||||||
out4 = _adain(hq4, lq4)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知 method: {method}")
|
|
||||||
out4 = torch.clamp(out4, *clip_range)
|
|
||||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
|
||||||
outs.append(out_chunk)
|
|
||||||
out = torch.cat(outs, dim=2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版 Pipeline(仅 dit + vae)
|
|
||||||
# -----------------------------
|
|
||||||
class FlashVSRTinyLongPipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
|
||||||
self.dit: WanModel = None
|
|
||||||
self.vae: WanVideoVAE = None
|
|
||||||
self.model_names = ['dit', 'vae']
|
|
||||||
self.height_division_factor = 16
|
|
||||||
self.width_division_factor = 16
|
|
||||||
self.use_unified_sequence_parallel = False
|
|
||||||
self.prompt_emb_posi = None
|
|
||||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
|
||||||
# 仅管理 dit / vae
|
|
||||||
dtype = next(iter(self.dit.parameters())).dtype
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
|
||||||
enable_vram_management(
|
|
||||||
self.dit,
|
|
||||||
module_map={
|
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
|
||||||
RMSNorm: AutoWrappedModule,
|
|
||||||
},
|
|
||||||
module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device=self.device,
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
max_num_param=num_persistent_param_in_dit,
|
|
||||||
overflow_module_config=dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device="cpu",
|
|
||||||
computation_dtype=self.torch_dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.enable_cpu_offload()
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
|
||||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
|
||||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
|
||||||
if device is None: device = model_manager.device
|
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
|
||||||
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
pipe.fetch_models(model_manager)
|
|
||||||
# 可选:统一序列并行入口(此处默认关闭)
|
|
||||||
pipe.use_unified_sequence_parallel = False
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.dit
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# 新增:显式 KV 预初始化函数
|
|
||||||
# -------------------------
|
|
||||||
def init_cross_kv(
|
|
||||||
self,
|
|
||||||
context_tensor: Optional[torch.Tensor] = None,
|
|
||||||
prompt_path = None,
|
|
||||||
):
|
|
||||||
self.load_models_to_device(["dit"])
|
|
||||||
"""
|
|
||||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
|
||||||
必须在 __call__ 前显式调用一次。
|
|
||||||
"""
|
|
||||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
|
||||||
|
|
||||||
if self.dit is None:
|
|
||||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
|
||||||
|
|
||||||
if context_tensor is None:
|
|
||||||
if prompt_path is None:
|
|
||||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
|
||||||
|
|
||||||
# --- Safetensors loading logic added here ---
|
|
||||||
prompt_path_lower = prompt_path.lower()
|
|
||||||
if prompt_path_lower.endswith(".safetensors"):
|
|
||||||
if st_load_file is None:
|
|
||||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
|
||||||
|
|
||||||
# Load the tensor from safetensors
|
|
||||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
|
||||||
|
|
||||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
|
||||||
if len(loaded_dict) == 1:
|
|
||||||
ctx = list(loaded_dict.values())[0]
|
|
||||||
elif 'context' in loaded_dict: # Common key for text context
|
|
||||||
ctx = loaded_dict['context']
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Default behavior for .pth, .pt, etc.
|
|
||||||
ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
|
||||||
else:
|
|
||||||
ctx = context_tensor
|
|
||||||
|
|
||||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
|
|
||||||
if self.prompt_emb_posi is None:
|
|
||||||
self.prompt_emb_posi = {}
|
|
||||||
self.prompt_emb_posi['context'] = ctx
|
|
||||||
|
|
||||||
if hasattr(self.dit, "reinit_cross_kv"):
|
|
||||||
self.dit.reinit_cross_kv(ctx)
|
|
||||||
else:
|
|
||||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
|
||||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
|
||||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
|
||||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
|
||||||
self.load_models_to_device([])
|
|
||||||
|
|
||||||
def prepare_unified_sequence_parallel(self):
|
|
||||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def decode_video(self, latents, cond=None, **kwargs):
|
|
||||||
frames = self.TCDecoder.decode_video(
|
|
||||||
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
|
||||||
parallel=False,
|
|
||||||
show_progress_bar=False,
|
|
||||||
cond=cond
|
|
||||||
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt=None,
|
|
||||||
negative_prompt="",
|
|
||||||
denoising_strength=1.0,
|
|
||||||
seed=None,
|
|
||||||
rand_device="gpu",
|
|
||||||
height=480,
|
|
||||||
width=832,
|
|
||||||
num_frames=81,
|
|
||||||
cfg_scale=5.0,
|
|
||||||
num_inference_steps=50,
|
|
||||||
sigma_shift=5.0,
|
|
||||||
tiled=True,
|
|
||||||
tile_size=(60, 104),
|
|
||||||
tile_stride=(30, 52),
|
|
||||||
tea_cache_l1_thresh=None,
|
|
||||||
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
LQ_video=None,
|
|
||||||
is_full_block=False,
|
|
||||||
if_buffer=False,
|
|
||||||
topk_ratio=2.0,
|
|
||||||
kv_ratio=3.0,
|
|
||||||
local_range = 9,
|
|
||||||
color_fix = True,
|
|
||||||
unload_dit = False,
|
|
||||||
skip_vae = False,
|
|
||||||
):
|
|
||||||
# 只接受 cfg=1.0(与原代码一致)
|
|
||||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
|
||||||
|
|
||||||
# 要求:必须先 init_cross_kv()
|
|
||||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
|
||||||
" pipe.init_cross_kv()\n"
|
|
||||||
"或传入自定义 context:\n"
|
|
||||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 尺寸修正
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
if num_frames % 4 != 1:
|
|
||||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
|
||||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
|
||||||
|
|
||||||
# Tiler 参数
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# 初始化噪声
|
|
||||||
if if_buffer:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
else:
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
latents = noise
|
|
||||||
|
|
||||||
process_total_num = (num_frames - 1) // 8 - 2
|
|
||||||
is_stream = True
|
|
||||||
|
|
||||||
# 清理可能存在的 LQ_proj_in cache
|
|
||||||
if hasattr(self.dit, "LQ_proj_in"):
|
|
||||||
self.dit.LQ_proj_in.clear_cache()
|
|
||||||
|
|
||||||
frames_total = []
|
|
||||||
LQ_pre_idx = 0
|
|
||||||
LQ_cur_idx = 0
|
|
||||||
self.TCDecoder.clean_mem()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
pre_cache_k = [None] * len(self.dit.blocks)
|
|
||||||
pre_cache_v = [None] * len(self.dit.blocks)
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 7
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
|
||||||
cur_latents = latents[:, :, :6, :, :]
|
|
||||||
else:
|
|
||||||
LQ_latents = None
|
|
||||||
inner_loop_num = 2
|
|
||||||
for inner_idx in range(inner_loop_num):
|
|
||||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
|
||||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
|
||||||
) if LQ_video is not None else None
|
|
||||||
if cur is None:
|
|
||||||
continue
|
|
||||||
if LQ_latents is None:
|
|
||||||
LQ_latents = cur
|
|
||||||
else:
|
|
||||||
for layer_idx in range(len(LQ_latents)):
|
|
||||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
|
||||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
|
||||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
|
||||||
|
|
||||||
# 推理(无 motion_controller / vace)
|
|
||||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
|
||||||
self.dit,
|
|
||||||
x=cur_latents,
|
|
||||||
timestep=self.timestep,
|
|
||||||
context=None,
|
|
||||||
tea_cache=None,
|
|
||||||
use_unified_sequence_parallel=False,
|
|
||||||
LQ_latents=LQ_latents,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k,
|
|
||||||
pre_cache_v=pre_cache_v,
|
|
||||||
topk_ratio=topk_ratio,
|
|
||||||
kv_ratio=kv_ratio,
|
|
||||||
cur_process_idx=cur_process_idx,
|
|
||||||
t_mod=self.t_mod,
|
|
||||||
t=self.t,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新 latent
|
|
||||||
cur_latents = cur_latents - noise_pred_posi
|
|
||||||
|
|
||||||
# Decode
|
|
||||||
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
|
|
||||||
cur_frames = self.TCDecoder.decode_video(
|
|
||||||
cur_latents.transpose(1, 2),
|
|
||||||
parallel=False,
|
|
||||||
show_progress_bar=False,
|
|
||||||
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
|
|
||||||
|
|
||||||
# 颜色校正(wavelet)
|
|
||||||
try:
|
|
||||||
if color_fix:
|
|
||||||
cur_frames = self.ColorCorrector(
|
|
||||||
cur_frames.to(device=self.device),
|
|
||||||
cur_LQ_frame,
|
|
||||||
clip_range=(-1, 1),
|
|
||||||
chunk_size=None,
|
|
||||||
method='adain'
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
frames_total.append(cur_frames.to('cpu'))
|
|
||||||
LQ_pre_idx = LQ_cur_idx
|
|
||||||
|
|
||||||
del cur_frames, cur_latents, cur_LQ_frame
|
|
||||||
clean_vram()
|
|
||||||
|
|
||||||
frames = torch.cat(frames_total, dim=2)
|
|
||||||
return frames[0]
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# TeaCache(保留原逻辑;此处默认不启用)
|
|
||||||
# -----------------------------
|
|
||||||
class TeaCache:
|
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.step = 0
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = None
|
|
||||||
self.rel_l1_thresh = rel_l1_thresh
|
|
||||||
self.previous_residual = None
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
self.coefficients_dict = {
|
|
||||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
|
||||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
|
||||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
|
||||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
|
||||||
}
|
|
||||||
if model_id not in self.coefficients_dict:
|
|
||||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
|
||||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
|
||||||
self.coefficients = self.coefficients_dict[model_id]
|
|
||||||
|
|
||||||
def check(self, dit: WanModel, x, t_mod):
|
|
||||||
modulated_inp = t_mod.clone()
|
|
||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
else:
|
|
||||||
coefficients = self.coefficients
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
||||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
|
||||||
if should_calc:
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = modulated_inp
|
|
||||||
self.step = (self.step + 1) % self.num_inference_steps
|
|
||||||
if should_calc:
|
|
||||||
self.previous_hidden_states = x.clone()
|
|
||||||
return not should_calc
|
|
||||||
|
|
||||||
def store(self, hidden_states):
|
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def update(self, hidden_states):
|
|
||||||
hidden_states = hidden_states + self.previous_residual
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
|
||||||
# -----------------------------
|
|
||||||
def model_fn_wan_video(
|
|
||||||
dit: WanModel,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timestep: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
tea_cache: Optional[TeaCache] = None,
|
|
||||||
use_unified_sequence_parallel: bool = False,
|
|
||||||
LQ_latents: Optional[torch.Tensor] = None,
|
|
||||||
is_full_block: bool = False,
|
|
||||||
is_stream: bool = False,
|
|
||||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
|
||||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
|
||||||
topk_ratio: float = 2.0,
|
|
||||||
kv_ratio: float = 3.0,
|
|
||||||
cur_process_idx: int = 0,
|
|
||||||
t_mod : torch.Tensor = None,
|
|
||||||
t : torch.Tensor = None,
|
|
||||||
local_range: int = 9,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# patchify
|
|
||||||
x, (f, h, w) = dit.patchify(x)
|
|
||||||
|
|
||||||
win = (2, 8, 8)
|
|
||||||
seqlen = f // win[0]
|
|
||||||
local_num = seqlen
|
|
||||||
window_size = win[0] * h * w // 128
|
|
||||||
square_num = window_size * window_size
|
|
||||||
topk = int(square_num * topk_ratio) - 1
|
|
||||||
kv_len = int(kv_ratio)
|
|
||||||
|
|
||||||
# RoPE 位置(分段)
|
|
||||||
if cur_process_idx == 0:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
else:
|
|
||||||
freqs = torch.cat([
|
|
||||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
# TeaCache(默认不启用)
|
|
||||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
|
||||||
|
|
||||||
# 统一序列并行(此处默认关闭)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
||||||
get_sequence_parallel_world_size,
|
|
||||||
get_sp_group)
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
|
|
||||||
# Block 堆叠
|
|
||||||
if tea_cache_update:
|
|
||||||
x = tea_cache.update(x)
|
|
||||||
else:
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
|
||||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
|
||||||
x = x + LQ_latents[block_id]
|
|
||||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
|
||||||
x, context, t_mod, freqs, f, h, w,
|
|
||||||
local_num, topk,
|
|
||||||
block_id=block_id,
|
|
||||||
kv_len=kv_len,
|
|
||||||
is_full_block=is_full_block,
|
|
||||||
is_stream=is_stream,
|
|
||||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
|
||||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
|
||||||
local_range = local_range,
|
|
||||||
)
|
|
||||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
|
||||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
|
||||||
|
|
||||||
x = dit.head(x, t)
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import get_sp_group
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
|
||||||
return x, pre_cache_k, pre_cache_v
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .flow_match import FlowMatchScheduler
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FlowMatchScheduler():
|
|
||||||
|
|
||||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
|
||||||
self.num_train_timesteps = num_train_timesteps
|
|
||||||
self.shift = shift
|
|
||||||
self.sigma_max = sigma_max
|
|
||||||
self.sigma_min = sigma_min
|
|
||||||
self.inverse_timesteps = inverse_timesteps
|
|
||||||
self.extra_one_step = extra_one_step
|
|
||||||
self.reverse_sigmas = reverse_sigmas
|
|
||||||
self.set_timesteps(num_inference_steps)
|
|
||||||
|
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
|
||||||
if shift is not None:
|
|
||||||
self.shift = shift
|
|
||||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
|
||||||
if self.extra_one_step:
|
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
|
||||||
else:
|
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
|
||||||
if self.inverse_timesteps:
|
|
||||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
|
||||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
|
||||||
if self.reverse_sigmas:
|
|
||||||
self.sigmas = 1 - self.sigmas
|
|
||||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
|
||||||
if training:
|
|
||||||
x = self.timesteps
|
|
||||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
|
||||||
y_shifted = y - y.min()
|
|
||||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
|
||||||
self.linear_timesteps_weights = bsmntw_weighing
|
|
||||||
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
||||||
sigma = self.sigmas[timestep_id]
|
|
||||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
|
||||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
|
||||||
else:
|
|
||||||
sigma_ = self.sigmas[timestep_id + 1]
|
|
||||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
|
||||||
return prev_sample
|
|
||||||
|
|
||||||
|
|
||||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
||||||
sigma = self.sigmas[timestep_id]
|
|
||||||
model_output = (sample - sample_stablized) / sigma
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
|
|
||||||
def add_noise(self, original_samples, noise, timestep):
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
||||||
sigma = self.sigmas[timestep_id]
|
|
||||||
sample = (1 - sigma) * original_samples + sigma * noise
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def training_target(self, sample, noise, timestep):
|
|
||||||
target = noise - sample
|
|
||||||
return target
|
|
||||||
|
|
||||||
|
|
||||||
def training_weight(self, timestep):
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
|
||||||
weights = self.linear_timesteps_weights[timestep_id]
|
|
||||||
return weights
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .layers import *
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
import torch, copy
|
|
||||||
from ..models.utils import init_weights_on_device
|
|
||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype, device):
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
class AutoWrappedModule(torch.nn.Module):
|
|
||||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
|
||||||
self.offload_dtype = offload_dtype
|
|
||||||
self.offload_device = offload_device
|
|
||||||
self.onload_dtype = onload_dtype
|
|
||||||
self.onload_device = onload_device
|
|
||||||
self.computation_dtype = computation_dtype
|
|
||||||
self.computation_device = computation_device
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def offload(self):
|
|
||||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def onload(self):
|
|
||||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
|
||||||
self.state = 1
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
|
||||||
module = self.module
|
|
||||||
else:
|
|
||||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
|
||||||
return module(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoWrappedLinear(torch.nn.Linear):
|
|
||||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
|
||||||
with init_weights_on_device(device=torch.device("meta")):
|
|
||||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
|
||||||
self.weight = module.weight
|
|
||||||
self.bias = module.bias
|
|
||||||
self.offload_dtype = offload_dtype
|
|
||||||
self.offload_device = offload_device
|
|
||||||
self.onload_dtype = onload_dtype
|
|
||||||
self.onload_device = onload_device
|
|
||||||
self.computation_dtype = computation_dtype
|
|
||||||
self.computation_device = computation_device
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def offload(self):
|
|
||||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
|
||||||
self.state = 0
|
|
||||||
|
|
||||||
def onload(self):
|
|
||||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
|
||||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
|
||||||
self.state = 1
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
|
||||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
|
||||||
weight, bias = self.weight, self.bias
|
|
||||||
else:
|
|
||||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
|
||||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
for source_module, target_module in module_map.items():
|
|
||||||
if isinstance(module, source_module):
|
|
||||||
num_param = sum(p.numel() for p in module.parameters())
|
|
||||||
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
|
||||||
module_config_ = overflow_module_config
|
|
||||||
else:
|
|
||||||
module_config_ = module_config
|
|
||||||
module_ = target_module(module, **module_config_)
|
|
||||||
setattr(model, name, module_)
|
|
||||||
total_num_param += num_param
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
|
||||||
return total_num_param
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
|
||||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
|
||||||
model.vram_management_enabled = True
|
|
||||||
|
|
||||||
248
inference.py
248
inference.py
@@ -1,11 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .bim_vfi_arch import BiMVFI
|
from .bim_vfi_arch import BiMVFI
|
||||||
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
||||||
@@ -624,248 +621,3 @@ class GIMMVFIModel:
|
|||||||
results.append(torch.clamp(unpadded, 0, 1))
|
results.append(torch.clamp(unpadded, 0, 1))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# FlashVSR model wrapper (4x video super-resolution)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class FlashVSRModel:
|
|
||||||
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
|
|
||||||
|
|
||||||
Supports three pipeline modes:
|
|
||||||
- full: Standard VAE decode, highest quality
|
|
||||||
- tiny: TCDecoder decode, faster
|
|
||||||
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Minimum input frame count required by the pipeline
|
|
||||||
MIN_FRAMES = 21
|
|
||||||
|
|
||||||
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from .flashvsr_arch import (
|
|
||||||
ModelManager, FlashVSRFullPipeline,
|
|
||||||
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
|
|
||||||
)
|
|
||||||
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
|
|
||||||
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
|
|
||||||
|
|
||||||
self.mode = mode
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
|
|
||||||
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
|
|
||||||
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
|
|
||||||
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
|
|
||||||
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
|
|
||||||
|
|
||||||
mm = ModelManager(torch_dtype=dtype, device="cpu")
|
|
||||||
|
|
||||||
if mode == "full":
|
|
||||||
mm.load_models([dit_path, vae_path])
|
|
||||||
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
|
|
||||||
self.pipe.vae.model.encoder = None
|
|
||||||
self.pipe.vae.model.conv1 = None
|
|
||||||
else:
|
|
||||||
mm.load_models([dit_path])
|
|
||||||
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
|
|
||||||
self.pipe = Pipeline.from_model_manager(mm, device=device)
|
|
||||||
|
|
||||||
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
|
|
||||||
self.pipe.TCDecoder = build_tcdecoder(
|
|
||||||
[512, 256, 128, 128], device, dtype, 16 + 768,
|
|
||||||
)
|
|
||||||
self.pipe.TCDecoder.load_state_dict(
|
|
||||||
load_file(tcd_path, device=device), strict=False,
|
|
||||||
)
|
|
||||||
self.pipe.TCDecoder.clean_mem()
|
|
||||||
|
|
||||||
# LQ frame projection — Causal variant for FlashVSR v1.1
|
|
||||||
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
|
|
||||||
if os.path.exists(lq_path):
|
|
||||||
lq_sd = load_file(lq_path, device="cpu")
|
|
||||||
cleaned = {}
|
|
||||||
for k, v in lq_sd.items():
|
|
||||||
cleaned[k.removeprefix("LQ_proj_in.")] = v
|
|
||||||
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
|
|
||||||
self.pipe.denoising_model().LQ_proj_in.to(device)
|
|
||||||
|
|
||||||
self.pipe.to(device, dtype)
|
|
||||||
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
||||||
self.pipe.init_cross_kv(prompt_path=prompt_path)
|
|
||||||
self.pipe.load_models_to_device([]) # offload to CPU
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.device = device
|
|
||||||
self.pipe.device = device
|
|
||||||
return self
|
|
||||||
|
|
||||||
def load_to_device(self):
|
|
||||||
"""Load models to the compute device for inference."""
|
|
||||||
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
|
|
||||||
self.pipe.load_models_to_device(names)
|
|
||||||
|
|
||||||
def offload(self):
|
|
||||||
"""Offload models to CPU."""
|
|
||||||
self.pipe.load_models_to_device([])
|
|
||||||
|
|
||||||
def clear_caches(self):
|
|
||||||
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
|
|
||||||
self.pipe.denoising_model().LQ_proj_in.clear_cache()
|
|
||||||
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
|
|
||||||
self.pipe.vae.clear_cache()
|
|
||||||
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
|
|
||||||
self.pipe.TCDecoder.clean_mem()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Frame preprocessing / postprocessing helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_dims(w, h, scale, align=128):
|
|
||||||
sw, sh = w * scale, h * scale
|
|
||||||
tw = math.ceil(sw / align) * align
|
|
||||||
th = math.ceil(sh / align) * align
|
|
||||||
return sw, sh, tw, th
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _restore_video_sequence(result, expected):
|
|
||||||
"""Trim pipeline output to the expected frame count."""
|
|
||||||
if result.shape[0] > expected:
|
|
||||||
result = result[:expected]
|
|
||||||
elif result.shape[0] < expected:
|
|
||||||
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
|
|
||||||
result = torch.cat([result, pad], dim=0)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _next_8n5(n, minimum=21):
|
|
||||||
"""Next integer >= n of the form 8k+5 (minimum 21)."""
|
|
||||||
if n < minimum:
|
|
||||||
return minimum
|
|
||||||
return ((n - 5 + 7) // 8) * 8 + 5
|
|
||||||
|
|
||||||
def _prepare_video(self, frames, scale):
|
|
||||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
|
|
||||||
|
|
||||||
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
|
|
||||||
1. Bicubic-upscale each frame to target resolution
|
|
||||||
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
|
|
||||||
3. Normalize to [-1, 1]
|
|
||||||
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
|
|
||||||
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
video: [1, C, F_padded, H, W] tensor
|
|
||||||
th, tw: padded spatial dimensions
|
|
||||||
nf: padded frame count
|
|
||||||
sh, sw: actual (unpadded) spatial dimensions
|
|
||||||
pad_top, pad_left: spatial padding offsets for output cropping
|
|
||||||
"""
|
|
||||||
N, H, W, C = frames.shape
|
|
||||||
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
|
||||||
|
|
||||||
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
|
|
||||||
N_padded = self._next_8n5(N)
|
|
||||||
|
|
||||||
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
|
|
||||||
target = N_padded + 4
|
|
||||||
|
|
||||||
# Centered spatial padding offsets
|
|
||||||
pad_top = (th - sh) // 2
|
|
||||||
pad_bottom = th - sh - pad_top
|
|
||||||
pad_left = (tw - sw) // 2
|
|
||||||
pad_right = tw - sw - pad_left
|
|
||||||
|
|
||||||
processed = []
|
|
||||||
for i in range(target):
|
|
||||||
frame_idx = min(i, N - 1) # clamp to last real frame
|
|
||||||
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
|
||||||
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
|
||||||
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
|
|
||||||
# Centered reflect padding (matches naxci1 reference)
|
|
||||||
try:
|
|
||||||
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
|
|
||||||
except RuntimeError:
|
|
||||||
# Reflect requires pad < input size; fall back to replicate
|
|
||||||
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
|
|
||||||
normalized = upscaled * 2.0 - 1.0
|
|
||||||
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
|
||||||
|
|
||||||
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
|
||||||
nf = video.shape[2]
|
|
||||||
|
|
||||||
return video, th, tw, nf, sh, sw, pad_top, pad_left
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _to_frames(video):
|
|
||||||
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
|
|
||||||
from einops import rearrange
|
|
||||||
v = video.squeeze(0) if video.dim() == 5 else video
|
|
||||||
v = rearrange(v, "C F H W -> F H W C")
|
|
||||||
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Main upscale method
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
|
||||||
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
|
|
||||||
color_fix=True, unload_dit=False, seed=1,
|
|
||||||
progress_bar_cmd=None):
|
|
||||||
"""Upscale video frames with FlashVSR.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames: [F, H, W, C] float32 [0, 1] with F >= 21
|
|
||||||
scale: Upscaling factor (2 or 4)
|
|
||||||
tiled: Enable VAE tiled decode (saves VRAM)
|
|
||||||
tile_size: (H, W) tile size for VAE tiling
|
|
||||||
topk_ratio: Sparse attention ratio (higher = faster, less detail)
|
|
||||||
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
|
|
||||||
local_range: Local attention window (9=sharp, 11=stable)
|
|
||||||
color_fix: Apply wavelet color correction
|
|
||||||
unload_dit: Offload DiT before VAE decode (saves VRAM)
|
|
||||||
seed: Random seed
|
|
||||||
progress_bar_cmd: Callable wrapping an iterable for progress display
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[F, H*scale, W*scale, C] float32 [0, 1]
|
|
||||||
"""
|
|
||||||
if progress_bar_cmd is None:
|
|
||||||
from tqdm import tqdm
|
|
||||||
progress_bar_cmd = tqdm
|
|
||||||
|
|
||||||
original_count = frames.shape[0]
|
|
||||||
|
|
||||||
# Prepare video tensor (bicubic upscale + centered pad)
|
|
||||||
video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
|
|
||||||
|
|
||||||
# Move LQ video to compute device (except for "long" mode which streams)
|
|
||||||
if "long" not in self.pipe.__class__.__name__.lower():
|
|
||||||
video = video.to(self.pipe.device)
|
|
||||||
|
|
||||||
# Run pipeline
|
|
||||||
out = self.pipe(
|
|
||||||
prompt="", negative_prompt="",
|
|
||||||
cfg_scale=1.0, num_inference_steps=1,
|
|
||||||
seed=seed, tiled=tiled, tile_size=tile_size,
|
|
||||||
progress_bar_cmd=progress_bar_cmd,
|
|
||||||
LQ_video=video,
|
|
||||||
num_frames=nf, height=th, width=tw,
|
|
||||||
is_full_block=False, if_buffer=True,
|
|
||||||
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
|
|
||||||
kv_ratio=kv_ratio, local_range=local_range,
|
|
||||||
color_fix=color_fix, unload_dit=unload_dit,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to ComfyUI format with centered spatial crop
|
|
||||||
result = self._to_frames(out).cpu()
|
|
||||||
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
|
|
||||||
|
|
||||||
# Trim to original frame count
|
|
||||||
result = self._restore_video_sequence(result, original_count)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[project]
|
||||||
|
name = "comfyui-tween"
|
||||||
|
description = "Video frame interpolation nodes for ComfyUI using BIM-VFI, EMA-VFI, SGM-VFI, and GIMM-VFI. Designed for long videos with thousands of frames."
|
||||||
|
version = "1.1.0"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"gdown",
|
||||||
|
"timm",
|
||||||
|
"omegaconf",
|
||||||
|
"yacs",
|
||||||
|
"easydict",
|
||||||
|
"einops",
|
||||||
|
"huggingface_hub",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Repository = "https://github.com/Ethanfel/ComfyUI-Tween"
|
||||||
|
|
||||||
|
[tool.comfy]
|
||||||
|
PublisherId = "ethanfel"
|
||||||
|
DisplayName = "Tween - Video Frame Interpolation"
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
gdown
|
gdown
|
||||||
|
timm
|
||||||
omegaconf
|
omegaconf
|
||||||
yacs
|
yacs
|
||||||
easydict
|
easydict
|
||||||
einops
|
einops
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
safetensors
|
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
function fitHeight(node) {
|
|
||||||
node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]);
|
|
||||||
node?.graph?.setDirtyCanvas(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "Tween.VideoPreview",
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData) {
|
|
||||||
if (nodeData?.name !== "TweenConcatVideos") return;
|
|
||||||
|
|
||||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
onNodeCreated?.apply(this, arguments);
|
|
||||||
|
|
||||||
const container = document.createElement("div");
|
|
||||||
const previewWidget = this.addDOMWidget("videopreview", "preview", container, {
|
|
||||||
serialize: false,
|
|
||||||
hideOnZoom: false,
|
|
||||||
getValue() { return container.value; },
|
|
||||||
setValue(v) { container.value = v; },
|
|
||||||
});
|
|
||||||
|
|
||||||
previewWidget.computeSize = function (width) {
|
|
||||||
if (this.aspectRatio && !this.videoEl.hidden) {
|
|
||||||
const height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
|
|
||||||
return [width, height > 0 ? height : -4];
|
|
||||||
}
|
|
||||||
return [width, -4];
|
|
||||||
};
|
|
||||||
|
|
||||||
const previewNode = this;
|
|
||||||
|
|
||||||
previewWidget.videoEl = document.createElement("video");
|
|
||||||
previewWidget.videoEl.controls = true;
|
|
||||||
previewWidget.videoEl.loop = true;
|
|
||||||
previewWidget.videoEl.muted = true;
|
|
||||||
previewWidget.videoEl.style.width = "100%";
|
|
||||||
previewWidget.videoEl.hidden = true;
|
|
||||||
|
|
||||||
previewWidget.videoEl.addEventListener("loadedmetadata", () => {
|
|
||||||
previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
|
|
||||||
fitHeight(previewNode);
|
|
||||||
});
|
|
||||||
previewWidget.videoEl.addEventListener("error", () => {
|
|
||||||
previewWidget.videoEl.hidden = true;
|
|
||||||
fitHeight(previewNode);
|
|
||||||
});
|
|
||||||
|
|
||||||
container.appendChild(previewWidget.videoEl);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onExecuted = nodeType.prototype.onExecuted;
|
|
||||||
nodeType.prototype.onExecuted = function (message) {
|
|
||||||
onExecuted?.apply(this, arguments);
|
|
||||||
|
|
||||||
if (!message?.gifs?.length) return;
|
|
||||||
|
|
||||||
const params = message.gifs[0];
|
|
||||||
const previewWidget = this.widgets?.find((w) => w.name === "videopreview");
|
|
||||||
if (!previewWidget) return;
|
|
||||||
|
|
||||||
const query = new URLSearchParams(params);
|
|
||||||
query.set("timestamp", Date.now());
|
|
||||||
previewWidget.videoEl.src = api.apiURL("/view?" + query);
|
|
||||||
previewWidget.videoEl.hidden = false;
|
|
||||||
previewWidget.videoEl.autoplay = true;
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
Reference in New Issue
Block a user