Compare commits
15 Commits
master
...
old-master
| Author | SHA1 | Date | |
|---|---|---|---|
| dd61ae8d1f | |||
| e7e7c1cb5a | |||
| 3b87652184 | |||
| 76dff7e573 | |||
| fa250897a2 | |||
| 94d9818675 | |||
| ea84ffef7c | |||
| 4cc6e9c705 | |||
| 39d0f7af42 | |||
| 11e2acb9e0 | |||
| 5071c4de4f | |||
| dd69a2fd2b | |||
| f40504cbcf | |||
| 8317a0603e | |||
| 0fecfcee37 |
99
README.md
99
README.md
@@ -1,6 +1,6 @@
|
||||
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI
|
||||
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR
|
||||
|
||||
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM.
|
||||
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM.
|
||||
|
||||
## Which model should I use?
|
||||
|
||||
@@ -18,6 +18,21 @@ ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://githu
|
||||
|
||||
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair.
|
||||
|
||||
### Video Super-Resolution
|
||||
|
||||
FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above.
|
||||
|
||||
| | FlashVSR |
|
||||
|---|----------|
|
||||
| **Task** | 4x video super-resolution |
|
||||
| **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) |
|
||||
| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) |
|
||||
| **VRAM** | ~8–12 GB (tiled, tiny mode) / ~16–24 GB (full mode) |
|
||||
| **Params** | ~1.3B (DiT) + ~200M (VAE) |
|
||||
| **Min input** | 21 frames |
|
||||
| **Paper** | arXiv 2510.12747 |
|
||||
| **License** | Apache 2.0 |
|
||||
|
||||
## Nodes
|
||||
|
||||
### BIM-VFI
|
||||
@@ -136,7 +151,61 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, p
|
||||
|
||||
Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
||||
|
||||
**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
||||
**Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
||||
|
||||
### FlashVSR
|
||||
|
||||
FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling.
|
||||
|
||||
#### Load FlashVSR Model
|
||||
|
||||
Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) |
|
||||
| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) |
|
||||
|
||||
Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)):
|
||||
| Checkpoint | Size | Description |
|
||||
|-----------|------|-------------|
|
||||
| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) |
|
||||
| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE |
|
||||
| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection |
|
||||
| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) |
|
||||
| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings |
|
||||
|
||||
#### FlashVSR Upscale
|
||||
|
||||
Upscales an image batch with 4x spatial super-resolution.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| **images** | Input video frames (minimum 21 frames) |
|
||||
| **model** | Model from the loader node |
|
||||
| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) |
|
||||
| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames |
|
||||
| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) |
|
||||
| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) |
|
||||
| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) |
|
||||
| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) |
|
||||
| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability |
|
||||
| **color_fix** | Apply wavelet color correction to prevent color shifts |
|
||||
| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) |
|
||||
| **seed** | Random seed for the diffusion process |
|
||||
|
||||
#### FlashVSR Segment Upscale
|
||||
|
||||
Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| **segment_index** | Which segment to process (0-based) |
|
||||
| **segment_size** | Number of input frames per segment (minimum 21) |
|
||||
| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending |
|
||||
| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) |
|
||||
|
||||
Plus all the same upscale parameters as FlashVSR Upscale.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -147,7 +216,7 @@ cd ComfyUI/custom_nodes
|
||||
git clone https://github.com/your-user/ComfyUI-Tween.git
|
||||
```
|
||||
|
||||
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
|
||||
Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version.
|
||||
|
||||
> **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with:
|
||||
> ```bash
|
||||
@@ -168,7 +237,8 @@ python install.py
|
||||
- `timm` (for EMA-VFI and SGM-VFI)
|
||||
- `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download)
|
||||
- `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI)
|
||||
- `huggingface_hub` (for GIMM-VFI model auto-download)
|
||||
- `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download)
|
||||
- `safetensors` (for FlashVSR checkpoint loading)
|
||||
|
||||
## VRAM Guide
|
||||
|
||||
@@ -181,7 +251,7 @@ python install.py
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, and the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU). GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths).
|
||||
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths).
|
||||
|
||||
**BiM-VFI:**
|
||||
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
|
||||
@@ -243,6 +313,21 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
|
||||
}
|
||||
```
|
||||
|
||||
**FlashVSR:**
|
||||
> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng.
|
||||
> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer."
|
||||
> *arXiv preprint arXiv:2510.12747*, 2025.
|
||||
> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR)
|
||||
|
||||
```bibtex
|
||||
@article{zhuang2025flashvsr,
|
||||
title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer},
|
||||
author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing},
|
||||
journal={arXiv preprint arXiv:2510.12747},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
|
||||
@@ -252,3 +337,5 @@ The EMA-VFI model weights and architecture code are released under the [Apache 2
|
||||
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
|
||||
|
||||
The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI).
|
||||
|
||||
The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR).
|
||||
|
||||
11
__init__.py
11
__init__.py
@@ -34,8 +34,8 @@ def _auto_install_deps():
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tween] Could not auto-install cupy: {e}")
|
||||
|
||||
# GIMM-VFI dependencies
|
||||
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"):
|
||||
# GIMM-VFI + FlashVSR dependencies
|
||||
for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"):
|
||||
try:
|
||||
__import__(pkg)
|
||||
except ImportError:
|
||||
@@ -50,6 +50,7 @@ from .nodes import (
|
||||
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
||||
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
||||
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
||||
LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale,
|
||||
)
|
||||
|
||||
WEB_DIRECTORY = "./web"
|
||||
@@ -68,6 +69,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LoadGIMMVFIModel": LoadGIMMVFIModel,
|
||||
"GIMMVFIInterpolate": GIMMVFIInterpolate,
|
||||
"GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate,
|
||||
"LoadFlashVSRModel": LoadFlashVSRModel,
|
||||
"FlashVSRUpscale": FlashVSRUpscale,
|
||||
"FlashVSRSegmentUpscale": FlashVSRSegmentUpscale,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -84,4 +88,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoadGIMMVFIModel": "Load GIMM-VFI Model",
|
||||
"GIMMVFIInterpolate": "GIMM-VFI Interpolate",
|
||||
"GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate",
|
||||
"LoadFlashVSRModel": "Load FlashVSR Model",
|
||||
"FlashVSRUpscale": "FlashVSR Upscale",
|
||||
"FlashVSRSegmentUpscale": "FlashVSR Segment Upscale",
|
||||
}
|
||||
|
||||
4
flashvsr_arch/__init__.py
Normal file
4
flashvsr_arch/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .models.model_manager import ModelManager
|
||||
from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
|
||||
from .models.utils import clean_vram, Buffer_LQ4x_Proj
|
||||
from .models.TCDecoder import build_tcdecoder
|
||||
0
flashvsr_arch/configs/__init__.py
Normal file
0
flashvsr_arch/configs/__init__.py
Normal file
21
flashvsr_arch/configs/model_config.py
Normal file
21
flashvsr_arch/configs/model_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
model_loader_configs = [
|
||||
# (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
]
|
||||
320
flashvsr_arch/models/TCDecoder.py
Normal file
320
flashvsr_arch/models/TCDecoder.py
Normal file
@@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
|
||||
- Encoder removed
|
||||
- Transplant/widening helpers removed
|
||||
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple
|
||||
from einops import rearrange
|
||||
import torch.nn.init as init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
# ----------------------------
|
||||
# Utility / building blocks
|
||||
# ----------------------------
|
||||
|
||||
class IdentityConv2d(nn.Conv2d):
|
||||
"""Same-shape Conv2d initialized to identity (Dirac)."""
|
||||
def __init__(self, C, kernel_size=3, bias=False):
|
||||
pad = kernel_size // 2
|
||||
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
|
||||
with torch.no_grad():
|
||||
init.dirac_(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias.zero_()
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
|
||||
conv(n_out, n_out), nn.ReLU(inplace=True),
|
||||
conv(n_out, n_out)
|
||||
)
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
class PixelShuffle3d(nn.Module):
|
||||
def __init__(self, ff, hh, ww):
|
||||
super().__init__()
|
||||
self.ff = ff
|
||||
self.hh = hh
|
||||
self.ww = ww
|
||||
def forward(self, x):
|
||||
# x: (B, C, F, H, W)
|
||||
B, C, F, H, W = x.shape
|
||||
if F % self.ff != 0:
|
||||
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
|
||||
x = torch.cat([first_frame, x], dim=2)
|
||||
return rearrange(
|
||||
x,
|
||||
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
||||
ff=self.ff, hh=self.hh, ww=self.ww
|
||||
).transpose(1, 2)
|
||||
|
||||
# ----------------------------
|
||||
# Generic NTCHW graph executor (kept; used by decoder)
|
||||
# ----------------------------
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
|
||||
"""
|
||||
Apply a sequential model with memblocks to the given input.
|
||||
Args:
|
||||
- model: nn.Sequential of blocks to apply
|
||||
- x: input data, of dimensions NTCHW
|
||||
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
||||
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
||||
- show_progress_bar: if True, enables tqdm progressbar display
|
||||
|
||||
Returns NTCHW tensor of output data.
|
||||
"""
|
||||
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
||||
N, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(N*T, C, H, W)
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
NT, C, H, W = x.shape
|
||||
T = NT // N
|
||||
_x = x.reshape(N, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
NT, C, H, W = x.shape
|
||||
T = NT // N
|
||||
x = x.view(N, T, C, H, W)
|
||||
else:
|
||||
out = []
|
||||
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
while work_queue:
|
||||
xt, i = work_queue.pop(0)
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
else:
|
||||
b = model[i]
|
||||
if isinstance(b, MemBlock):
|
||||
if mem[i] is None:
|
||||
xt_new = b(xt, xt * 0)
|
||||
mem[i] = xt
|
||||
else:
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i].copy_(xt)
|
||||
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
||||
elif isinstance(b, TPool):
|
||||
if mem[i] is None:
|
||||
mem[i] = []
|
||||
mem[i].append(xt)
|
||||
if len(mem[i]) > b.stride:
|
||||
raise ValueError("TPool internal state invalid.")
|
||||
elif len(mem[i]) == b.stride:
|
||||
N_, C_, H_, W_ = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
|
||||
mem[i] = []
|
||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C_, H_, W_ = xt.shape
|
||||
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
|
||||
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
||||
else:
|
||||
xt = b(xt)
|
||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x, mem
|
||||
|
||||
# ----------------------------
|
||||
# Decoder-only TAEHV
|
||||
# ----------------------------
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
image_channels = 3
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path="taehv.pth",
|
||||
decoder_time_upscale=(True, True),
|
||||
decoder_space_upscale=(True, True, True),
|
||||
channels = [256, 128, 64, 64],
|
||||
latent_channels = 16
|
||||
):
|
||||
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
|
||||
Deepening config: how_many_each=1, k=3 (fixed as requested).
|
||||
"""
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
n_f = channels
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
|
||||
# Build the decoder "skeleton"
|
||||
base_decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
||||
|
||||
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
|
||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
|
||||
TGrow(n_f[0], 1),
|
||||
conv(n_f[0], n_f[1], bias=False),
|
||||
|
||||
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
|
||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
|
||||
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
|
||||
conv(n_f[1], n_f[2], bias=False),
|
||||
|
||||
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
|
||||
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
|
||||
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
|
||||
conv(n_f[2], n_f[3], bias=False),
|
||||
|
||||
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
||||
)
|
||||
|
||||
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
|
||||
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
|
||||
|
||||
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
missing_keys = self.load_state_dict(
|
||||
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
|
||||
strict=False
|
||||
)
|
||||
print('missing_keys', missing_keys)
|
||||
|
||||
# Initialize decoder mem state
|
||||
self.mem = [None] * len(self.decoder)
|
||||
|
||||
@staticmethod
|
||||
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
|
||||
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
|
||||
new_layers = []
|
||||
for b in decoder:
|
||||
new_layers.append(b)
|
||||
if isinstance(b, nn.ReLU):
|
||||
# Deduce channel count from preceding layer
|
||||
C = None
|
||||
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
|
||||
C = new_layers[-2].out_channels
|
||||
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
|
||||
C = new_layers[-2].conv[-1].out_channels
|
||||
if C is not None:
|
||||
for _ in range(how_many_each):
|
||||
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
|
||||
new_layers.append(nn.ReLU(inplace=True))
|
||||
return nn.Sequential(*new_layers)
|
||||
|
||||
def patch_tgrow_layers(self, sd):
|
||||
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
|
||||
new_sd = self.state_dict()
|
||||
for i, layer in enumerate(self.decoder):
|
||||
if isinstance(layer, TGrow):
|
||||
key = f"decoder.{i}.conv.weight"
|
||||
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
|
||||
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
||||
return sd
|
||||
|
||||
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
|
||||
"""Decode a sequence of frames from latents.
|
||||
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
|
||||
"""
|
||||
trim_flag = self.mem[-8] is None # keeps original relative check
|
||||
|
||||
if cond is not None:
|
||||
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
|
||||
|
||||
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
|
||||
|
||||
if trim_flag:
|
||||
return x[:, self.frames_to_trim:]
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
|
||||
|
||||
def clean_mem(self):
|
||||
self.mem = [None] * len(self.decoder)
|
||||
|
||||
class DotDict(dict):
|
||||
__getattr__ = dict.__getitem__
|
||||
__setattr__ = dict.__setitem__
|
||||
|
||||
class TAEW2_1DiffusersWrapper(nn.Module):
|
||||
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
|
||||
super().__init__()
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = "cuda"
|
||||
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
|
||||
self.temperal_downsample = [True, True, False] # [sic]
|
||||
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
|
||||
|
||||
def decode(self, latents, return_dict=None):
|
||||
n, c, t, h, w = latents.shape
|
||||
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
|
||||
|
||||
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
|
||||
n, c, t, h, w = latents.shape
|
||||
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
||||
def clean_mem(self):
|
||||
self.taehv.clean_mem()
|
||||
|
||||
# ----------------------------
|
||||
# Simplified builder (no small, no transplant, no post-hoc deepening)
|
||||
# ----------------------------
|
||||
|
||||
def build_tcdecoder(new_channels = [512, 256, 128, 128],
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
new_latent_channels=None):
|
||||
"""
|
||||
构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
|
||||
- 不创建 small / 不做移植
|
||||
- base_ckpt_path 参数保留但不使用(接口兼容)
|
||||
|
||||
返回:big (单个模型)
|
||||
"""
|
||||
if new_latent_channels is not None:
|
||||
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
|
||||
else:
|
||||
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
|
||||
|
||||
big.clean_mem()
|
||||
return big
|
||||
1
flashvsr_arch/models/__init__.py
Normal file
1
flashvsr_arch/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model_manager import *
|
||||
402
flashvsr_arch/models/model_manager.py
Normal file
402
flashvsr_arch/models/model_manager.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import os, torch, json, importlib
|
||||
from typing import List
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
with init_weights_on_device():
|
||||
model = model_class(**extra_kwargs)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
model = model.to(device=device)
|
||||
except:
|
||||
pass
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config and "_class_name" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
||||
for architecture in architectures:
|
||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||
if redirected_architecture is not None:
|
||||
architecture = redirected_architecture
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
#print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
#print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
if isinstance(file_path, list):
|
||||
for file_path_ in file_path:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
is_loaded = False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
is_loaded = True
|
||||
break
|
||||
if not is_loaded:
|
||||
print(f" Cannot load LoRA: {file_path}")
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
#print(f"Loading models from: {file_path}")
|
||||
if device is None: device = self.device
|
||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for path in file_path:
|
||||
state_dict.update(load_state_dict(path))
|
||||
elif os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
#print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
#print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}")
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
3
flashvsr_arch/models/sparse_sage/__init__.py
Normal file
3
flashvsr_arch/models/sparse_sage/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .core import sparse_sageattn
|
||||
|
||||
__all__ = ["sparse_sageattn"]
|
||||
40
flashvsr_arch/models/sparse_sage/core.py
Normal file
40
flashvsr_arch/models/sparse_sage/core.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Sparse SageAttention — block-sparse INT8 attention via Triton.
|
||||
|
||||
https://github.com/jt-zhang/Sparse_SageAttention_API
|
||||
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
Licensed under the Apache License, Version 2.0
|
||||
"""
|
||||
|
||||
from .quant_per_block import per_block_int8
|
||||
from .sparse_int8_attn import forward as sparse_sageattn_fwd
|
||||
import torch
|
||||
|
||||
|
||||
def sparse_sageattn(q, k, v, mask_id=None, is_causal=False, tensor_layout="HND"):
|
||||
if mask_id is None:
|
||||
mask_id = torch.ones(
|
||||
(q.shape[0], q.shape[1],
|
||||
(q.shape[2] + 128 - 1) // 128,
|
||||
(q.shape[3] + 64 - 1) // 64),
|
||||
dtype=torch.int8, device=q.device,
|
||||
)
|
||||
|
||||
output_dtype = q.dtype
|
||||
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
|
||||
v = v.to(torch.float16)
|
||||
|
||||
seq_dim = 1 if tensor_layout == "NHD" else 2
|
||||
km = k.mean(dim=seq_dim, keepdim=True)
|
||||
|
||||
q_int8, q_scale, k_int8, k_scale = per_block_int8(
|
||||
q, k, km=km, tensor_layout=tensor_layout,
|
||||
)
|
||||
|
||||
o = sparse_sageattn_fwd(
|
||||
q_int8, k_int8, mask_id, v, q_scale, k_scale,
|
||||
is_causal=is_causal, tensor_layout=tensor_layout,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
return o
|
||||
110
flashvsr_arch/models/sparse_sage/quant_per_block.py
Normal file
110
flashvsr_arch/models/sparse_sage/quant_per_block.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Per-block INT8 quantization kernel for Sparse SageAttention.
|
||||
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
Licensed under the Apache License, Version 2.0
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def quant_per_block_int8_kernel(
|
||||
Input, Output, Scale, L,
|
||||
stride_iz, stride_ih, stride_in,
|
||||
stride_oz, stride_oh, stride_on,
|
||||
stride_sz, stride_sh,
|
||||
sm_scale,
|
||||
C: tl.constexpr, BLK: tl.constexpr,
|
||||
):
|
||||
off_blk = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
off_b = tl.program_id(2)
|
||||
|
||||
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
||||
offs_k = tl.arange(0, C)
|
||||
|
||||
input_ptrs = (
|
||||
Input
|
||||
+ off_b * stride_iz
|
||||
+ off_h * stride_ih
|
||||
+ offs_n[:, None] * stride_in
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
output_ptrs = (
|
||||
Output
|
||||
+ off_b * stride_oz
|
||||
+ off_h * stride_oh
|
||||
+ offs_n[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
|
||||
|
||||
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
||||
x = x.to(tl.float32)
|
||||
x *= sm_scale
|
||||
scale = tl.max(tl.abs(x)) / 127.0
|
||||
x_int8 = x / scale
|
||||
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
||||
x_int8 = x_int8.to(tl.int8)
|
||||
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
||||
tl.store(scale_ptrs, scale)
|
||||
|
||||
|
||||
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
|
||||
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
||||
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
||||
|
||||
if km is not None:
|
||||
k = k - km
|
||||
|
||||
if tensor_layout == "HND":
|
||||
b, h_qo, qo_len, head_dim = q.shape
|
||||
_, h_kv, kv_len, _ = k.shape
|
||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
||||
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
|
||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
||||
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
|
||||
elif tensor_layout == "NHD":
|
||||
b, qo_len, h_qo, head_dim = q.shape
|
||||
_, kv_len, h_kv, _ = k.shape
|
||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
||||
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
|
||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
||||
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
|
||||
else:
|
||||
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
|
||||
|
||||
q_scale = torch.empty(
|
||||
(b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
(b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32,
|
||||
)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = head_dim ** -0.5
|
||||
|
||||
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
q, q_int8, q_scale, qo_len,
|
||||
stride_bz_q, stride_h_q, stride_seq_q,
|
||||
stride_bz_qo, stride_h_qo, stride_seq_qo,
|
||||
q_scale.stride(0), q_scale.stride(1),
|
||||
sm_scale=(sm_scale * 1.44269504),
|
||||
C=head_dim, BLK=BLKQ,
|
||||
)
|
||||
|
||||
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
k, k_int8, k_scale, kv_len,
|
||||
stride_bz_k, stride_h_k, stride_seq_k,
|
||||
stride_bz_ko, stride_h_ko, stride_seq_ko,
|
||||
k_scale.stride(0), k_scale.stride(1),
|
||||
sm_scale=1.0,
|
||||
C=head_dim, BLK=BLKK,
|
||||
)
|
||||
|
||||
return q_int8, q_scale, k_int8, k_scale
|
||||
196
flashvsr_arch/models/sparse_sage/sparse_int8_attn.py
Normal file
196
flashvsr_arch/models/sparse_sage/sparse_int8_attn.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Sparse INT8 attention kernel for Sparse SageAttention.
|
||||
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
Licensed under the Apache License, Version 2.0
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc, l_i, old_m, q, q_scale, kv_len,
|
||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||
stride_kn, stride_vn, start_m,
|
||||
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
||||
):
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
K_scale_ptr += lo // BLOCK_N
|
||||
K_ptrs += stride_kn * lo
|
||||
V_ptrs += stride_vn * lo
|
||||
elif STAGE == 3:
|
||||
lo, hi = 0, kv_len
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
kbid = tl.load(K_bid_ptr + start_n // BLOCK_N)
|
||||
if kbid:
|
||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||
k = tl.load(K_ptrs, mask=k_mask)
|
||||
k_scale = tl.load(K_scale_ptr)
|
||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = qk + tl.where(mask, 0, -1.0e6)
|
||||
local_m = tl.max(qk, 1)
|
||||
new_m = tl.maximum(old_m, local_m)
|
||||
qk -= new_m[:, None]
|
||||
else:
|
||||
local_m = tl.max(qk, 1)
|
||||
new_m = tl.maximum(old_m, local_m)
|
||||
qk = qk - new_m[:, None]
|
||||
|
||||
p = tl.math.exp2(qk)
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(old_m - new_m)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha[:, None]
|
||||
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v, out_dtype=tl.float16)
|
||||
old_m = new_m
|
||||
K_ptrs += BLOCK_N * stride_kn
|
||||
K_scale_ptr += 1
|
||||
V_ptrs += BLOCK_N * stride_vn
|
||||
return acc, l_i, old_m
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
Q, K, K_blkid, V, Q_scale, K_scale, Out,
|
||||
stride_qz, stride_qh, stride_qn,
|
||||
stride_kz, stride_kh, stride_kn,
|
||||
stride_vz, stride_vh, stride_vn,
|
||||
stride_oz, stride_oh, stride_on,
|
||||
stride_kbidq, stride_kbidk,
|
||||
qo_len, kv_len,
|
||||
H: tl.constexpr, num_kv_groups: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_z = tl.program_id(2).to(tl.int64)
|
||||
off_h = tl.program_id(1).to(tl.int64)
|
||||
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
|
||||
k_scale_offset = (
|
||||
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
) * tl.cdiv(kv_len, BLOCK_N)
|
||||
k_bid_offset = (
|
||||
off_z * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
) * stride_kbidq
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
Q_ptrs = (
|
||||
Q
|
||||
+ (off_z * stride_qz + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
Q_scale_ptr = Q_scale + q_scale_offset + start_m
|
||||
K_ptrs = (
|
||||
K
|
||||
+ (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
K_scale_ptr = K_scale + k_scale_offset
|
||||
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
|
||||
V_ptrs = (
|
||||
V
|
||||
+ (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
O_block_ptr = (
|
||||
Out
|
||||
+ (off_z * stride_oz + off_h * stride_oh)
|
||||
+ offs_m[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
q_scale = tl.load(Q_scale_ptr)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, q_scale, kv_len,
|
||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||
stride_kn, stride_vn,
|
||||
start_m,
|
||||
BLOCK_M, HEAD_DIM, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
)
|
||||
if STAGE != 1:
|
||||
acc, l_i, _ = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, q_scale, kv_len,
|
||||
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs,
|
||||
stride_kn, stride_vn,
|
||||
start_m,
|
||||
BLOCK_M, HEAD_DIM, BLOCK_N,
|
||||
2, offs_m, offs_n,
|
||||
)
|
||||
acc = acc / l_i[:, None]
|
||||
tl.store(
|
||||
O_block_ptr,
|
||||
acc.to(Out.type.element_ty),
|
||||
mask=(offs_m[:, None] < qo_len),
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
q, k, k_block_id, v, q_scale, k_scale,
|
||||
is_causal=False, tensor_layout="HND", output_dtype=torch.float16,
|
||||
):
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
stage = 3 if is_causal else 1
|
||||
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
b, h_qo, qo_len, head_dim = q.shape
|
||||
_, h_kv, kv_len, _ = k.shape
|
||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
||||
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
|
||||
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
|
||||
elif tensor_layout == "NHD":
|
||||
b, qo_len, h_qo, head_dim = q.shape
|
||||
_, kv_len, h_kv, _ = k.shape
|
||||
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
||||
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
||||
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
|
||||
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
|
||||
else:
|
||||
raise ValueError(f"tensor_layout {tensor_layout} not supported")
|
||||
|
||||
if is_causal:
|
||||
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
|
||||
|
||||
HEAD_DIM_K = head_dim
|
||||
num_kv_groups = h_qo // h_kv
|
||||
|
||||
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b)
|
||||
_attn_fwd[grid](
|
||||
q, k, k_block_id, v, q_scale, k_scale, o,
|
||||
stride_bz_q, stride_h_q, stride_seq_q,
|
||||
stride_bz_k, stride_h_k, stride_seq_k,
|
||||
stride_bz_v, stride_h_v, stride_seq_v,
|
||||
stride_bz_o, stride_h_o, stride_seq_o,
|
||||
k_block_id.stride(1), k_block_id.stride(2),
|
||||
qo_len, kv_len,
|
||||
h_qo, num_kv_groups,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
|
||||
STAGE=stage,
|
||||
num_warps=4 if head_dim == 64 else 8,
|
||||
num_stages=4,
|
||||
)
|
||||
return o
|
||||
460
flashvsr_arch/models/utils.py
Normal file
460
flashvsr_arch/models/utils.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import torch, os, gc
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
from einops import rearrange, repeat
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
for file_name in os.listdir(file_path):
|
||||
if "." in file_name and file_name.split(".")[-1] in [
|
||||
"safetensors", "bin", "ckpt", "pth", "pt"
|
||||
]:
|
||||
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-3:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
|
||||
|
||||
def search_for_files(folder, extensions):
|
||||
files = []
|
||||
if os.path.isdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
files += search_for_files(os.path.join(folder, file), extensions)
|
||||
elif os.path.isfile(folder):
|
||||
for extension in extensions:
|
||||
if folder.endswith(extension):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
def clean_vram():
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
if torch.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def get_device_list():
|
||||
devs = []
|
||||
if torch.cuda.is_available():
|
||||
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
||||
|
||||
if torch.mps.is_available():
|
||||
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
|
||||
|
||||
return devs
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
# print(cache_x.shape, x.shape)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
# print('cache!')
|
||||
x = F.pad(x, padding, mode='replicate') # mode='replicate'
|
||||
# print(x[0,0,:,0,0])
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
class PixelShuffle3d(nn.Module):
|
||||
def __init__(self, ff, hh, ww):
|
||||
super().__init__()
|
||||
self.ff = ff
|
||||
self.hh = hh
|
||||
self.ww = ww
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, C, F, H, W)
|
||||
return rearrange(x,
|
||||
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
||||
ff=self.ff, hh=self.hh, ww=self.ww)
|
||||
|
||||
class Buffer_LQ4x_Proj(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, layer_num=30):
|
||||
super().__init__()
|
||||
self.ff = 1
|
||||
self.hh = 16
|
||||
self.ww = 16
|
||||
self.hidden_dim1 = 2048
|
||||
self.hidden_dim2 = 3072
|
||||
self.layer_num = layer_num
|
||||
|
||||
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
||||
|
||||
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
||||
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
||||
self.act1 = nn.SiLU()
|
||||
|
||||
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
||||
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
||||
self.act2 = nn.SiLU()
|
||||
|
||||
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
||||
|
||||
self.clip_idx = 0
|
||||
|
||||
def forward(self, video):
|
||||
self.clear_cache()
|
||||
# x: (B, C, F, H, W)
|
||||
|
||||
t = video.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
||||
video = torch.cat([first_frame, video], dim=2)
|
||||
# print(video.shape)
|
||||
|
||||
out_x = []
|
||||
for i in range(iter_):
|
||||
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv1'] = cache1_x
|
||||
x = self.conv1(x, self.cache['conv1'])
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv2'] = cache2_x
|
||||
if i == 0:
|
||||
continue
|
||||
x = self.conv2(x, self.cache['conv2'])
|
||||
x = self.norm2(x)
|
||||
x = self.act2(x)
|
||||
out_x.append(x)
|
||||
out_x = torch.cat(out_x, dim = 2)
|
||||
# print(out_x.shape)
|
||||
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
||||
outputs = []
|
||||
for i in range(self.layer_num):
|
||||
outputs.append(self.linear_layers[i](out_x))
|
||||
return outputs
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache = {}
|
||||
self.cache['conv1'] = None
|
||||
self.cache['conv2'] = None
|
||||
self.clip_idx = 0
|
||||
|
||||
def stream_forward(self, video_clip):
|
||||
if self.clip_idx == 0:
|
||||
# self.clear_cache()
|
||||
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
||||
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
||||
x = self.pixel_shuffle(video_clip)
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv1'] = cache1_x
|
||||
x = self.conv1(x, self.cache['conv1'])
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv2'] = cache2_x
|
||||
self.clip_idx += 1
|
||||
return None
|
||||
else:
|
||||
x = self.pixel_shuffle(video_clip)
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv1'] = cache1_x
|
||||
x = self.conv1(x, self.cache['conv1'])
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv2'] = cache2_x
|
||||
x = self.conv2(x, self.cache['conv2'])
|
||||
x = self.norm2(x)
|
||||
x = self.act2(x)
|
||||
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
||||
outputs = []
|
||||
for i in range(self.layer_num):
|
||||
outputs.append(self.linear_layers[i](out_x))
|
||||
self.clip_idx += 1
|
||||
return outputs
|
||||
|
||||
|
||||
class Causal_LQ4x_Proj(nn.Module):
|
||||
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
|
||||
|
||||
Key difference: reads old cache BEFORE writing new cache (truly causal),
|
||||
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim, layer_num=30):
|
||||
super().__init__()
|
||||
self.ff = 1
|
||||
self.hh = 16
|
||||
self.ww = 16
|
||||
self.hidden_dim1 = 2048
|
||||
self.hidden_dim2 = 3072
|
||||
self.layer_num = layer_num
|
||||
|
||||
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
||||
|
||||
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
||||
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
||||
self.act1 = nn.SiLU()
|
||||
|
||||
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
||||
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
||||
self.act2 = nn.SiLU()
|
||||
|
||||
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
||||
|
||||
self.clip_idx = 0
|
||||
|
||||
def forward(self, video):
|
||||
self.clear_cache()
|
||||
t = video.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
||||
video = torch.cat([first_frame, video], dim=2)
|
||||
|
||||
out_x = []
|
||||
for i in range(iter_):
|
||||
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
||||
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if i == 0:
|
||||
self.cache['conv2'] = cache2_x
|
||||
continue
|
||||
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
||||
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
|
||||
x = self.norm2(x)
|
||||
x = self.act2(x)
|
||||
out_x.append(x)
|
||||
out_x = torch.cat(out_x, dim=2)
|
||||
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
||||
outputs = []
|
||||
for i in range(self.layer_num):
|
||||
outputs.append(self.linear_layers[i](out_x))
|
||||
return outputs
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache = {}
|
||||
self.cache['conv1'] = None
|
||||
self.cache['conv2'] = None
|
||||
self.clip_idx = 0
|
||||
|
||||
def stream_forward(self, video_clip):
|
||||
if self.clip_idx == 0:
|
||||
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
||||
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
||||
x = self.pixel_shuffle(video_clip)
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
|
||||
self.cache['conv1'] = cache1_x # writes AFTER
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
self.cache['conv2'] = cache2_x
|
||||
self.clip_idx += 1
|
||||
return None
|
||||
else:
|
||||
x = self.pixel_shuffle(video_clip)
|
||||
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
||||
self.cache['conv1'] = cache1_x # writes AFTER
|
||||
x = self.norm1(x)
|
||||
x = self.act1(x)
|
||||
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
||||
self.cache['conv2'] = cache2_x # writes AFTER
|
||||
x = self.norm2(x)
|
||||
x = self.act2(x)
|
||||
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
||||
outputs = []
|
||||
for i in range(self.layer_num):
|
||||
outputs.append(self.linear_layers[i](out_x))
|
||||
self.clip_idx += 1
|
||||
return outputs
|
||||
865
flashvsr_arch/models/wan_video_dit.py
Normal file
865
flashvsr_arch/models/wan_video_dit.py
Normal file
@@ -0,0 +1,865 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except Exception:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
assert callable(getattr(flash_attn, "flash_attn_func", None))
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except Exception:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
assert callable(sageattn)
|
||||
SAGE_ATTN_AVAILABLE = True
|
||||
except Exception:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from .sparse_sage.core import sparse_sageattn
|
||||
assert callable(sparse_sageattn)
|
||||
SPARSE_SAGE_AVAILABLE = True
|
||||
except Exception:
|
||||
try:
|
||||
from sageattn.core import sparse_sageattn
|
||||
assert callable(sparse_sageattn)
|
||||
SPARSE_SAGE_AVAILABLE = True
|
||||
except Exception:
|
||||
SPARSE_SAGE_AVAILABLE = False
|
||||
sparse_sageattn = None
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
print(f"[FlashVSR] Attention backends: sparse_sage={SPARSE_SAGE_AVAILABLE}, "
|
||||
f"flash_attn_3={FLASH_ATTN_3_AVAILABLE}, flash_attn_2={FLASH_ATTN_2_AVAILABLE}, "
|
||||
f"sage_attn={SAGE_ATTN_AVAILABLE}")
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Local / window masks
|
||||
# ----------------------------
|
||||
@torch.no_grad()
|
||||
def build_local_block_mask_shifted_vec(block_h: int,
|
||||
block_w: int,
|
||||
win_h: int = 6,
|
||||
win_w: int = 6,
|
||||
include_self: bool = True,
|
||||
device=None) -> torch.Tensor:
|
||||
device = device or torch.device("cpu")
|
||||
H, W = block_h, block_w
|
||||
r = torch.arange(H, device=device)
|
||||
c = torch.arange(W, device=device)
|
||||
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
||||
r_all = YY.reshape(-1)
|
||||
c_all = XX.reshape(-1)
|
||||
r_half = win_h // 2
|
||||
c_half = win_w // 2
|
||||
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
|
||||
end_r = start_r + win_h - 1
|
||||
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
|
||||
end_c = start_c + win_w - 1
|
||||
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
||||
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
||||
mask = in_row & in_col
|
||||
if not include_self:
|
||||
mask.fill_diagonal_(False)
|
||||
return mask
|
||||
|
||||
@torch.no_grad()
|
||||
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
|
||||
block_w: int,
|
||||
win_h: int = 6,
|
||||
win_w: int = 6,
|
||||
include_self: bool = True,
|
||||
device=None) -> torch.Tensor:
|
||||
device = device or torch.device("cpu")
|
||||
H, W = block_h, block_w
|
||||
r = torch.arange(H, device=device)
|
||||
c = torch.arange(W, device=device)
|
||||
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
||||
r_all = YY.reshape(-1)
|
||||
c_all = XX.reshape(-1)
|
||||
r_half = win_h // 2
|
||||
c_half = win_w // 2
|
||||
start_r = r_all - r_half
|
||||
end_r = start_r + win_h - 1
|
||||
start_c = c_all - c_half
|
||||
end_c = start_c + win_w - 1
|
||||
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
||||
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
||||
mask = in_row & in_col
|
||||
if not include_self:
|
||||
mask.fill_diagonal_(False)
|
||||
return mask
|
||||
|
||||
|
||||
class WindowPartition3D:
|
||||
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
|
||||
@staticmethod
|
||||
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
|
||||
B, F, H, W, C = x.shape
|
||||
wf, wh, ww = win
|
||||
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
|
||||
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
|
||||
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
|
||||
return x.view(-1, wf * wh * ww, C)
|
||||
|
||||
@staticmethod
|
||||
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
|
||||
F, H, W = orig
|
||||
wf, wh, ww = win
|
||||
nf, nh, nw = F // wf, H // wh, W // ww
|
||||
B = windows.size(0) // (nf * nh * nw)
|
||||
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
|
||||
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
|
||||
return x.view(B, F, H, W, -1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_draft_block_mask(batch_size, nheads, seqlen,
|
||||
q_w, k_w, topk=10, local_attn_mask=None):
|
||||
assert batch_size == 1, "Only batch_size=1 supported for now"
|
||||
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
||||
avgpool_q = torch.mean(q_w, dim=1)
|
||||
avgpool_k = torch.mean(k_w, dim=1)
|
||||
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
||||
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
|
||||
q_heads = avgpool_q.permute(1, 0, 2)
|
||||
k_heads = avgpool_k.permute(1, 0, 2)
|
||||
D = avgpool_q.shape[-1]
|
||||
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
|
||||
|
||||
repeat_head = scores.shape[0]
|
||||
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
||||
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
|
||||
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
||||
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
||||
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
||||
local_attn_mask = local_attn_mask.to(torch.float32)
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
||||
scores = scores + local_attn_mask
|
||||
|
||||
attn_map = torch.softmax(scores, dim=-1)
|
||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
||||
loop_num, s1, s2 = attn_map.shape
|
||||
flat = attn_map.reshape(loop_num, -1)
|
||||
n = flat.shape[1]
|
||||
apply_topk = min(flat.shape[1]-1, topk)
|
||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
||||
thresholds = thresholds.unsqueeze(1)
|
||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
||||
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
|
||||
# 修正:上行变量名统一
|
||||
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
|
||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
|
||||
q_w, k_w, topk=10, local_attn_mask=None):
|
||||
assert batch_size == 1, "Only batch_size=1 supported for now"
|
||||
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
||||
|
||||
avgpool_q = torch.mean(q_w, dim=1)
|
||||
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
||||
q_heads = avgpool_q.permute(1, 0, 2)
|
||||
D = avgpool_q.shape[-1]
|
||||
|
||||
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
|
||||
avgpool_k_split = torch.mean(k_w_split, dim=2)
|
||||
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
|
||||
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
|
||||
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2)
|
||||
|
||||
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
|
||||
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
|
||||
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
|
||||
scores = torch.cat([scores_1, scores_2], dim=-1)
|
||||
|
||||
repeat_head = scores.shape[0]
|
||||
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
||||
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
|
||||
|
||||
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
||||
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
||||
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
|
||||
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
||||
|
||||
local_attn_mask = local_attn_mask.to(torch.float32)
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
||||
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
||||
|
||||
assert scores.shape == local_attn_mask.shape, \
|
||||
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
|
||||
|
||||
scores = scores + local_attn_mask
|
||||
attn_map = torch.softmax(scores, dim=-1)
|
||||
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
||||
loop_num, s1, s2 = attn_map.shape
|
||||
flat = attn_map.reshape(loop_num, -1)
|
||||
apply_topk = min(flat.shape[1]-1, topk)
|
||||
|
||||
if apply_topk <= 0:
|
||||
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
|
||||
else:
|
||||
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
||||
thresholds = thresholds.unsqueeze(1)
|
||||
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
||||
|
||||
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
|
||||
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
||||
return mask
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Attention kernels
|
||||
# ----------------------------
|
||||
def _sdpa_fallback(q, k, v, num_heads):
|
||||
"""PyTorch scaled dot-product attention (always available)."""
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
|
||||
|
||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
|
||||
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
|
||||
|
||||
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
|
||||
try:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
base_blockmask = attention_mask
|
||||
x = sparse_sageattn(
|
||||
q, k, v,
|
||||
mask_id=base_blockmask.to(torch.int8),
|
||||
is_causal=False,
|
||||
tensor_layout="HND"
|
||||
)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
except Exception:
|
||||
SPARSE_SAGE_AVAILABLE = False
|
||||
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
# q,k,v already rearranged to [b, n, s, d] above
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
elif compatibility_mode:
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||
x = flash_attn.flash_attn_func(q, k, v)
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
try:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = sageattn(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
except Exception:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
|
||||
# q,k,v already rearranged to [b, n, s, d] above
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = _sdpa_fallback(q, k, v, num_heads)
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return (x * (1 + scale) + shift)
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
half_dim = max(dim // 2, 1)
|
||||
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
|
||||
inv_freq = torch.pow(10000.0, -scale / half_dim)
|
||||
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
|
||||
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
||||
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
half_dim = max(dim // 2, 1)
|
||||
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
|
||||
freqs = torch.pow(theta, -base / max(dim, 1))
|
||||
steps = torch.arange(end, dtype=torch.float64)
|
||||
angles = torch.outer(steps, freqs)
|
||||
return torch.polar(torch.ones_like(angles), angles)
|
||||
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
orig_dtype = x.dtype
|
||||
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
x_complex = torch.view_as_complex(reshaped)
|
||||
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
|
||||
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
|
||||
return x_out.to(orig_dtype)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Norms & Blocks
|
||||
# ----------------------------
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, num_heads, enable_sageattention=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.enable_sageattention = enable_sageattention
|
||||
|
||||
def forward(self, q, k, v, attention_mask=None):
|
||||
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
|
||||
self.local_attn_mask = None
|
||||
|
||||
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
|
||||
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
||||
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
||||
B, L, D = x.shape
|
||||
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
|
||||
assert f==2, "f must be 2"
|
||||
if is_stream and (pre_cache_k is None or pre_cache_v is None):
|
||||
assert f==6, " start f must be 6"
|
||||
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
|
||||
win = (2, 8, 8)
|
||||
q = q.view(B, f, h, w, D)
|
||||
k = k.view(B, f, h, w, D)
|
||||
v = v.view(B, f, h, w, D)
|
||||
|
||||
q_w = WindowPartition3D.partition(q, win)
|
||||
k_w = WindowPartition3D.partition(k, win)
|
||||
v_w = WindowPartition3D.partition(v, win)
|
||||
|
||||
seqlen = f//win[0]
|
||||
one_len = k_w.shape[0] // B // seqlen
|
||||
if pre_cache_k is not None and pre_cache_v is not None:
|
||||
k_w = torch.cat([pre_cache_k, k_w], dim=0)
|
||||
v_w = torch.cat([pre_cache_v, v_w], dim=0)
|
||||
|
||||
block_n = q_w.shape[0] // B
|
||||
block_s = q_w.shape[1]
|
||||
block_n_kv = k_w.shape[0] // B
|
||||
|
||||
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
|
||||
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
||||
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
||||
|
||||
window_size = win[0]*h*w//128
|
||||
|
||||
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
|
||||
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
|
||||
self.local_attn_mask_h = h//8
|
||||
self.local_attn_mask_w = w//8
|
||||
self.local_range = local_range
|
||||
attention_mask = generate_draft_block_mask_refined(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
||||
|
||||
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
|
||||
|
||||
cur_block_n, cur_block_s, _ = k_w.shape
|
||||
cache_num = cur_block_n // one_len
|
||||
if cache_num > kv_len:
|
||||
cache_k = k_w[one_len:, :, :]
|
||||
cache_v = v_w[one_len:, :, :]
|
||||
else:
|
||||
cache_k = k_w
|
||||
cache_v = v_w
|
||||
|
||||
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
|
||||
x = WindowPartition3D.reverse(x, win, (f, h, w))
|
||||
x = x.view(B, f*h*w, D)
|
||||
|
||||
if is_stream:
|
||||
return self.o(x), cache_k, cache_v
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
仅考虑文本 context;提供持久 KV 缓存。
|
||||
"""
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
|
||||
|
||||
# 持久缓存
|
||||
self.cache_k = None
|
||||
self.cache_v = None
|
||||
|
||||
@torch.no_grad()
|
||||
def init_cache(self, ctx: torch.Tensor):
|
||||
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
|
||||
self.cache_k = self.norm_k(self.k(ctx))
|
||||
self.cache_v = self.v(ctx)
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache_k = None
|
||||
self.cache_v = None
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
|
||||
"""
|
||||
y 即文本上下文(未做其他分支)。
|
||||
"""
|
||||
q = self.norm_q(self.q(x))
|
||||
assert self.cache_k is not None and self.cache_v is not None
|
||||
k = self.cache_k
|
||||
v = self.cache_v
|
||||
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class GateModule(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, gate, residual):
|
||||
return x + gate * residual
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
|
||||
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
|
||||
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
|
||||
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
||||
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
|
||||
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
|
||||
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
|
||||
|
||||
x = self.gate(x, gate_msa, self_attn_output)
|
||||
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
if is_stream:
|
||||
return x, self_attn_cache_k, self_attn_cache_v
|
||||
return x
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, in_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
self.has_pos_emb = has_pos_emb
|
||||
if has_pos_emb:
|
||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_pos_emb:
|
||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
return x
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# WanModel (no image branch) — init 时即产生 KV 缓存
|
||||
# ----------------------------
|
||||
class WanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool = False,
|
||||
enable_sageattention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.patch_size = patch_size
|
||||
|
||||
# patch embed
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# text / time embed
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
head_dim = dim // num_heads
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
self._cross_kv_initialized = False
|
||||
|
||||
# 可选:手动清空 / 重新初始化
|
||||
# 可选:手动清空 / 重新初始化
|
||||
def clear_cross_kv(self):
|
||||
for blk in self.blocks:
|
||||
blk.cross_attn.clear_cache()
|
||||
self._cross_kv_initialized = False
|
||||
|
||||
@torch.no_grad()
|
||||
def reinit_cross_kv(self, new_context: torch.Tensor):
|
||||
ctx_txt = self.text_embedding(new_context)
|
||||
for blk in self.blocks:
|
||||
blk.cross_attn.init_cache(ctx_txt)
|
||||
self._cross_kv_initialized = True
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
LQ_latents: Optional[List[torch.Tensor]] = None,
|
||||
train_img: bool = False,
|
||||
topk_ratio: Optional[float] = None,
|
||||
kv_ratio: Optional[float] = None,
|
||||
local_num: Optional[int] = None,
|
||||
is_full_block: bool = False,
|
||||
causal_idx: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# time / text embeds
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
|
||||
# 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它)
|
||||
# context = self.text_embedding(context)
|
||||
|
||||
# 输入打补丁
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
B = x.shape[0]
|
||||
|
||||
# window / masks 超参
|
||||
win = (2, 8, 8)
|
||||
seqlen = f//win[0]
|
||||
if local_num is None:
|
||||
local_random = random.random()
|
||||
if local_random < 0.3:
|
||||
local_num = seqlen - 3
|
||||
elif local_random < 0.4:
|
||||
local_num = seqlen - 4
|
||||
elif local_random < 0.5:
|
||||
local_num = seqlen - 2
|
||||
else:
|
||||
local_num = seqlen
|
||||
|
||||
window_size = win[0]*h*w//128
|
||||
square_num = window_size*window_size
|
||||
topk_ratio = 2.0
|
||||
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
|
||||
|
||||
if kv_ratio is None:
|
||||
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
|
||||
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
|
||||
|
||||
decay_ratio = random.uniform(0.7, 1.0)
|
||||
|
||||
# RoPE 3D
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# blocks
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
||||
x += LQ_latents[block_id]
|
||||
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
||||
train_img, block_id, kv_len, is_full_block, False,
|
||||
None, None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
||||
train_img, block_id, kv_len, is_full_block, False,
|
||||
None, None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
|
||||
train_img, block_id, kv_len, is_full_block, False,
|
||||
None, None)
|
||||
|
||||
x = self.head(x, t)
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanModelStateDictConverter()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# State dict converter(保持原映射;已忽略 has_image_input 使用)
|
||||
# ----------------------------
|
||||
class WanModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
# 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
847
flashvsr_arch/models/wan_video_vae.py
Normal file
847
flashvsr_arch/models/wan_video_vae.py
Normal file
@@ -0,0 +1,847 @@
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
def check_is_instance(model, module_class):
|
||||
if isinstance(model, module_class):
|
||||
return True
|
||||
if hasattr(model, "module") and isinstance(model.module, module_class):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def block_causal_mask(x, block_size):
|
||||
# params
|
||||
b, n, s, _, device = *x.size(), x.device
|
||||
assert s % block_size == 0
|
||||
num_blocks = s // block_size
|
||||
|
||||
# build mask
|
||||
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
||||
for i in range(num_blocks):
|
||||
mask[:, :,
|
||||
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim * 2, (3, 1, 1),
|
||||
padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim, (3, 1, 1),
|
||||
stride=(2, 1, 1),
|
||||
padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
||||
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
#attn_mask=block_causal_mask(q, block_size=h * w)
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if check_is_instance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class VideoVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=96,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
return out
|
||||
|
||||
|
||||
def stream_decode(self, z, scale):
|
||||
# self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# print('self._feat_map', len(self._feat_map))
|
||||
# cache encode
|
||||
if self.encoder is not None:
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
# print('self._enc_feat_map', len(self._enc_feat_map))
|
||||
|
||||
|
||||
class WanVideoVAE(nn.Module):
|
||||
|
||||
def __init__(self, z_dim=16, dim=96):
|
||||
super().__init__()
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, _, H, W = data.shape
|
||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
||||
|
||||
h = repeat(h, "H -> H W", H=H, W=W)
|
||||
w = repeat(w, "W -> H W", H=H, W=W)
|
||||
|
||||
mask = torch.stack([h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = hidden_states.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = T * 4 - 3
|
||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
||||
).to(dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
target_h = h * self.upsampling_factor
|
||||
target_w = w * self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
def tiled_encode(self, video, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = video.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
||||
).to(dtype=video.dtype, device=data_device)
|
||||
|
||||
target_h = h // self.upsampling_factor
|
||||
target_w = w // self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
||||
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
hidden_state = hidden_state.squeeze(0)
|
||||
hidden_states.append(hidden_state)
|
||||
hidden_states = torch.stack(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
def clear_cache(self):
|
||||
self.model.clear_cache()
|
||||
|
||||
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
hidden_states = [hidden_state for hidden_state in hidden_states]
|
||||
assert len(hidden_states) == 1
|
||||
hidden_state = hidden_states[0]
|
||||
video = self.model.stream_decode(hidden_state, self.scale)
|
||||
return video
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanVideoVAEStateDictConverter()
|
||||
|
||||
|
||||
class WanVideoVAEStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
if 'model_state' in state_dict:
|
||||
state_dict = state_dict['model_state']
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
3
flashvsr_arch/pipelines/__init__.py
Normal file
3
flashvsr_arch/pipelines/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .flashvsr_full import FlashVSRFullPipeline
|
||||
from .flashvsr_tiny import FlashVSRTinyPipeline
|
||||
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline
|
||||
127
flashvsr_arch/pipelines/base.py
Normal file
127
flashvsr_arch/pipelines/base.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.transforms import GaussianBlur
|
||||
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.cpu_offload = False
|
||||
self.model_names = []
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width):
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
||||
return height, width
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_images(self, images):
|
||||
return [self.preprocess_image(image) for image in images]
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output):
|
||||
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output):
|
||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
||||
return video
|
||||
|
||||
|
||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||
if len(latents) > 0:
|
||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||
mask = blur(mask)
|
||||
value += latent * mask * scale
|
||||
weight += mask * scale
|
||||
value /= weight
|
||||
return value
|
||||
|
||||
|
||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
||||
if special_kwargs is None:
|
||||
noise_pred_global = inference_callback(prompt_emb_global)
|
||||
else:
|
||||
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
||||
if special_local_kwargs_list is None:
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||
else:
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
||||
local_prompts = local_prompts or []
|
||||
masks = masks or []
|
||||
mask_scales = mask_scales or []
|
||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||
masks += extended_prompt_dict.get("masks", [])
|
||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
||||
return prompt, local_prompts, masks, mask_scales
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
self.cpu_offload = True
|
||||
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[]):
|
||||
# only load models to device if cpu_offload is enabled
|
||||
if not self.cpu_offload:
|
||||
return
|
||||
# offload the unneeded models to cpu
|
||||
for model_name in self.model_names:
|
||||
if model_name not in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
# load the needed models to device
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
# fresh the cuda cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
||||
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
return noise
|
||||
638
flashvsr_arch/pipelines/flashvsr_full.py
Normal file
638
flashvsr_arch/pipelines/flashvsr_full.py
Normal file
@@ -0,0 +1,638 @@
|
||||
import types
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
# import pyfiglet
|
||||
|
||||
from ..models.utils import clean_vram
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
|
||||
# -----------------------------
|
||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
||||
# -----------------------------
|
||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
||||
N, C = feat.shape[:2]
|
||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
||||
std = var.sqrt().view(N, C, 1, 1)
|
||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||||
return mean, std
|
||||
|
||||
|
||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = _calc_mean_std(style_feat)
|
||||
content_mean, content_std = _calc_mean_std(content_feat)
|
||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
||||
# -----------------------------
|
||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
||||
vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125 ],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
return torch.tensor(vals, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
N, C, H, W = x.shape
|
||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
||||
pad = radius
|
||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
||||
return out
|
||||
|
||||
|
||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
high = torch.zeros_like(x)
|
||||
low = x
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
blurred = _wavelet_blur(low, radius)
|
||||
high = high + (low - blurred)
|
||||
low = blurred
|
||||
return high, low
|
||||
|
||||
|
||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
||||
return c_high + s_low
|
||||
|
||||
# -----------------------------
|
||||
# Safetensors support ---------
|
||||
# -----------------------------
|
||||
st_load_file = None # Define the variable in global scope first
|
||||
try:
|
||||
from safetensors.torch import load_file as st_load_file
|
||||
except ImportError:
|
||||
# st_load_file remains None if import fails
|
||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
||||
|
||||
# -----------------------------
|
||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
||||
# -----------------------------
|
||||
class TorchColorCorrectorWavelet(nn.Module):
|
||||
def __init__(self, levels: int = 5):
|
||||
super().__init__()
|
||||
self.levels = levels
|
||||
|
||||
@staticmethod
|
||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
||||
B, C, f, H, W = x.shape
|
||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
||||
return y, B, f
|
||||
|
||||
@staticmethod
|
||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
||||
BF, C, H, W = y.shape
|
||||
assert BF == B * f
|
||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
||||
|
||||
B, C, f, H, W = hq_image.shape
|
||||
if chunk_size is None or chunk_size >= f:
|
||||
hq4, B, f = self._flatten_time(hq_image)
|
||||
lq4, _, _ = self._flatten_time(lq_image)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out = self._unflatten_time(out4, B, f)
|
||||
return out
|
||||
|
||||
outs = []
|
||||
for start in range(0, f, chunk_size):
|
||||
end = min(start + chunk_size, f)
|
||||
hq_chunk = hq_image[:, :, start:end]
|
||||
lq_chunk = lq_image[:, :, start:end]
|
||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
||||
outs.append(out_chunk)
|
||||
out = torch.cat(outs, dim=2)
|
||||
return out
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 简化版 Pipeline(仅 dit + vae)
|
||||
# -----------------------------
|
||||
class FlashVSRFullPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['dit', 'vae']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
self.prompt_emb_posi = None
|
||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
||||
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
# 仅管理 dit / vae
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
RMS_norm: AutoWrappedModule,
|
||||
CausalConv3d: AutoWrappedModule,
|
||||
Upsample: AutoWrappedModule,
|
||||
torch.nn.SiLU: AutoWrappedModule,
|
||||
torch.nn.Dropout: AutoWrappedModule,
|
||||
},
|
||||
module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
# 可选:统一序列并行入口(此处默认关闭)
|
||||
pipe.use_unified_sequence_parallel = False
|
||||
return pipe
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
# -------------------------
|
||||
# 新增:显式 KV 预初始化函数
|
||||
# -------------------------
|
||||
def init_cross_kv(
|
||||
self,
|
||||
context_tensor: Optional[torch.Tensor] = None,
|
||||
prompt_path = None
|
||||
):
|
||||
self.load_models_to_device(["dit"])
|
||||
"""
|
||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
||||
必须在 __call__ 前显式调用一次。
|
||||
"""
|
||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
||||
if self.dit is None:
|
||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
||||
|
||||
if context_tensor is None:
|
||||
if prompt_path is None:
|
||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
||||
|
||||
# --- Safetensors loading logic added here ---
|
||||
prompt_path_lower = prompt_path.lower()
|
||||
if prompt_path_lower.endswith(".safetensors"):
|
||||
if st_load_file is None:
|
||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
||||
|
||||
# Load the tensor from safetensors
|
||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
||||
|
||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
||||
if len(loaded_dict) == 1:
|
||||
ctx = list(loaded_dict.values())[0]
|
||||
elif 'context' in loaded_dict: # Common key for text context
|
||||
ctx = loaded_dict['context']
|
||||
else:
|
||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
||||
|
||||
else:
|
||||
# Default behavior for .pth, .pt, etc.
|
||||
ctx = torch.load(prompt_path, map_location=self.device)
|
||||
|
||||
# --------------------------------------------
|
||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
||||
else:
|
||||
ctx = context_tensor
|
||||
|
||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if self.prompt_emb_posi is None:
|
||||
self.prompt_emb_posi = {}
|
||||
self.prompt_emb_posi['context'] = ctx
|
||||
|
||||
if hasattr(self.dit, "reinit_cross_kv"):
|
||||
self.dit.reinit_cross_kv(ctx)
|
||||
else:
|
||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
||||
self.load_models_to_device([])
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt=None,
|
||||
negative_prompt="",
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="gpu",
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
tiled=True,
|
||||
tile_size=(60, 104),
|
||||
tile_stride=(30, 52),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
LQ_video=None,
|
||||
is_full_block=False,
|
||||
if_buffer=False,
|
||||
topk_ratio=2.0,
|
||||
kv_ratio=3.0,
|
||||
local_range = 9,
|
||||
color_fix = True,
|
||||
unload_dit = False,
|
||||
skip_vae = False,
|
||||
):
|
||||
# 只接受 cfg=1.0(与原代码一致)
|
||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
||||
|
||||
# 要求:必须先 init_cross_kv()
|
||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
||||
raise RuntimeError(
|
||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
||||
" pipe.init_cross_kv()\n"
|
||||
"或传入自定义 context:\n"
|
||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
||||
)
|
||||
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler 参数
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# 初始化噪声
|
||||
if if_buffer:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = noise
|
||||
|
||||
process_total_num = (num_frames - 1) // 8 - 2
|
||||
is_stream = True
|
||||
|
||||
# 清理可能存在的 LQ_proj_in cache
|
||||
if hasattr(self.dit, "LQ_proj_in"):
|
||||
self.dit.LQ_proj_in.clear_cache()
|
||||
|
||||
frames_total = []
|
||||
LQ_pre_idx = 0
|
||||
LQ_cur_idx = 0
|
||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
||||
self.TCDecoder.clean_mem()
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
||||
current_dit_device = next(iter(self.dit.parameters())).device
|
||||
if str(current_dit_device) != str(self.device):
|
||||
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
|
||||
self.dit.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
||||
if cur_process_idx == 0:
|
||||
pre_cache_k = [None] * len(self.dit.blocks)
|
||||
pre_cache_v = [None] * len(self.dit.blocks)
|
||||
LQ_latents = None
|
||||
inner_loop_num = 7
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
||||
cur_latents = latents[:, :, :6, :, :]
|
||||
else:
|
||||
LQ_latents = None
|
||||
inner_loop_num = 2
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||||
|
||||
# Denoise
|
||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||||
self.dit,
|
||||
x=cur_latents,
|
||||
timestep=self.timestep,
|
||||
context=None,
|
||||
tea_cache=None,
|
||||
use_unified_sequence_parallel=False,
|
||||
LQ_latents=LQ_latents,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k,
|
||||
pre_cache_v=pre_cache_v,
|
||||
topk_ratio=topk_ratio,
|
||||
kv_ratio=kv_ratio,
|
||||
cur_process_idx=cur_process_idx,
|
||||
t_mod=self.t_mod,
|
||||
t=self.t,
|
||||
local_range = local_range,
|
||||
)
|
||||
|
||||
cur_latents = cur_latents - noise_pred_posi
|
||||
|
||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
||||
|
||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
||||
cur_frames = self.TCDecoder.decode_video(
|
||||
cur_latents.transpose(1, 2),
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cur_LQ_frame
|
||||
).transpose(1, 2).mul_(2).sub_(1)
|
||||
else:
|
||||
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
|
||||
|
||||
# Per-chunk color correction
|
||||
try:
|
||||
if color_fix:
|
||||
cur_frames = self.ColorCorrector(
|
||||
cur_frames.to(device=self.device),
|
||||
cur_LQ_frame,
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=None,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
frames_total.append(cur_frames.to('cpu'))
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
del cur_frames, cur_latents, cur_LQ_frame
|
||||
clean_vram()
|
||||
|
||||
frames = torch.cat(frames_total, dim=2)
|
||||
return frames[0]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# TeaCache(保留原逻辑;此处默认不启用)
|
||||
# -----------------------------
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
||||
if should_calc:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step = (self.step + 1) % self.num_inference_steps
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
||||
# -----------------------------
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
tea_cache: Optional[TeaCache] = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
LQ_latents: Optional[torch.Tensor] = None,
|
||||
is_full_block: bool = False,
|
||||
is_stream: bool = False,
|
||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
||||
topk_ratio: float = 2.0,
|
||||
kv_ratio: float = 3.0,
|
||||
cur_process_idx: int = 0,
|
||||
t_mod : torch.Tensor = None,
|
||||
t : torch.Tensor = None,
|
||||
local_range: int = 9,
|
||||
**kwargs,
|
||||
):
|
||||
# patchify
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
win = (2, 8, 8)
|
||||
seqlen = f // win[0]
|
||||
local_num = seqlen
|
||||
window_size = win[0] * h * w // 128
|
||||
square_num = window_size * window_size
|
||||
topk = int(square_num * topk_ratio) - 1
|
||||
kv_len = int(kv_ratio)
|
||||
|
||||
# RoPE 位置(分段)
|
||||
if cur_process_idx == 0:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
else:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache(默认不启用)
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
||||
|
||||
# 统一序列并行(此处默认关闭)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
# Block 堆叠
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
||||
x = x + LQ_latents[block_id]
|
||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
||||
x, context, t_mod, freqs, f, h, w,
|
||||
local_num, topk,
|
||||
block_id=block_id,
|
||||
kv_len=kv_len,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
||||
local_range = local_range,
|
||||
)
|
||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x, pre_cache_k, pre_cache_v
|
||||
625
flashvsr_arch/pipelines/flashvsr_tiny.py
Normal file
625
flashvsr_arch/pipelines/flashvsr_tiny.py
Normal file
@@ -0,0 +1,625 @@
|
||||
import types
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
# import pyfiglet
|
||||
|
||||
from ..models.utils import clean_vram
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
|
||||
# -----------------------------
|
||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
||||
# -----------------------------
|
||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
||||
N, C = feat.shape[:2]
|
||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
||||
std = var.sqrt().view(N, C, 1, 1)
|
||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||||
return mean, std
|
||||
|
||||
|
||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = _calc_mean_std(style_feat)
|
||||
content_mean, content_std = _calc_mean_std(content_feat)
|
||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
||||
# -----------------------------
|
||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
||||
vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125 ],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
return torch.tensor(vals, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
N, C, H, W = x.shape
|
||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
||||
pad = radius
|
||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
||||
return out
|
||||
|
||||
|
||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
high = torch.zeros_like(x)
|
||||
low = x
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
blurred = _wavelet_blur(low, radius)
|
||||
high = high + (low - blurred)
|
||||
low = blurred
|
||||
return high, low
|
||||
|
||||
|
||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
||||
return c_high + s_low
|
||||
|
||||
# -----------------------------
|
||||
# Safetensors support ---------
|
||||
# -----------------------------
|
||||
st_load_file = None # Define the variable in global scope first
|
||||
try:
|
||||
from safetensors.torch import load_file as st_load_file
|
||||
except ImportError:
|
||||
# st_load_file remains None if import fails
|
||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
||||
|
||||
# -----------------------------
|
||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
||||
# -----------------------------
|
||||
class TorchColorCorrectorWavelet(nn.Module):
|
||||
def __init__(self, levels: int = 5):
|
||||
super().__init__()
|
||||
self.levels = levels
|
||||
|
||||
@staticmethod
|
||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
||||
B, C, f, H, W = x.shape
|
||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
||||
return y, B, f
|
||||
|
||||
@staticmethod
|
||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
||||
BF, C, H, W = y.shape
|
||||
assert BF == B * f
|
||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
||||
|
||||
B, C, f, H, W = hq_image.shape
|
||||
if chunk_size is None or chunk_size >= f:
|
||||
hq4, B, f = self._flatten_time(hq_image)
|
||||
lq4, _, _ = self._flatten_time(lq_image)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out = self._unflatten_time(out4, B, f)
|
||||
return out
|
||||
|
||||
outs = []
|
||||
for start in range(0, f, chunk_size):
|
||||
end = min(start + chunk_size, f)
|
||||
hq_chunk = hq_image[:, :, start:end]
|
||||
lq_chunk = lq_image[:, :, start:end]
|
||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
||||
outs.append(out_chunk)
|
||||
out = torch.cat(outs, dim=2)
|
||||
return out
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 简化版 Pipeline(仅 dit + vae)
|
||||
# -----------------------------
|
||||
class FlashVSRTinyPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['dit', 'vae']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
self.prompt_emb_posi = None
|
||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
||||
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
# 仅管理 dit / vae
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
# 可选:统一序列并行入口(此处默认关闭)
|
||||
pipe.use_unified_sequence_parallel = False
|
||||
return pipe
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
# -------------------------
|
||||
# 新增:显式 KV 预初始化函数
|
||||
# -------------------------
|
||||
def init_cross_kv(
|
||||
self,
|
||||
context_tensor: Optional[torch.Tensor] = None,
|
||||
prompt_path = None,
|
||||
):
|
||||
self.load_models_to_device(["dit"])
|
||||
"""
|
||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
||||
必须在 __call__ 前显式调用一次。
|
||||
"""
|
||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
||||
|
||||
if self.dit is None:
|
||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
||||
|
||||
if context_tensor is None:
|
||||
if prompt_path is None:
|
||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
||||
|
||||
# --- Safetensors loading logic added here ---
|
||||
prompt_path_lower = prompt_path.lower()
|
||||
if prompt_path_lower.endswith(".safetensors"):
|
||||
if st_load_file is None:
|
||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
||||
|
||||
# Load the tensor from safetensors
|
||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
||||
|
||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
||||
if len(loaded_dict) == 1:
|
||||
ctx = list(loaded_dict.values())[0]
|
||||
elif 'context' in loaded_dict: # Common key for text context
|
||||
ctx = loaded_dict['context']
|
||||
else:
|
||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
||||
|
||||
else:
|
||||
# Default behavior for .pth, .pt, etc.
|
||||
ctx = torch.load(prompt_path, map_location=self.device)
|
||||
|
||||
# --------------------------------------------
|
||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
||||
else:
|
||||
ctx = context_tensor
|
||||
|
||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if self.prompt_emb_posi is None:
|
||||
self.prompt_emb_posi = {}
|
||||
self.prompt_emb_posi['context'] = ctx
|
||||
|
||||
if hasattr(self.dit, "reinit_cross_kv"):
|
||||
self.dit.reinit_cross_kv(ctx)
|
||||
else:
|
||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
||||
self.load_models_to_device([])
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
def decode_video(self, latents, cond=None, **kwargs):
|
||||
frames = self.TCDecoder.decode_video(
|
||||
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cond
|
||||
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
||||
|
||||
return frames
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt=None,
|
||||
negative_prompt="",
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="gpu",
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
tiled=True,
|
||||
tile_size=(60, 104),
|
||||
tile_stride=(30, 52),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="Wan2.1-T2V-14B",
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
LQ_video=None,
|
||||
is_full_block=False,
|
||||
if_buffer=False,
|
||||
topk_ratio=2.0,
|
||||
kv_ratio=3.0,
|
||||
local_range = 9,
|
||||
color_fix = True,
|
||||
unload_dit = False,
|
||||
skip_vae = False,
|
||||
):
|
||||
# 只接受 cfg=1.0(与原代码一致)
|
||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
||||
|
||||
# 要求:必须先 init_cross_kv()
|
||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
||||
raise RuntimeError(
|
||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
||||
" pipe.init_cross_kv()\n"
|
||||
"或传入自定义 context:\n"
|
||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
||||
)
|
||||
|
||||
# 尺寸修正
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler 参数
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# 初始化噪声
|
||||
if if_buffer:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = noise
|
||||
|
||||
process_total_num = (num_frames - 1) // 8 - 2
|
||||
is_stream = True
|
||||
|
||||
# 清理可能存在的 LQ_proj_in cache
|
||||
if hasattr(self.dit, "LQ_proj_in"):
|
||||
self.dit.LQ_proj_in.clear_cache()
|
||||
|
||||
frames_total = []
|
||||
self.TCDecoder.clean_mem()
|
||||
LQ_pre_idx = 0
|
||||
LQ_cur_idx = 0
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
||||
current_dit_device = next(iter(self.dit.parameters())).device
|
||||
if str(current_dit_device) != str(self.device):
|
||||
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
|
||||
self.dit.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
||||
if cur_process_idx == 0:
|
||||
pre_cache_k = [None] * len(self.dit.blocks)
|
||||
pre_cache_v = [None] * len(self.dit.blocks)
|
||||
LQ_latents = None
|
||||
inner_loop_num = 7
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
||||
cur_latents = latents[:, :, :6, :, :]
|
||||
else:
|
||||
LQ_latents = None
|
||||
inner_loop_num = 2
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||||
|
||||
# Denoise
|
||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||||
self.dit,
|
||||
x=cur_latents,
|
||||
timestep=self.timestep,
|
||||
context=None,
|
||||
tea_cache=None,
|
||||
use_unified_sequence_parallel=False,
|
||||
LQ_latents=LQ_latents,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k,
|
||||
pre_cache_v=pre_cache_v,
|
||||
topk_ratio=topk_ratio,
|
||||
kv_ratio=kv_ratio,
|
||||
cur_process_idx=cur_process_idx,
|
||||
t_mod=self.t_mod,
|
||||
t=self.t,
|
||||
local_range = local_range,
|
||||
)
|
||||
|
||||
cur_latents = cur_latents - noise_pred_posi
|
||||
|
||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
||||
cur_frames = self.TCDecoder.decode_video(
|
||||
cur_latents.transpose(1, 2),
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cur_LQ_frame
|
||||
).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
||||
# Per-chunk color correction
|
||||
try:
|
||||
if color_fix:
|
||||
cur_frames = self.ColorCorrector(
|
||||
cur_frames.to(device=self.device),
|
||||
cur_LQ_frame,
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=None,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
frames_total.append(cur_frames.to('cpu'))
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
del cur_frames, cur_latents, cur_LQ_frame
|
||||
clean_vram()
|
||||
|
||||
frames = torch.cat(frames_total, dim=2)
|
||||
return frames[0]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# TeaCache(保留原逻辑;此处默认不启用)
|
||||
# -----------------------------
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
||||
if should_calc:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step = (self.step + 1) % self.num_inference_steps
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
||||
# -----------------------------
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
tea_cache: Optional[TeaCache] = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
LQ_latents: Optional[torch.Tensor] = None,
|
||||
is_full_block: bool = False,
|
||||
is_stream: bool = False,
|
||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
||||
topk_ratio: float = 2.0,
|
||||
kv_ratio: float = 3.0,
|
||||
cur_process_idx: int = 0,
|
||||
t_mod : torch.Tensor = None,
|
||||
t : torch.Tensor = None,
|
||||
local_range: int = 9,
|
||||
**kwargs,
|
||||
):
|
||||
# patchify
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
win = (2, 8, 8)
|
||||
seqlen = f // win[0]
|
||||
local_num = seqlen
|
||||
window_size = win[0] * h * w // 128
|
||||
square_num = window_size * window_size
|
||||
topk = int(square_num * topk_ratio) - 1
|
||||
kv_len = int(kv_ratio)
|
||||
|
||||
# RoPE 位置(分段)
|
||||
if cur_process_idx == 0:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
else:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache(默认不启用)
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
||||
|
||||
# 统一序列并行(此处默认关闭)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
# Block 堆叠
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
||||
x = x + LQ_latents[block_id]
|
||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
||||
x, context, t_mod, freqs, f, h, w,
|
||||
local_num, topk,
|
||||
block_id=block_id,
|
||||
kv_len=kv_len,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
||||
local_range = local_range,
|
||||
)
|
||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x, pre_cache_k, pre_cache_v
|
||||
619
flashvsr_arch/pipelines/flashvsr_tiny_long.py
Normal file
619
flashvsr_arch/pipelines/flashvsr_tiny_long.py
Normal file
@@ -0,0 +1,619 @@
|
||||
import types
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
# import pyfiglet
|
||||
|
||||
from ..models.utils import clean_vram
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
||||
# -----------------------------
|
||||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
||||
N, C = feat.shape[:2]
|
||||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
||||
std = var.sqrt().view(N, C, 1, 1)
|
||||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||||
return mean, std
|
||||
|
||||
|
||||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
||||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = _calc_mean_std(style_feat)
|
||||
content_mean, content_std = _calc_mean_std(content_feat)
|
||||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
||||
# -----------------------------
|
||||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
||||
vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125 ],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
return torch.tensor(vals, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
N, C, H, W = x.shape
|
||||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
||||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
||||
pad = radius
|
||||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
||||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
||||
return out
|
||||
|
||||
|
||||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||||
high = torch.zeros_like(x)
|
||||
low = x
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
blurred = _wavelet_blur(low, radius)
|
||||
high = high + (low - blurred)
|
||||
low = blurred
|
||||
return high, low
|
||||
|
||||
|
||||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
||||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
||||
_, s_low = _wavelet_decompose(style, levels=levels)
|
||||
return c_high + s_low
|
||||
|
||||
# -----------------------------
|
||||
# Safetensors support ---------
|
||||
# -----------------------------
|
||||
st_load_file = None # Define the variable in global scope first
|
||||
try:
|
||||
from safetensors.torch import load_file as st_load_file
|
||||
except ImportError:
|
||||
# st_load_file remains None if import fails
|
||||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
||||
|
||||
# -----------------------------
|
||||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
||||
# -----------------------------
|
||||
class TorchColorCorrectorWavelet(nn.Module):
|
||||
def __init__(self, levels: int = 5):
|
||||
super().__init__()
|
||||
self.levels = levels
|
||||
|
||||
@staticmethod
|
||||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
||||
B, C, f, H, W = x.shape
|
||||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
||||
return y, B, f
|
||||
|
||||
@staticmethod
|
||||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
||||
BF, C, H, W = y.shape
|
||||
assert BF == B * f
|
||||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
||||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
||||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
||||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
||||
|
||||
B, C, f, H, W = hq_image.shape
|
||||
if chunk_size is None or chunk_size >= f:
|
||||
hq4, B, f = self._flatten_time(hq_image)
|
||||
lq4, _, _ = self._flatten_time(lq_image)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out = self._unflatten_time(out4, B, f)
|
||||
return out
|
||||
|
||||
outs = []
|
||||
for start in range(0, f, chunk_size):
|
||||
end = min(start + chunk_size, f)
|
||||
hq_chunk = hq_image[:, :, start:end]
|
||||
lq_chunk = lq_image[:, :, start:end]
|
||||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
||||
lq4, _, _ = self._flatten_time(lq_chunk)
|
||||
if method == 'wavelet':
|
||||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||||
elif method == 'adain':
|
||||
out4 = _adain(hq4, lq4)
|
||||
else:
|
||||
raise ValueError(f"未知 method: {method}")
|
||||
out4 = torch.clamp(out4, *clip_range)
|
||||
out_chunk = self._unflatten_time(out4, B_, f_)
|
||||
outs.append(out_chunk)
|
||||
out = torch.cat(outs, dim=2)
|
||||
return out
|
||||
|
||||
# -----------------------------
|
||||
# 简化版 Pipeline(仅 dit + vae)
|
||||
# -----------------------------
|
||||
class FlashVSRTinyLongPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['dit', 'vae']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
self.prompt_emb_posi = None
|
||||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
||||
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
# 仅管理 dit / vae
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config=dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
# 可选:统一序列并行入口(此处默认关闭)
|
||||
pipe.use_unified_sequence_parallel = False
|
||||
return pipe
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
# -------------------------
|
||||
# 新增:显式 KV 预初始化函数
|
||||
# -------------------------
|
||||
def init_cross_kv(
|
||||
self,
|
||||
context_tensor: Optional[torch.Tensor] = None,
|
||||
prompt_path = None,
|
||||
):
|
||||
self.load_models_to_device(["dit"])
|
||||
"""
|
||||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
||||
必须在 __call__ 前显式调用一次。
|
||||
"""
|
||||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
||||
|
||||
if self.dit is None:
|
||||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
||||
|
||||
if context_tensor is None:
|
||||
if prompt_path is None:
|
||||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
||||
|
||||
# --- Safetensors loading logic added here ---
|
||||
prompt_path_lower = prompt_path.lower()
|
||||
if prompt_path_lower.endswith(".safetensors"):
|
||||
if st_load_file is None:
|
||||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
||||
|
||||
# Load the tensor from safetensors
|
||||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
||||
|
||||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
||||
if len(loaded_dict) == 1:
|
||||
ctx = list(loaded_dict.values())[0]
|
||||
elif 'context' in loaded_dict: # Common key for text context
|
||||
ctx = loaded_dict['context']
|
||||
else:
|
||||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
||||
|
||||
else:
|
||||
# Default behavior for .pth, .pt, etc.
|
||||
ctx = torch.load(prompt_path, map_location=self.device)
|
||||
|
||||
# --------------------------------------------
|
||||
# ctx = torch.load(prompt_path, map_location=self.device)
|
||||
else:
|
||||
ctx = context_tensor
|
||||
|
||||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if self.prompt_emb_posi is None:
|
||||
self.prompt_emb_posi = {}
|
||||
self.prompt_emb_posi['context'] = ctx
|
||||
|
||||
if hasattr(self.dit, "reinit_cross_kv"):
|
||||
self.dit.reinit_cross_kv(ctx)
|
||||
else:
|
||||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
||||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
||||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
||||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
||||
self.load_models_to_device([])
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
def decode_video(self, latents, cond=None, **kwargs):
|
||||
frames = self.TCDecoder.decode_video(
|
||||
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cond
|
||||
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
||||
|
||||
return frames
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt=None,
|
||||
negative_prompt="",
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="gpu",
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
tiled=True,
|
||||
tile_size=(60, 104),
|
||||
tile_stride=(30, 52),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
LQ_video=None,
|
||||
is_full_block=False,
|
||||
if_buffer=False,
|
||||
topk_ratio=2.0,
|
||||
kv_ratio=3.0,
|
||||
local_range = 9,
|
||||
color_fix = True,
|
||||
unload_dit = False,
|
||||
skip_vae = False,
|
||||
):
|
||||
# 只接受 cfg=1.0(与原代码一致)
|
||||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
||||
|
||||
# 要求:必须先 init_cross_kv()
|
||||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
||||
raise RuntimeError(
|
||||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
||||
" pipe.init_cross_kv()\n"
|
||||
"或传入自定义 context:\n"
|
||||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
||||
)
|
||||
|
||||
# 尺寸修正
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler 参数
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# 初始化噪声
|
||||
if if_buffer:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
else:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = noise
|
||||
|
||||
process_total_num = (num_frames - 1) // 8 - 2
|
||||
is_stream = True
|
||||
|
||||
# 清理可能存在的 LQ_proj_in cache
|
||||
if hasattr(self.dit, "LQ_proj_in"):
|
||||
self.dit.LQ_proj_in.clear_cache()
|
||||
|
||||
frames_total = []
|
||||
LQ_pre_idx = 0
|
||||
LQ_cur_idx = 0
|
||||
self.TCDecoder.clean_mem()
|
||||
|
||||
with torch.no_grad():
|
||||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
||||
if cur_process_idx == 0:
|
||||
pre_cache_k = [None] * len(self.dit.blocks)
|
||||
pre_cache_v = [None] * len(self.dit.blocks)
|
||||
LQ_latents = None
|
||||
inner_loop_num = 7
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
||||
cur_latents = latents[:, :, :6, :, :]
|
||||
else:
|
||||
LQ_latents = None
|
||||
inner_loop_num = 2
|
||||
for inner_idx in range(inner_loop_num):
|
||||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
||||
) if LQ_video is not None else None
|
||||
if cur is None:
|
||||
continue
|
||||
if LQ_latents is None:
|
||||
LQ_latents = cur
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||||
|
||||
# 推理(无 motion_controller / vace)
|
||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||||
self.dit,
|
||||
x=cur_latents,
|
||||
timestep=self.timestep,
|
||||
context=None,
|
||||
tea_cache=None,
|
||||
use_unified_sequence_parallel=False,
|
||||
LQ_latents=LQ_latents,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k,
|
||||
pre_cache_v=pre_cache_v,
|
||||
topk_ratio=topk_ratio,
|
||||
kv_ratio=kv_ratio,
|
||||
cur_process_idx=cur_process_idx,
|
||||
t_mod=self.t_mod,
|
||||
t=self.t,
|
||||
local_range = local_range,
|
||||
)
|
||||
|
||||
# 更新 latent
|
||||
cur_latents = cur_latents - noise_pred_posi
|
||||
|
||||
# Decode
|
||||
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
|
||||
cur_frames = self.TCDecoder.decode_video(
|
||||
cur_latents.transpose(1, 2),
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
||||
# 颜色校正(wavelet)
|
||||
try:
|
||||
if color_fix:
|
||||
cur_frames = self.ColorCorrector(
|
||||
cur_frames.to(device=self.device),
|
||||
cur_LQ_frame,
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=None,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
frames_total.append(cur_frames.to('cpu'))
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
del cur_frames, cur_latents, cur_LQ_frame
|
||||
clean_vram()
|
||||
|
||||
frames = torch.cat(frames_total, dim=2)
|
||||
return frames[0]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# TeaCache(保留原逻辑;此处默认不启用)
|
||||
# -----------------------------
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
||||
if should_calc:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step = (self.step + 1) % self.num_inference_steps
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
||||
# -----------------------------
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
tea_cache: Optional[TeaCache] = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
LQ_latents: Optional[torch.Tensor] = None,
|
||||
is_full_block: bool = False,
|
||||
is_stream: bool = False,
|
||||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
||||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
||||
topk_ratio: float = 2.0,
|
||||
kv_ratio: float = 3.0,
|
||||
cur_process_idx: int = 0,
|
||||
t_mod : torch.Tensor = None,
|
||||
t : torch.Tensor = None,
|
||||
local_range: int = 9,
|
||||
**kwargs,
|
||||
):
|
||||
# patchify
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
win = (2, 8, 8)
|
||||
seqlen = f // win[0]
|
||||
local_num = seqlen
|
||||
window_size = win[0] * h * w // 128
|
||||
square_num = window_size * window_size
|
||||
topk = int(square_num * topk_ratio) - 1
|
||||
kv_len = int(kv_ratio)
|
||||
|
||||
# RoPE 位置(分段)
|
||||
if cur_process_idx == 0:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
else:
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache(默认不启用)
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
||||
|
||||
# 统一序列并行(此处默认关闭)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
# Block 堆叠
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if LQ_latents is not None and block_id < len(LQ_latents):
|
||||
x = x + LQ_latents[block_id]
|
||||
x, last_pre_cache_k, last_pre_cache_v = block(
|
||||
x, context, t_mod, freqs, f, h, w,
|
||||
local_num, topk,
|
||||
block_id=block_id,
|
||||
kv_len=kv_len,
|
||||
is_full_block=is_full_block,
|
||||
is_stream=is_stream,
|
||||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
||||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
||||
local_range = local_range,
|
||||
)
|
||||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
||||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x, pre_cache_k, pre_cache_v
|
||||
1
flashvsr_arch/schedulers/__init__.py
Normal file
1
flashvsr_arch/schedulers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .flow_match import FlowMatchScheduler
|
||||
79
flashvsr_arch/schedulers/flow_match.py
Normal file
79
flashvsr_arch/schedulers/flow_match.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
if training:
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
weights = self.linear_timesteps_weights[timestep_id]
|
||||
return weights
|
||||
1
flashvsr_arch/vram_management/__init__.py
Normal file
1
flashvsr_arch/vram_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .layers import *
|
||||
95
flashvsr_arch/vram_management/layers.py
Normal file
95
flashvsr_arch/vram_management/layers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch, copy
|
||||
from ..models.utils import init_weights_on_device
|
||||
|
||||
|
||||
def cast_to(weight, dtype, device):
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
|
||||
class AutoWrappedModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
num_param = sum(p.numel() for p in module.parameters())
|
||||
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||
return total_num_param
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||
model.vram_management_enabled = True
|
||||
|
||||
248
inference.py
248
inference.py
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .bim_vfi_arch import BiMVFI
|
||||
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
||||
@@ -621,3 +624,248 @@ class GIMMVFIModel:
|
||||
results.append(torch.clamp(unpadded, 0, 1))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlashVSR model wrapper (4x video super-resolution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FlashVSRModel:
|
||||
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
|
||||
|
||||
Supports three pipeline modes:
|
||||
- full: Standard VAE decode, highest quality
|
||||
- tiny: TCDecoder decode, faster
|
||||
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
|
||||
"""
|
||||
|
||||
# Minimum input frame count required by the pipeline
|
||||
MIN_FRAMES = 21
|
||||
|
||||
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
|
||||
from safetensors.torch import load_file
|
||||
from .flashvsr_arch import (
|
||||
ModelManager, FlashVSRFullPipeline,
|
||||
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
|
||||
)
|
||||
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
|
||||
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
|
||||
|
||||
self.mode = mode
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
|
||||
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
|
||||
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
|
||||
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
|
||||
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
|
||||
|
||||
mm = ModelManager(torch_dtype=dtype, device="cpu")
|
||||
|
||||
if mode == "full":
|
||||
mm.load_models([dit_path, vae_path])
|
||||
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
|
||||
self.pipe.vae.model.encoder = None
|
||||
self.pipe.vae.model.conv1 = None
|
||||
else:
|
||||
mm.load_models([dit_path])
|
||||
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
|
||||
self.pipe = Pipeline.from_model_manager(mm, device=device)
|
||||
|
||||
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
|
||||
self.pipe.TCDecoder = build_tcdecoder(
|
||||
[512, 256, 128, 128], device, dtype, 16 + 768,
|
||||
)
|
||||
self.pipe.TCDecoder.load_state_dict(
|
||||
load_file(tcd_path, device=device), strict=False,
|
||||
)
|
||||
self.pipe.TCDecoder.clean_mem()
|
||||
|
||||
# LQ frame projection — Causal variant for FlashVSR v1.1
|
||||
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
|
||||
if os.path.exists(lq_path):
|
||||
lq_sd = load_file(lq_path, device="cpu")
|
||||
cleaned = {}
|
||||
for k, v in lq_sd.items():
|
||||
cleaned[k.removeprefix("LQ_proj_in.")] = v
|
||||
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
|
||||
self.pipe.denoising_model().LQ_proj_in.to(device)
|
||||
|
||||
self.pipe.to(device, dtype)
|
||||
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
self.pipe.init_cross_kv(prompt_path=prompt_path)
|
||||
self.pipe.load_models_to_device([]) # offload to CPU
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
self.pipe.device = device
|
||||
return self
|
||||
|
||||
def load_to_device(self):
|
||||
"""Load models to the compute device for inference."""
|
||||
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
|
||||
self.pipe.load_models_to_device(names)
|
||||
|
||||
def offload(self):
|
||||
"""Offload models to CPU."""
|
||||
self.pipe.load_models_to_device([])
|
||||
|
||||
def clear_caches(self):
|
||||
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
|
||||
self.pipe.denoising_model().LQ_proj_in.clear_cache()
|
||||
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
|
||||
self.pipe.vae.clear_cache()
|
||||
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
|
||||
self.pipe.TCDecoder.clean_mem()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame preprocessing / postprocessing helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _compute_dims(w, h, scale, align=128):
|
||||
sw, sh = w * scale, h * scale
|
||||
tw = math.ceil(sw / align) * align
|
||||
th = math.ceil(sh / align) * align
|
||||
return sw, sh, tw, th
|
||||
|
||||
@staticmethod
|
||||
def _restore_video_sequence(result, expected):
|
||||
"""Trim pipeline output to the expected frame count."""
|
||||
if result.shape[0] > expected:
|
||||
result = result[:expected]
|
||||
elif result.shape[0] < expected:
|
||||
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
|
||||
result = torch.cat([result, pad], dim=0)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _next_8n5(n, minimum=21):
|
||||
"""Next integer >= n of the form 8k+5 (minimum 21)."""
|
||||
if n < minimum:
|
||||
return minimum
|
||||
return ((n - 5 + 7) // 8) * 8 + 5
|
||||
|
||||
def _prepare_video(self, frames, scale):
|
||||
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
|
||||
|
||||
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
|
||||
1. Bicubic-upscale each frame to target resolution
|
||||
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
|
||||
3. Normalize to [-1, 1]
|
||||
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
|
||||
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
|
||||
|
||||
Returns:
|
||||
video: [1, C, F_padded, H, W] tensor
|
||||
th, tw: padded spatial dimensions
|
||||
nf: padded frame count
|
||||
sh, sw: actual (unpadded) spatial dimensions
|
||||
pad_top, pad_left: spatial padding offsets for output cropping
|
||||
"""
|
||||
N, H, W, C = frames.shape
|
||||
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
||||
|
||||
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
|
||||
N_padded = self._next_8n5(N)
|
||||
|
||||
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
|
||||
target = N_padded + 4
|
||||
|
||||
# Centered spatial padding offsets
|
||||
pad_top = (th - sh) // 2
|
||||
pad_bottom = th - sh - pad_top
|
||||
pad_left = (tw - sw) // 2
|
||||
pad_right = tw - sw - pad_left
|
||||
|
||||
processed = []
|
||||
for i in range(target):
|
||||
frame_idx = min(i, N - 1) # clamp to last real frame
|
||||
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
||||
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
||||
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
|
||||
# Centered reflect padding (matches naxci1 reference)
|
||||
try:
|
||||
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
|
||||
except RuntimeError:
|
||||
# Reflect requires pad < input size; fall back to replicate
|
||||
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
|
||||
normalized = upscaled * 2.0 - 1.0
|
||||
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
||||
|
||||
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
||||
nf = video.shape[2]
|
||||
|
||||
return video, th, tw, nf, sh, sw, pad_top, pad_left
|
||||
|
||||
@staticmethod
|
||||
def _to_frames(video):
|
||||
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
|
||||
from einops import rearrange
|
||||
v = video.squeeze(0) if video.dim() == 5 else video
|
||||
v = rearrange(v, "C F H W -> F H W C")
|
||||
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main upscale method
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
||||
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
|
||||
color_fix=True, unload_dit=False, seed=1,
|
||||
progress_bar_cmd=None):
|
||||
"""Upscale video frames with FlashVSR.
|
||||
|
||||
Args:
|
||||
frames: [F, H, W, C] float32 [0, 1] with F >= 21
|
||||
scale: Upscaling factor (2 or 4)
|
||||
tiled: Enable VAE tiled decode (saves VRAM)
|
||||
tile_size: (H, W) tile size for VAE tiling
|
||||
topk_ratio: Sparse attention ratio (higher = faster, less detail)
|
||||
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
|
||||
local_range: Local attention window (9=sharp, 11=stable)
|
||||
color_fix: Apply wavelet color correction
|
||||
unload_dit: Offload DiT before VAE decode (saves VRAM)
|
||||
seed: Random seed
|
||||
progress_bar_cmd: Callable wrapping an iterable for progress display
|
||||
|
||||
Returns:
|
||||
[F, H*scale, W*scale, C] float32 [0, 1]
|
||||
"""
|
||||
if progress_bar_cmd is None:
|
||||
from tqdm import tqdm
|
||||
progress_bar_cmd = tqdm
|
||||
|
||||
original_count = frames.shape[0]
|
||||
|
||||
# Prepare video tensor (bicubic upscale + centered pad)
|
||||
video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
|
||||
|
||||
# Move LQ video to compute device (except for "long" mode which streams)
|
||||
if "long" not in self.pipe.__class__.__name__.lower():
|
||||
video = video.to(self.pipe.device)
|
||||
|
||||
# Run pipeline
|
||||
out = self.pipe(
|
||||
prompt="", negative_prompt="",
|
||||
cfg_scale=1.0, num_inference_steps=1,
|
||||
seed=seed, tiled=tiled, tile_size=tile_size,
|
||||
progress_bar_cmd=progress_bar_cmd,
|
||||
LQ_video=video,
|
||||
num_frames=nf, height=th, width=tw,
|
||||
is_full_block=False, if_buffer=True,
|
||||
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
|
||||
kv_ratio=kv_ratio, local_range=local_range,
|
||||
color_fix=color_fix, unload_dit=unload_dit,
|
||||
)
|
||||
|
||||
# Convert to ComfyUI format with centered spatial crop
|
||||
result = self._to_frames(out).cpu()
|
||||
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
|
||||
|
||||
# Trim to original frame count
|
||||
result = self._restore_video_sequence(result, original_count)
|
||||
|
||||
return result
|
||||
|
||||
418
nodes.py
418
nodes.py
@@ -8,7 +8,7 @@ import torch
|
||||
import folder_paths
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel
|
||||
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel
|
||||
from .bim_vfi_arch import clear_backwarp_cache
|
||||
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
||||
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
||||
@@ -1507,3 +1507,419 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
|
||||
return (result, model)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlashVSR nodes (4x video super-resolution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FLASHVSR_HF_REPO = "1038lab/FlashVSR"
|
||||
FLASHVSR_REQUIRED_FILES = [
|
||||
"FlashVSR1_1.safetensors",
|
||||
"Wan2.1_VAE.safetensors",
|
||||
"LQ_proj_in.safetensors",
|
||||
"TCDecoder.safetensors",
|
||||
"Prompt.safetensors",
|
||||
]
|
||||
|
||||
# Check common locations so we reuse models from 1038lab/ComfyUI-FlashVSR
|
||||
FLASHVSR_MODEL_DIR = None
|
||||
for _dirname in ("FlashVSR", "flashvsr"):
|
||||
_candidate = os.path.join(folder_paths.models_dir, _dirname)
|
||||
if os.path.isdir(_candidate) and all(
|
||||
os.path.exists(os.path.join(_candidate, f)) for f in FLASHVSR_REQUIRED_FILES
|
||||
):
|
||||
FLASHVSR_MODEL_DIR = _candidate
|
||||
break
|
||||
if FLASHVSR_MODEL_DIR is None:
|
||||
# Default to "FlashVSR" (matches 1038lab convention)
|
||||
FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "FlashVSR")
|
||||
|
||||
|
||||
def download_flashvsr_models(model_dir):
|
||||
"""Download FlashVSR checkpoints from HuggingFace if missing."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
missing = [f for f in FLASHVSR_REQUIRED_FILES
|
||||
if not os.path.exists(os.path.join(model_dir, f))]
|
||||
if not missing:
|
||||
return
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...")
|
||||
snapshot_download(
|
||||
repo_id=FLASHVSR_HF_REPO,
|
||||
local_dir=model_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
resume_download=True,
|
||||
)
|
||||
|
||||
still_missing = [f for f in FLASHVSR_REQUIRED_FILES
|
||||
if not os.path.exists(os.path.join(model_dir, f))]
|
||||
if still_missing:
|
||||
raise FileNotFoundError(
|
||||
f"[FlashVSR] Failed to download: {', '.join(still_missing)}. "
|
||||
f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}"
|
||||
)
|
||||
logger.info("[FlashVSR] All checkpoints downloaded successfully.")
|
||||
|
||||
|
||||
class _FlashVSRProgressBar:
|
||||
"""Wrap an iterable with a ComfyUI ProgressBar."""
|
||||
|
||||
def __init__(self, total, pbar, step_ref):
|
||||
self.total = total
|
||||
self.pbar = pbar
|
||||
self.step_ref = step_ref
|
||||
|
||||
def __call__(self, iterable):
|
||||
return self._Wrapper(iterable, self.pbar, self.step_ref)
|
||||
|
||||
class _Wrapper:
|
||||
def __init__(self, iterable, pbar, step_ref):
|
||||
self.iterable = iterable
|
||||
self.pbar = pbar
|
||||
self.step_ref = step_ref
|
||||
self._iter = iter(iterable)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
val = next(self._iter)
|
||||
self.step_ref[0] += 1
|
||||
self.pbar.update_absolute(self.step_ref[0])
|
||||
return val
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iterable)
|
||||
|
||||
|
||||
class LoadFlashVSRModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mode": (["tiny", "tiny-long", "full"], {
|
||||
"default": "tiny",
|
||||
"tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. "
|
||||
"Tiny-long: streaming TCDecoder, lowest VRAM for long videos. "
|
||||
"Full: standard VAE decode, highest quality but more VRAM.",
|
||||
}),
|
||||
"precision": (["bf16", "fp16"], {
|
||||
"default": "bf16",
|
||||
"tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLASHVSR_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def load_model(self, mode, precision):
|
||||
download_flashvsr_models(FLASHVSR_MODEL_DIR)
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
|
||||
|
||||
wrapper = FlashVSRModel(
|
||||
model_dir=FLASHVSR_MODEL_DIR,
|
||||
mode=mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})")
|
||||
return (wrapper,)
|
||||
|
||||
|
||||
class FlashVSRUpscale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Input video frames. Minimum 21 frames required.",
|
||||
}),
|
||||
"model": ("FLASHVSR_MODEL", {
|
||||
"tooltip": "FlashVSR model from the Load FlashVSR Model node.",
|
||||
}),
|
||||
"scale": ("INT", {
|
||||
"default": 4, "min": 2, "max": 4, "step": 2,
|
||||
"tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.",
|
||||
}),
|
||||
"frame_chunk_size": ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). "
|
||||
"Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).",
|
||||
}),
|
||||
"tiled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.",
|
||||
}),
|
||||
"tile_size_h": ("INT", {
|
||||
"default": 60, "min": 16, "max": 256, "step": 4,
|
||||
"tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.",
|
||||
}),
|
||||
"tile_size_w": ("INT", {
|
||||
"default": 104, "min": 16, "max": 256, "step": 4,
|
||||
"tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.",
|
||||
}),
|
||||
"topk_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
"tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.",
|
||||
}),
|
||||
"kv_ratio": ("FLOAT", {
|
||||
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
"tooltip": "KV cache ratio. Higher = better quality, more VRAM. 3.0 recommended.",
|
||||
}),
|
||||
"local_range": ([9, 11], {
|
||||
"default": 11,
|
||||
"tooltip": "Local attention window. 9=sharper details, 11=more temporal stability (recommended).",
|
||||
}),
|
||||
"color_fix": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply color correction to prevent color shifts from the diffusion process.",
|
||||
}),
|
||||
"unload_dit": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.",
|
||||
}),
|
||||
"seed": ("INT", {
|
||||
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "Random seed for the diffusion process.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def upscale(self, images, model, scale, frame_chunk_size,
|
||||
tiled, tile_size_h, tile_size_w,
|
||||
topk_ratio, kv_ratio, local_range,
|
||||
color_fix, unload_dit, seed):
|
||||
num_frames = images.shape[0]
|
||||
if num_frames < FlashVSRModel.MIN_FRAMES:
|
||||
raise ValueError(
|
||||
f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}"
|
||||
)
|
||||
|
||||
tile_size = (tile_size_h, tile_size_w)
|
||||
|
||||
# Build frame chunks
|
||||
if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames:
|
||||
chunks = [(0, num_frames)]
|
||||
else:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < num_frames:
|
||||
end = min(start + frame_chunk_size, num_frames)
|
||||
chunks.append((start, end))
|
||||
if end == num_frames:
|
||||
break
|
||||
start = end
|
||||
# If the last chunk is too small, merge it into the previous one
|
||||
if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES:
|
||||
prev_start = chunks[-2][0]
|
||||
last_end = chunks[-1][1]
|
||||
chunks = chunks[:-2]
|
||||
chunks.append((prev_start, last_end))
|
||||
|
||||
# Estimate total pipeline steps for progress bar
|
||||
# Mirrors _prepare_video two-stage padding: next_8n5(N) + 4
|
||||
def _next_8n5(n, minimum=21):
|
||||
return minimum if n < minimum else ((n - 5 + 7) // 8) * 8 + 5
|
||||
|
||||
total_steps = 0
|
||||
for cs, ce in chunks:
|
||||
n = ce - cs
|
||||
target = _next_8n5(n) + 4 # always 8k+1
|
||||
total_steps += max(1, (target - 1) // 8 - 2)
|
||||
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
progress = _FlashVSRProgressBar(total_steps, pbar, step_ref)
|
||||
|
||||
model.load_to_device()
|
||||
|
||||
result_chunks = []
|
||||
for chunk_start, chunk_end in chunks:
|
||||
chunk_frames = images[chunk_start:chunk_end]
|
||||
|
||||
chunk_result = model.upscale(
|
||||
chunk_frames,
|
||||
scale=scale, tiled=tiled, tile_size=tile_size,
|
||||
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
|
||||
local_range=local_range, color_fix=color_fix,
|
||||
unload_dit=unload_dit, seed=seed,
|
||||
progress_bar_cmd=progress,
|
||||
)
|
||||
result_chunks.append(chunk_result)
|
||||
model.clear_caches()
|
||||
|
||||
model.offload()
|
||||
from .flashvsr_arch.models.utils import clean_vram
|
||||
clean_vram()
|
||||
|
||||
return (torch.cat(result_chunks, dim=0),)
|
||||
|
||||
|
||||
class FlashVSRSegmentUpscale:
|
||||
"""Process a numbered segment with temporal overlap and crossfade blending.
|
||||
|
||||
Chain multiple instances with Save nodes between them to bound peak RAM.
|
||||
The model pass-through forces sequential execution so each segment
|
||||
saves and frees RAM before the next starts.
|
||||
|
||||
Crossfade blending within the overlap region:
|
||||
- First (overlap - blend) frames: warmup only, discarded from output
|
||||
- Last blend frames: linear alpha crossfade with previous segment's tail
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {
|
||||
"tooltip": "Full input video frames. Minimum 21 frames required.",
|
||||
}),
|
||||
"model": ("FLASHVSR_MODEL", {
|
||||
"tooltip": "FlashVSR model from Load FlashVSR Model. "
|
||||
"Chain the model output to the next segment node for sequential execution.",
|
||||
}),
|
||||
"segment_index": ("INT", {
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Which segment to process (0-based).",
|
||||
}),
|
||||
"segment_size": ("INT", {
|
||||
"default": 100, "min": 21, "max": 10000, "step": 1,
|
||||
"tooltip": "Number of input frames per segment.",
|
||||
}),
|
||||
"overlap_frames": ("INT", {
|
||||
"default": 8, "min": 0, "max": 100, "step": 1,
|
||||
"tooltip": "Number of overlapping frames between adjacent segments. "
|
||||
"These frames provide temporal context and crossfade blending.",
|
||||
}),
|
||||
"blend_frames": ("INT", {
|
||||
"default": 4, "min": 0, "max": 50, "step": 1,
|
||||
"tooltip": "Number of frames within the overlap region to crossfade. "
|
||||
"Must be <= overlap_frames. The rest of the overlap is warmup (discarded).",
|
||||
}),
|
||||
"scale": ("INT", {
|
||||
"default": 4, "min": 2, "max": 4, "step": 2,
|
||||
"tooltip": "Upscaling factor.",
|
||||
}),
|
||||
"tiled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable VAE tiled decode.",
|
||||
}),
|
||||
"tile_size_h": ("INT", {
|
||||
"default": 60, "min": 16, "max": 256, "step": 4,
|
||||
}),
|
||||
"tile_size_w": ("INT", {
|
||||
"default": 104, "min": 16, "max": 256, "step": 4,
|
||||
}),
|
||||
"topk_ratio": ("FLOAT", {
|
||||
"default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
}),
|
||||
"kv_ratio": ("FLOAT", {
|
||||
"default": 3.0, "min": 1.0, "max": 4.0, "step": 0.1,
|
||||
}),
|
||||
"local_range": ([9, 11], {
|
||||
"default": 11,
|
||||
}),
|
||||
"color_fix": ("BOOLEAN", {
|
||||
"default": True,
|
||||
}),
|
||||
"unload_dit": ("BOOLEAN", {
|
||||
"default": False,
|
||||
}),
|
||||
"seed": ("INT", {
|
||||
"default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL")
|
||||
RETURN_NAMES = ("images", "model")
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "video/FlashVSR"
|
||||
|
||||
def upscale(self, images, model, segment_index, segment_size,
|
||||
overlap_frames, blend_frames, scale,
|
||||
tiled, tile_size_h, tile_size_w,
|
||||
topk_ratio, kv_ratio, local_range,
|
||||
color_fix, unload_dit, seed):
|
||||
total_input = images.shape[0]
|
||||
blend_frames = min(blend_frames, overlap_frames)
|
||||
|
||||
# Clear stale overlap data from previous workflow runs
|
||||
if segment_index == 0:
|
||||
model._overlap_tail = None
|
||||
|
||||
# Compute segment boundaries
|
||||
stride = segment_size - overlap_frames
|
||||
start = segment_index * stride
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input:
|
||||
# Past the end
|
||||
return (images[:1], model)
|
||||
|
||||
# Ensure minimum frame count
|
||||
actual_size = end - start
|
||||
if actual_size < FlashVSRModel.MIN_FRAMES:
|
||||
start = max(0, end - FlashVSRModel.MIN_FRAMES)
|
||||
actual_size = end - start
|
||||
|
||||
segment_frames = images[start:end]
|
||||
|
||||
tile_size = (tile_size_h, tile_size_w)
|
||||
|
||||
model.load_to_device()
|
||||
|
||||
result = model.upscale(
|
||||
segment_frames,
|
||||
scale=scale, tiled=tiled, tile_size=tile_size,
|
||||
topk_ratio=topk_ratio, kv_ratio=kv_ratio,
|
||||
local_range=local_range, color_fix=color_fix,
|
||||
unload_dit=unload_dit, seed=seed,
|
||||
)
|
||||
|
||||
model.clear_caches()
|
||||
model.offload()
|
||||
from .flashvsr_arch.models.utils import clean_vram
|
||||
clean_vram()
|
||||
|
||||
# Handle crossfade blending with previous segment's tail
|
||||
if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'):
|
||||
prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU
|
||||
|
||||
# The overlap region in result: first overlap_frames of the upscaled output
|
||||
# Within overlap: first (overlap - blend) frames are warmup (discard)
|
||||
# last blend_frames frames: crossfade with prev_tail
|
||||
warmup = overlap_frames - blend_frames
|
||||
|
||||
if blend_frames > 0 and prev_tail is not None:
|
||||
# Linear alpha ramp for crossfade
|
||||
alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1)
|
||||
blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames]
|
||||
result = torch.cat([blended, result[overlap_frames:]], dim=0)
|
||||
else:
|
||||
result = result[overlap_frames:]
|
||||
elif segment_index > 0 and overlap_frames > 0:
|
||||
# No previous tail stored, just skip overlap
|
||||
result = result[overlap_frames:]
|
||||
|
||||
# Store tail frames for next segment's crossfade
|
||||
if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames:
|
||||
model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16)
|
||||
else:
|
||||
model._overlap_tail = None
|
||||
|
||||
return (result, model)
|
||||
|
||||
@@ -4,3 +4,4 @@ yacs
|
||||
easydict
|
||||
einops
|
||||
huggingface_hub
|
||||
safetensors
|
||||
|
||||
Reference in New Issue
Block a user