diff --git a/README.md b/README.md index a85fc35..bb8a9c9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI +# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI + GIMM-VFI + FlashVSR -ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM. +ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024), and [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) (NeurIPS 2024), plus video super-resolution using [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) (arXiv 2025). Designed for long videos with thousands of frames — processes them without running out of VRAM. ## Which model should I use? @@ -18,6 +18,21 @@ ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://githu **TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. Use **GIMM-VFI** when you want 4x or 8x interpolation without recursive passes — it generates all intermediate frames in a single forward pass per pair. +### Video Super-Resolution + +FlashVSR is a different category — **spatial upscaling** rather than temporal interpolation. It can be combined with any of the VFI models above. + +| | FlashVSR | +|---|----------| +| **Task** | 4x video super-resolution | +| **Architecture** | Wan 2.1-1.3B DiT + VAE (diffusion-based) | +| **Modes** | Full (best quality), Tiny (fast), Tiny-Long (streaming, lowest VRAM) | +| **VRAM** | ~8–12 GB (tiled, tiny mode) / ~16–24 GB (full mode) | +| **Params** | ~1.3B (DiT) + ~200M (VAE) | +| **Min input** | 21 frames | +| **Paper** | arXiv 2510.12747 | +| **License** | Apache 2.0 | + ## Nodes ### BIM-VFI @@ -136,7 +151,61 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate, p Same as GIMM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate. -**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 +**Output frame count (VFI models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 + +### FlashVSR + +FlashVSR does **4x video super-resolution** (spatial upscaling), not frame interpolation. It uses a diffusion-based approach built on Wan 2.1-1.3B for temporally coherent upscaling. + +#### Load FlashVSR Model + +Downloads checkpoints from HuggingFace (~7.5 GB) on first use to `ComfyUI/models/flashvsr/`. + +| Input | Description | +|-------|-------------| +| **mode** | Pipeline mode: `tiny` (fast TCDecoder decode), `tiny-long` (streaming TCDecoder, lowest VRAM for long videos), `full` (standard VAE decode, best quality) | +| **precision** | `bf16` (faster on modern GPUs) or `fp16` (for older GPUs) | + +Checkpoints (auto-downloaded from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR)): +| Checkpoint | Size | Description | +|-----------|------|-------------| +| `FlashVSR1_1.safetensors` | ~5 GB | Main DiT model (v1.1) | +| `Wan2.1_VAE.safetensors` | ~2 GB | Video VAE | +| `LQ_proj_in.safetensors` | ~50 MB | Low-quality frame projection | +| `TCDecoder.safetensors` | ~200 MB | Tiny conditional decoder (for tiny/tiny-long modes) | +| `Prompt.safetensors` | ~1 MB | Precomputed text embeddings | + +#### FlashVSR Upscale + +Upscales an image batch with 4x spatial super-resolution. + +| Input | Description | +|-------|-------------| +| **images** | Input video frames (minimum 21 frames) | +| **model** | Model from the loader node | +| **scale** | Upscaling factor: 2x or 4x (4x is native resolution) | +| **frame_chunk_size** | Process in chunks of N frames to bound VRAM (0 = all at once). Recommended: 33 or 65. Each chunk must be >= 21 frames | +| **tiled** | Enable tiled VAE decode (reduces VRAM significantly) | +| **tile_size_h / tile_size_w** | VAE tile dimensions in latent space (default 60/104) | +| **topk_ratio** | Sparse attention ratio. Higher = faster, may lose fine detail (default 2.0) | +| **kv_ratio** | KV cache ratio. Higher = better quality, more VRAM (default 2.0) | +| **local_range** | Local attention window: 9 = sharper details, 11 = more temporal stability | +| **color_fix** | Apply wavelet color correction to prevent color shifts | +| **unload_dit** | Offload DiT to CPU before VAE decode (saves VRAM, slower) | +| **seed** | Random seed for the diffusion process | + +#### FlashVSR Segment Upscale + +Same as FlashVSR Upscale but processes a single segment of the input. Chain multiple instances with Save nodes between them to bound peak RAM. The model pass-through output forces sequential execution. + +| Input | Description | +|-------|-------------| +| **segment_index** | Which segment to process (0-based) | +| **segment_size** | Number of input frames per segment (minimum 21) | +| **overlap_frames** | Overlapping frames between adjacent segments for temporal context and crossfade blending | +| **blend_frames** | Number of frames within the overlap to crossfade (must be <= overlap_frames) | + +Plus all the same upscale parameters as FlashVSR Upscale. ## Installation @@ -147,7 +216,7 @@ cd ComfyUI/custom_nodes git clone https://github.com/your-user/ComfyUI-Tween.git ``` -Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version. +Dependencies (`gdown`, `cupy`, `timm`, `omegaconf`, `easydict`, `yacs`, `einops`, `huggingface_hub`, `safetensors`) are auto-installed on first load. The correct `cupy` variant is detected from your PyTorch CUDA version. > **Warning:** `cupy` is a large package (~800MB) and compilation/installation can take several minutes. The first ComfyUI startup after installing this node may appear to hang while `cupy` installs in the background. Check the console log for progress. If auto-install fails (e.g. missing build tools in Docker), install manually with: > ```bash @@ -168,7 +237,8 @@ python install.py - `timm` (for EMA-VFI and SGM-VFI) - `gdown` (for BIM-VFI/EMA-VFI/SGM-VFI model auto-download) - `omegaconf`, `easydict`, `yacs`, `einops` (for GIMM-VFI) -- `huggingface_hub` (for GIMM-VFI model auto-download) +- `huggingface_hub` (for GIMM-VFI and FlashVSR model auto-download) +- `safetensors` (for FlashVSR checkpoint loading) ## VRAM Guide @@ -181,7 +251,7 @@ python install.py ## Acknowledgments -This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, and the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU). GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, and `gimm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths). +This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU, the [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) implementation by S-Lab (NTU), and [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) by OpenImagingLab. GIMM-VFI architecture files in `gimm_vfi_arch/` are adapted from [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI) with safetensors checkpoints from [Kijai/GIMM-VFI_safetensors](https://huggingface.co/Kijai/GIMM-VFI_safetensors). FlashVSR architecture files in `flashvsr_arch/` are adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR) (a diffsynth subset) with safetensors checkpoints from [1038lab/FlashVSR](https://huggingface.co/1038lab/FlashVSR). Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, `sgm_vfi_arch/`, `gimm_vfi_arch/`, and `flashvsr_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, dtype safety patches, inference-only paths). **BiM-VFI:** > Wonyong Seo, Jihyong Oh, and Munchurl Kim. @@ -243,6 +313,21 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF } ``` +**FlashVSR:** +> Junhao Zhuang, Ting-Che Lin, Xin Zhong, Zhihong Pan, Chun Yuan, and Ailing Zeng. +> "FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer." +> *arXiv preprint arXiv:2510.12747*, 2025. +> [[arXiv]](https://arxiv.org/abs/2510.12747) [[GitHub]](https://github.com/OpenImagingLab/FlashVSR) + +```bibtex +@article{zhuang2025flashvsr, + title={FlashVSR: Efficient Real-World Video Super-Resolution via Distilled Diffusion Transformer}, + author={Zhuang, Junhao and Lin, Ting-Che and Zhong, Xin and Pan, Zhihong and Yuan, Chun and Zeng, Ailing}, + journal={arXiv preprint arXiv:2510.12747}, + year={2025} +} +``` + ## License The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details. @@ -252,3 +337,5 @@ The EMA-VFI model weights and architecture code are released under the [Apache 2 The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details. The GIMM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/GSeanCDAT/GIMM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/GSeanCDAT/GIMM-VFI) for details. ComfyUI adaptation based on [kijai/ComfyUI-GIMM-VFI](https://github.com/kijai/ComfyUI-GIMM-VFI). + +The FlashVSR model weights and architecture code are released under the [Apache 2.0 License](https://github.com/OpenImagingLab/FlashVSR/blob/main/LICENSE). See the [original repository](https://github.com/OpenImagingLab/FlashVSR) for details. Architecture files adapted from [1038lab/ComfyUI-FlashVSR](https://github.com/1038lab/ComfyUI-FlashVSR). diff --git a/__init__.py b/__init__.py index 8aedf02..f7808af 100644 --- a/__init__.py +++ b/__init__.py @@ -34,8 +34,8 @@ def _auto_install_deps(): except Exception as e: logger.warning(f"[Tween] Could not auto-install cupy: {e}") - # GIMM-VFI dependencies - for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub"): + # GIMM-VFI + FlashVSR dependencies + for pkg in ("omegaconf", "yacs", "easydict", "einops", "huggingface_hub", "safetensors"): try: __import__(pkg) except ImportError: @@ -50,6 +50,7 @@ from .nodes import ( LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, + LoadFlashVSRModel, FlashVSRUpscale, FlashVSRSegmentUpscale, ) WEB_DIRECTORY = "./web" @@ -68,6 +69,9 @@ NODE_CLASS_MAPPINGS = { "LoadGIMMVFIModel": LoadGIMMVFIModel, "GIMMVFIInterpolate": GIMMVFIInterpolate, "GIMMVFISegmentInterpolate": GIMMVFISegmentInterpolate, + "LoadFlashVSRModel": LoadFlashVSRModel, + "FlashVSRUpscale": FlashVSRUpscale, + "FlashVSRSegmentUpscale": FlashVSRSegmentUpscale, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -84,4 +88,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadGIMMVFIModel": "Load GIMM-VFI Model", "GIMMVFIInterpolate": "GIMM-VFI Interpolate", "GIMMVFISegmentInterpolate": "GIMM-VFI Segment Interpolate", + "LoadFlashVSRModel": "Load FlashVSR Model", + "FlashVSRUpscale": "FlashVSR Upscale", + "FlashVSRSegmentUpscale": "FlashVSR Segment Upscale", } diff --git a/flashvsr_arch/__init__.py b/flashvsr_arch/__init__.py new file mode 100644 index 0000000..92918e0 --- /dev/null +++ b/flashvsr_arch/__init__.py @@ -0,0 +1,4 @@ +from .models.model_manager import ModelManager +from .pipelines import FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline +from .models.utils import clean_vram, Buffer_LQ4x_Proj +from .models.TCDecoder import build_tcdecoder diff --git a/flashvsr_arch/configs/__init__.py b/flashvsr_arch/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flashvsr_arch/configs/model_config.py b/flashvsr_arch/configs/model_config.py new file mode 100644 index 0000000..2e8aa57 --- /dev/null +++ b/flashvsr_arch/configs/model_config.py @@ -0,0 +1,21 @@ +from ..models.wan_video_dit import WanModel +from ..models.wan_video_vae import WanVideoVAE + +model_loader_configs = [ + # (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource) + (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), + (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), + (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), + (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), + (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), +] +huggingface_model_loader_configs = [ +] +patch_model_loader_configs = [ +] diff --git a/flashvsr_arch/models/TCDecoder.py b/flashvsr_arch/models/TCDecoder.py new file mode 100644 index 0000000..ca825c3 --- /dev/null +++ b/flashvsr_arch/models/TCDecoder.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned) +- Encoder removed +- Transplant/widening helpers removed +- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple +from einops import rearrange +import torch.nn.init as init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +# ---------------------------- +# Utility / building blocks +# ---------------------------- + +class IdentityConv2d(nn.Conv2d): + """Same-shape Conv2d initialized to identity (Dirac).""" + def __init__(self, C, kernel_size=3, bias=False): + pad = kernel_size // 2 + super().__init__(C, C, kernel_size, padding=pad, bias=bias) + with torch.no_grad(): + init.dirac_(self.weight) + if self.bias is not None: + self.bias.zero_() + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential( + conv(n_in * 2, n_out), nn.ReLU(inplace=True), + conv(n_out, n_out), nn.ReLU(inplace=True), + conv(n_out, n_out) + ) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = nn.ReLU(inplace=True) + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + def forward(self, x): + # x: (B, C, F, H, W) + B, C, F, H, W = x.shape + if F % self.ff != 0: + first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1) + x = torch.cat([first_frame, x], dim=2) + return rearrange( + x, + 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w', + ff=self.ff, hh=self.hh, ww=self.ww + ).transpose(1, 2) + +# ---------------------------- +# Generic NTCHW graph executor (kept; used by decoder) +# ---------------------------- + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N*T, C, H, W) + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + out = [] + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + progress_bar = tqdm(range(T), disable=not show_progress_bar) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) + work_queue.insert(0, TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt) + if len(mem[i]) > b.stride: + raise ValueError("TPool internal state invalid.") + elif len(mem[i]) == b.stride: + N_, C_, H_, W_ = xt.shape + xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_)) + mem[i] = [] + work_queue.insert(0, TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C_, H_, W_ = xt.shape + for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)): + work_queue.insert(0, TWorkItem(xt_next, i+1)) + else: + xt = b(xt) + work_queue.insert(0, TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x, mem + +# ---------------------------- +# Decoder-only TAEHV +# ---------------------------- + +class TAEHV(nn.Module): + image_channels = 3 + def __init__( + self, + checkpoint_path="taehv.pth", + decoder_time_upscale=(True, True), + decoder_space_upscale=(True, True, True), + channels = [256, 128, 64, 64], + latent_channels = 16 + ): + """Initialize TAEHV (decoder-only) with built-in deepening after every ReLU. + Deepening config: how_many_each=1, k=3 (fixed as requested). + """ + super().__init__() + self.latent_channels = latent_channels + n_f = channels + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + + # Build the decoder "skeleton" + base_decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True), + + MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + TGrow(n_f[0], 1), + conv(n_f[0], n_f[1], bias=False), + + MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + conv(n_f[1], n_f[2], bias=False), + + MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + conv(n_f[2], n_f[3], bias=False), + + nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels), + ) + + # Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU + self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3) + + self.pixel_shuffle = PixelShuffle3d(4, 8, 8) + + if checkpoint_path is not None: + missing_keys = self.load_state_dict( + self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)), + strict=False + ) + print('missing_keys', missing_keys) + + # Initialize decoder mem state + self.mem = [None] * len(self.decoder) + + @staticmethod + def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential: + """Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU).""" + new_layers = [] + for b in decoder: + new_layers.append(b) + if isinstance(b, nn.ReLU): + # Deduce channel count from preceding layer + C = None + if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d): + C = new_layers[-2].out_channels + elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock): + C = new_layers[-2].conv[-1].out_channels + if C is not None: + for _ in range(how_many_each): + new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False)) + new_layers.append(nn.ReLU(inplace=True)) + return nn.Sequential(*new_layers) + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed (decoder-only).""" + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if key in sd and sd[key].shape[0] > new_sd[key].shape[0]: + sd[key] = sd[key][-new_sd[key].shape[0]:] + return sd + + def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None): + """Decode a sequence of frames from latents. + x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1]. + """ + trim_flag = self.mem[-8] is None # keeps original relative check + + if cond is not None: + x = torch.cat([self.pixel_shuffle(cond), x], dim=2) + + x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem) + + if trim_flag: + return x[:, self.frames_to_trim:] + return x + + def forward(self, *args, **kwargs): + raise NotImplementedError("Decoder-only model: call decode_video(...) instead.") + + def clean_mem(self): + self.mem = [None] * len(self.decoder) + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + +class TAEW2_1DiffusersWrapper(nn.Module): + def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]): + super().__init__() + self.dtype = torch.bfloat16 + self.device = "cuda" + self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype) + self.temperal_downsample = [True, True, False] # [sic] + self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16)) + + def decode(self, latents, return_dict=None): + n, c, t, h, w = latents.shape + return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),) + + def stream_decode_with_cond(self, latents, tiled=False, cond=None): + n, c, t, h, w = latents.shape + return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1) + + def clean_mem(self): + self.taehv.clean_mem() + +# ---------------------------- +# Simplified builder (no small, no transplant, no post-hoc deepening) +# ---------------------------- + +def build_tcdecoder(new_channels = [512, 256, 128, 128], + device="cuda", + dtype=torch.bfloat16, + new_latent_channels=None): + """ + 构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。 + - 不创建 small / 不做移植 + - base_ckpt_path 参数保留但不使用(接口兼容) + + 返回:big (单个模型) + """ + if new_latent_channels is not None: + big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train() + else: + big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train() + + big.clean_mem() + return big diff --git a/flashvsr_arch/models/__init__.py b/flashvsr_arch/models/__init__.py new file mode 100644 index 0000000..30499f3 --- /dev/null +++ b/flashvsr_arch/models/__init__.py @@ -0,0 +1 @@ +from .model_manager import * diff --git a/flashvsr_arch/models/model_manager.py b/flashvsr_arch/models/model_manager.py new file mode 100644 index 0000000..8dfa9b6 --- /dev/null +++ b/flashvsr_arch/models/model_manager.py @@ -0,0 +1,402 @@ +import os, torch, json, importlib +from typing import List + +from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs +from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix + +def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + #print(f" model_name: {model_name} model_class: {model_class.__name__}") + state_dict_converter = model_class.state_dict_converter() + if model_resource == "civitai": + state_dict_results = state_dict_converter.from_civitai(state_dict) + elif model_resource == "diffusers": + state_dict_results = state_dict_converter.from_diffusers(state_dict) + if isinstance(state_dict_results, tuple): + model_state_dict, extra_kwargs = state_dict_results + #print(f" This model is initialized with extra kwargs: {extra_kwargs}") + else: + model_state_dict, extra_kwargs = state_dict_results, {} + torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype + with init_weights_on_device(): + model = model_class(**extra_kwargs) + if hasattr(model, "eval"): + model = model.eval() + model.load_state_dict(model_state_dict, assign=True) + model = model.to(dtype=torch_dtype, device=device) + loaded_model_names.append(model_name) + loaded_models.append(model) + return loaded_model_names, loaded_models + + +def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: + model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() + else: + model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype) + if torch_dtype == torch.float16 and hasattr(model, "half"): + model = model.half() + try: + model = model.to(device=device) + except: + pass + loaded_model_names.append(model_name) + loaded_models.append(model) + return loaded_model_names, loaded_models + + +def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device): + #print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}") + base_state_dict = base_model.state_dict() + base_model.to("cpu") + del base_model + model = model_class(**extra_kwargs) + model.load_state_dict(base_state_dict, strict=False) + model.load_state_dict(state_dict, strict=False) + model.to(dtype=torch_dtype, device=device) + return model + + +def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + while True: + for model_id in range(len(model_manager.model)): + base_model_name = model_manager.model_name[model_id] + if base_model_name == model_name: + base_model_path = model_manager.model_path[model_id] + base_model = model_manager.model[model_id] + print(f" Adding patch model to {base_model_name} ({base_model_path})") + patched_model = load_single_patch_model_from_single_file( + state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device) + loaded_model_names.append(base_model_name) + loaded_models.append(patched_model) + model_manager.model.pop(model_id) + model_manager.model_path.pop(model_id) + model_manager.model_name.pop(model_id) + break + else: + break + return loaded_model_names, loaded_models + + + +class ModelDetectorTemplate: + def __init__(self): + pass + + def match(self, file_path="", state_dict={}): + return False + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + return [], [] + + + +class ModelDetectorFromSingleFile: + def __init__(self, model_loader_configs=[]): + self.keys_hash_with_shape_dict = {} + self.keys_hash_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource): + self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource) + if keys_hash is not None: + self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource) + + + def match(self, file_path="", state_dict={}): + if isinstance(file_path, str) and os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + return True + keys_hash = hash_state_dict_keys(state_dict, with_shape=False) + if keys_hash in self.keys_hash_dict: + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + + # Load models with strict matching + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape] + loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) + return loaded_model_names, loaded_models + + # Load models without strict matching + # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture) + keys_hash = hash_state_dict_keys(state_dict, with_shape=False) + if keys_hash in self.keys_hash_dict: + model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash] + loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) + return loaded_model_names, loaded_models + + return loaded_model_names, loaded_models + + + +class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile): + def __init__(self, model_loader_configs=[]): + super().__init__(model_loader_configs) + + + def match(self, file_path="", state_dict={}): + if isinstance(file_path, str) and os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + splited_state_dict = split_state_dict_with_prefix(state_dict) + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + # Split the state_dict and load from each component + splited_state_dict = split_state_dict_with_prefix(state_dict) + valid_state_dict = {} + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + valid_state_dict.update(sub_state_dict) + if super().match(file_path, valid_state_dict): + loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype) + else: + loaded_model_names, loaded_models = [], [] + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelDetectorFromHuggingfaceFolder: + def __init__(self, model_loader_configs=[]): + self.architecture_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture): + self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture) + + + def match(self, file_path="", state_dict={}): + if not isinstance(file_path, str) or os.path.isfile(file_path): + return False + file_list = os.listdir(file_path) + if "config.json" not in file_list: + return False + with open(os.path.join(file_path, "config.json"), "r") as f: + config = json.load(f) + if "architectures" not in config and "_class_name" not in config: + return False + return True + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + with open(os.path.join(file_path, "config.json"), "r") as f: + config = json.load(f) + loaded_model_names, loaded_models = [], [] + architectures = config["architectures"] if "architectures" in config else [config["_class_name"]] + for architecture in architectures: + huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture] + if redirected_architecture is not None: + architecture = redirected_architecture + model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture) + loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelDetectorFromPatchedSingleFile: + def __init__(self, model_loader_configs=[]): + self.keys_hash_with_shape_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs): + self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs) + + + def match(self, file_path="", state_dict={}): + if not isinstance(file_path, str) or os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs): + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + + # Load models with strict matching + loaded_model_names, loaded_models = [], [] + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape] + loaded_model_names_, loaded_models_ = load_patch_model_from_single_file( + state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelManager: + def __init__( + self, + torch_dtype=torch.float16, + device="cuda", + file_path_list: List[str] = [], + ): + self.torch_dtype = torch_dtype + self.device = device + self.model = [] + self.model_path = [] + self.model_name = [] + self.model_detector = [ + ModelDetectorFromSingleFile(model_loader_configs), + ModelDetectorFromSplitedSingleFile(model_loader_configs), + ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs), + ModelDetectorFromPatchedSingleFile(patch_model_loader_configs), + ] + self.load_models(file_path_list) + + + def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None): + print(f"Loading models from file: {file_path}") + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + + + def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]): + print(f"Loading models from folder: {file_path}") + model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + + + def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}): + print(f"Loading patch models from file: {file_path}") + model_names, models = load_patch_model_from_single_file( + state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + print(f" The following patched models are loaded: {model_names}.") + + + def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): + if isinstance(file_path, list): + for file_path_ in file_path: + self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) + else: + print(f"Loading LoRA models from file: {file_path}") + is_loaded = False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): + for lora in get_lora_loaders(): + match_results = lora.match(model, state_dict) + if match_results is not None: + print(f" Adding LoRA to {model_name} ({model_path}).") + lora_prefix, model_resource = match_results + lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + is_loaded = True + break + if not is_loaded: + print(f" Cannot load LoRA: {file_path}") + + + def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): + #print(f"Loading models from: {file_path}") + if device is None: device = self.device + if torch_dtype is None: torch_dtype = self.torch_dtype + if isinstance(file_path, list): + state_dict = {} + for path in file_path: + state_dict.update(load_state_dict(path)) + elif os.path.isfile(file_path): + state_dict = load_state_dict(file_path) + else: + state_dict = None + for model_detector in self.model_detector: + if model_detector.match(file_path, state_dict): + model_names, models = model_detector.load( + file_path, state_dict, + device=device, torch_dtype=torch_dtype, + allowed_model_names=model_names, model_manager=self + ) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + break + else: + print(f" We cannot detect the model type. No models are loaded.") + + + def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None): + for file_path in file_path_list: + self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) + + + def fetch_model(self, model_name, file_path=None, require_model_path=False): + fetched_models = [] + fetched_model_paths = [] + for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): + if file_path is not None and file_path != model_path: + continue + if model_name == model_name_: + fetched_models.append(model) + fetched_model_paths.append(model_path) + if len(fetched_models) == 0: + #print(f"No {model_name} models available.") + return None + if len(fetched_models) == 1: + print(f"Using {model_name} from {fetched_model_paths[0]}") + else: + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}") + if require_model_path: + return fetched_models[0], fetched_model_paths[0] + else: + return fetched_models[0] + + + def to(self, device): + for model in self.model: + model.to(device) + diff --git a/flashvsr_arch/models/utils.py b/flashvsr_arch/models/utils.py new file mode 100644 index 0000000..00b5313 --- /dev/null +++ b/flashvsr_arch/models/utils.py @@ -0,0 +1,360 @@ +import torch, os, gc +from safetensors import safe_open +from contextlib import contextmanager +from einops import rearrange, repeat +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import time +import hashlib + +CACHE_T = 2 + +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + +def load_state_dict_from_folder(file_path, torch_dtype=None): + state_dict = {} + for file_name in os.listdir(file_path): + if "." in file_name and file_name.split(".")[-1] in [ + "safetensors", "bin", "ckpt", "pth", "pt" + ]: + state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) + return state_dict + + +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): + state_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu", weights_only=True) + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def search_for_embeddings(state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += search_for_embeddings(state_dict[k]) + return embeddings + + +def search_parameter(param, state_dict): + for name, param_ in state_dict.items(): + if param.numel() == param_.numel(): + if param.shape == param_.shape: + if torch.dist(param, param_) < 1e-3: + return name + else: + if torch.dist(param.flatten(), param_.flatten()) < 1e-3: + return name + return None + + +def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): + matched_keys = set() + with torch.no_grad(): + for name in source_state_dict: + rename = search_parameter(source_state_dict[name], target_state_dict) + if rename is not None: + print(f'"{name}": "{rename}",') + matched_keys.add(rename) + elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: + length = source_state_dict[name].shape[0] // 3 + rename = [] + for i in range(3): + rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) + if None not in rename: + print(f'"{name}": {rename},') + for rename_ in rename: + matched_keys.add(rename_) + for name in target_state_dict: + if name not in matched_keys: + print("Cannot find", name, target_state_dict[name].shape) + + +def search_for_files(folder, extensions): + files = [] + if os.path.isdir(folder): + for file in sorted(os.listdir(folder)): + files += search_for_files(os.path.join(folder, file), extensions) + elif os.path.isfile(folder): + for extension in extensions: + if folder.endswith(extension): + files.append(folder) + break + return files + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def split_state_dict_with_prefix(state_dict): + keys = sorted([key for key in state_dict if isinstance(key, str)]) + prefix_dict = {} + for key in keys: + prefix = key if "." not in key else key.split(".")[0] + if prefix not in prefix_dict: + prefix_dict[prefix] = [] + prefix_dict[prefix].append(key) + state_dicts = [] + for prefix, keys in prefix_dict.items(): + sub_state_dict = {key: state_dict[key] for key in keys} + state_dicts.append(sub_state_dict) + return state_dicts + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + +def clean_vram(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if torch.mps.is_available(): + torch.mps.empty_cache() + +def get_device_list(): + devs = [] + if torch.cuda.is_available(): + devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] + + if torch.mps.is_available(): + devs += [f"mps:{i}" for i in range(torch.mps.device_count())] + + return devs + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + # print(cache_x.shape, x.shape) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + # print('cache!') + x = F.pad(x, padding, mode='replicate') # mode='replicate' + # print(x[0,0,:,0,0]) + + return super().forward(x) + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + + def forward(self, x): + # x: (B, C, F, H, W) + return rearrange(x, + 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w', + ff=self.ff, hh=self.hh, ww=self.ww) + +class Buffer_LQ4x_Proj(nn.Module): + + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + # x: (B, C, F, H, W) + + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + # print(video.shape) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + if i == 0: + continue + x = self.conv2(x, self.cache['conv2']) + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim = 2) + # print(out_x.shape) + out_x = rearrange(out_x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache['conv1'] = None + self.cache['conv2'] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + # self.clear_cache() + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + x = self.conv2(x, self.cache['conv2']) + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs + \ No newline at end of file diff --git a/flashvsr_arch/models/wan_video_dit.py b/flashvsr_arch/models/wan_video_dit.py new file mode 100644 index 0000000..cc06947 --- /dev/null +++ b/flashvsr_arch/models/wan_video_dit.py @@ -0,0 +1,834 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import random +import os +import time +from typing import Tuple, Optional, List +from einops import rearrange +from .utils import hash_state_dict_keys + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + from sageattn.core import sparse_sageattn + SPARSE_SAGE_AVAILABLE = True +except ModuleNotFoundError: + SPARSE_SAGE_AVAILABLE = False + sparse_sageattn = None +from PIL import Image +import numpy as np + + +# ---------------------------- +# Local / window masks +# ---------------------------- +@torch.no_grad() +def build_local_block_mask_shifted_vec(block_h: int, + block_w: int, + win_h: int = 6, + win_w: int = 6, + include_self: bool = True, + device=None) -> torch.Tensor: + device = device or torch.device("cpu") + H, W = block_h, block_w + r = torch.arange(H, device=device) + c = torch.arange(W, device=device) + YY, XX = torch.meshgrid(r, c, indexing="ij") + r_all = YY.reshape(-1) + c_all = XX.reshape(-1) + r_half = win_h // 2 + c_half = win_w // 2 + start_r = torch.clamp(r_all - r_half, 0, H - win_h) + end_r = start_r + win_h - 1 + start_c = torch.clamp(c_all - c_half, 0, W - win_w) + end_c = start_c + win_w - 1 + in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) + in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) + mask = in_row & in_col + if not include_self: + mask.fill_diagonal_(False) + return mask + +@torch.no_grad() +def build_local_block_mask_shifted_vec_normal_slide(block_h: int, + block_w: int, + win_h: int = 6, + win_w: int = 6, + include_self: bool = True, + device=None) -> torch.Tensor: + device = device or torch.device("cpu") + H, W = block_h, block_w + r = torch.arange(H, device=device) + c = torch.arange(W, device=device) + YY, XX = torch.meshgrid(r, c, indexing="ij") + r_all = YY.reshape(-1) + c_all = XX.reshape(-1) + r_half = win_h // 2 + c_half = win_w // 2 + start_r = r_all - r_half + end_r = start_r + win_h - 1 + start_c = c_all - c_half + end_c = start_c + win_w - 1 + in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) + in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) + mask = in_row & in_col + if not include_self: + mask.fill_diagonal_(False) + return mask + + +class WindowPartition3D: + """Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C).""" + @staticmethod + def partition(x: torch.Tensor, win: Tuple[int, int, int]): + B, F, H, W, C = x.shape + wf, wh, ww = win + assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size." + x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous() + return x.view(-1, wf * wh * ww, C) + + @staticmethod + def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]): + F, H, W = orig + wf, wh, ww = win + nf, nh, nw = F // wf, H // wh, W // ww + B = windows.size(0) // (nf * nh * nw) + x = windows.view(B, nf, nh, nw, wf, wh, ww, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous() + return x.view(B, F, H, W, -1) + + +@torch.no_grad() +def generate_draft_block_mask(batch_size, nheads, seqlen, + q_w, k_w, topk=10, local_attn_mask=None): + assert batch_size == 1, "Only batch_size=1 supported for now" + assert local_attn_mask is not None, "local_attn_mask must be provided" + avgpool_q = torch.mean(q_w, dim=1) + avgpool_k = torch.mean(k_w, dim=1) + avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads) + avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads) + q_heads = avgpool_q.permute(1, 0, 2) + k_heads = avgpool_k.permute(1, 0, 2) + D = avgpool_q.shape[-1] + scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D) + + repeat_head = scores.shape[0] + repeat_len = scores.shape[1] // local_attn_mask.shape[0] + repeat_num = scores.shape[2] // local_attn_mask.shape[1] + local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) + local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)') + local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) + local_attn_mask = local_attn_mask.to(torch.float32) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf')) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) + scores = scores + local_attn_mask + + attn_map = torch.softmax(scores, dim=-1) + attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) + loop_num, s1, s2 = attn_map.shape + flat = attn_map.reshape(loop_num, -1) + n = flat.shape[1] + apply_topk = min(flat.shape[1]-1, topk) + thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] + thresholds = thresholds.unsqueeze(1) + mask_new = (flat > thresholds).reshape(loop_num, s1, s2) + mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note + # 修正:上行变量名统一 + # mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new + mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) + mask = mask.repeat_interleave(2, dim=-1) + return mask + + +@torch.no_grad() +def generate_draft_block_mask_refined(batch_size, nheads, seqlen, + q_w, k_w, topk=10, local_attn_mask=None): + assert batch_size == 1, "Only batch_size=1 supported for now" + assert local_attn_mask is not None, "local_attn_mask must be provided" + + avgpool_q = torch.mean(q_w, dim=1) + avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads) + q_heads = avgpool_q.permute(1, 0, 2) + + k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2]) + avgpool_k_split = torch.mean(k_w_split, dim=2) + avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2) + avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads) + k_heads = avgpool_k_refined.permute(1, 0, 2) + + D = avgpool_q.shape[-1] + scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D) + + repeat_head = scores.shape[0] + num_q_blocks_local = local_attn_mask.shape[0] + num_k_blocks_local = local_attn_mask.shape[1] + + local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1) + + repeat_len = scores.shape[1] // local_attn_mask.shape[0] + repeat_num = scores.shape[2] // local_attn_mask.shape[1] + + local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) + + local_attn_mask = local_attn_mask.to(torch.float32) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf')) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) + + assert scores.shape == local_attn_mask.shape + + scores = scores + local_attn_mask + attn_map = torch.softmax(scores, dim=-1) + attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) # it=seqlen可能需要调整,取决于seqlen的含义 + loop_num, s1, s2 = attn_map.shape + flat = attn_map.reshape(loop_num, -1) + apply_topk = min(flat.shape[1]-1, topk) + thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] + thresholds = thresholds.unsqueeze(1) + mask_new = (flat > thresholds).reshape(loop_num, s1, s2) + mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) + + mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + return mask + + +# ---------------------------- +# Attention kernels +# ---------------------------- +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True): + if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE: + seqlen = q.shape[1] + seqlen_kv = k.shape[1] + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + base_blockmask = attention_mask + x = sparse_sageattn( + q, k, v, + mask_id=base_blockmask.to(torch.int8), + is_causal=False, + tensor_layout="HND" + ) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + half_dim = max(dim // 2, 1) + scale = torch.arange(half_dim, dtype=torch.float32, device=position.device) + inv_freq = torch.pow(10000.0, -scale / half_dim) + sinusoid = torch.outer(position.to(torch.float32), inv_freq) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + half_dim = max(dim // 2, 1) + base = torch.arange(0, dim, 2, dtype=torch.float32)[:half_dim] + freqs = torch.pow(theta, -base / max(dim, 1)) + steps = torch.arange(end, dtype=torch.float32) + angles = torch.outer(steps, freqs) + return torch.polar(torch.ones_like(angles), angles) + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + orig_dtype = x.dtype + work_dtype = torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype + reshaped = x.to(work_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + x_complex = torch.view_as_complex(reshaped) + freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device) + x_out = torch.view_as_real(x_complex * freqs).flatten(2) + return x_out.to(orig_dtype) + + +# ---------------------------- +# Norms & Blocks +# ---------------------------- +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads, enable_sageattention=True): + super().__init__() + self.num_heads = num_heads + self.enable_sageattention = enable_sageattention + + def forward(self, q, k, v, attention_mask=None): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention) + self.local_attn_mask = None + + def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None, + train_img=False, block_id=None, kv_len=None, is_full_block=False, + is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9): + B, L, D = x.shape + if is_stream and pre_cache_k is not None and pre_cache_v is not None: + assert f==2, "f must be 2" + if is_stream and (pre_cache_k is None or pre_cache_v is None): + assert f==6, " start f must be 6" + assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)." + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + + win = (2, 8, 8) + q = q.view(B, f, h, w, D) + k = k.view(B, f, h, w, D) + v = v.view(B, f, h, w, D) + + q_w = WindowPartition3D.partition(q, win) + k_w = WindowPartition3D.partition(k, win) + v_w = WindowPartition3D.partition(v, win) + + seqlen = f//win[0] + one_len = k_w.shape[0] // B // seqlen + if pre_cache_k is not None and pre_cache_v is not None: + k_w = torch.cat([pre_cache_k, k_w], dim=0) + v_w = torch.cat([pre_cache_v, v_w], dim=0) + + block_n = q_w.shape[0] // B + block_s = q_w.shape[1] + block_n_kv = k_w.shape[0] // B + + reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s) + reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s) + reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s) + + window_size = win[0]*h*w//128 + + if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range: + self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device) + self.local_attn_mask_h = h//8 + self.local_attn_mask_w = w//8 + self.local_range = local_range + attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) + + x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask) + + cur_block_n, cur_block_s, _ = k_w.shape + cache_num = cur_block_n // one_len + if cache_num > kv_len: + cache_k = k_w[one_len:, :, :] + cache_v = v_w[one_len:, :, :] + else: + cache_k = k_w + cache_v = v_w + + x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s) + x = WindowPartition3D.reverse(x, win, (f, h, w)) + x = x.view(B, f*h*w, D) + + if is_stream: + return self.o(x), cache_k, cache_v + return self.o(x) + + +class CrossAttention(nn.Module): + """ + 仅考虑文本 context;提供持久 KV 缓存。 + """ + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads, enable_sageattention=False) + + # 持久缓存 + self.cache_k = None + self.cache_v = None + + @torch.no_grad() + def init_cache(self, ctx: torch.Tensor): + """ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文""" + self.cache_k = self.norm_k(self.k(ctx)) + self.cache_v = self.v(ctx) + + def clear_cache(self): + self.cache_k = None + self.cache_v = None + + def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False): + """ + y 即文本上下文(未做其他分支)。 + """ + q = self.norm_q(self.q(x)) + assert self.cache_k is not None and self.cache_v is not None + k = self.cache_k + v = self.cache_v + + x = self.attn(q, k, v) + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + + +class DiTBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention) + self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False) + + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None, + train_img=False, block_id=None, kv_len=None, is_full_block=False, + is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn( + input_x, freqs, f, h, w, local_num, topk, train_img, block_id, + kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream, + pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range) + + x = self.gate(x, gate_msa, self_attn_output) + x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + if is_stream: + return x, self_attn_cache_k, self_attn_cache_v + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +# ---------------------------- +# WanModel (no image branch) — init 时即产生 KV 缓存 +# ---------------------------- +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool = False, + enable_sageattention: bool = True, + ): + super().__init__() + self.dim = dim + self.freq_dim = freq_dim + self.patch_size = patch_size + + # patch embed + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + # text / time embed + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList([ + DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + self._cross_kv_initialized = False + + # 可选:手动清空 / 重新初始化 + # 可选:手动清空 / 重新初始化 + def clear_cross_kv(self): + for blk in self.blocks: + blk.cross_attn.clear_cache() + self._cross_kv_initialized = False + + @torch.no_grad() + def reinit_cross_kv(self, new_context: torch.Tensor): + ctx_txt = self.text_embedding(new_context) + for blk in self.blocks: + blk.cross_attn.init_cache(ctx_txt) + self._cross_kv_initialized = True + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + LQ_latents: Optional[List[torch.Tensor]] = None, + train_img: bool = False, + topk_ratio: Optional[float] = None, + kv_ratio: Optional[float] = None, + local_num: Optional[int] = None, + is_full_block: bool = False, + causal_idx: Optional[int] = None, + **kwargs, + ): + # time / text embeds + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + + # 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它) + # context = self.text_embedding(context) + + # 输入打补丁 + x, (f, h, w) = self.patchify(x) + B = x.shape[0] + + # window / masks 超参 + win = (2, 8, 8) + seqlen = f//win[0] + if local_num is None: + local_random = random.random() + if local_random < 0.3: + local_num = seqlen - 3 + elif local_random < 0.4: + local_num = seqlen - 4 + elif local_random < 0.5: + local_num = seqlen - 2 + else: + local_num = seqlen + + window_size = win[0]*h*w//128 + square_num = window_size*window_size + topk_ratio = 2.0 + topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1) + + if kv_ratio is None: + kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2 + kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1) + + decay_ratio = random.uniform(0.7, 1.0) + + # RoPE 3D + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # blocks + for block_id, block in enumerate(self.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x += LQ_latents[block_id] + + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + @staticmethod + def state_dict_converter(): + return WanModelStateDictConverter() + + +# ---------------------------- +# State dict converter(保持原映射;已忽略 has_image_input 使用) +# ---------------------------- +class WanModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict_, config + + def from_civitai(self, state_dict): + state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} + # 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支 + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False} + else: + config = {} + return state_dict, config + \ No newline at end of file diff --git a/flashvsr_arch/models/wan_video_vae.py b/flashvsr_arch/models/wan_video_vae.py new file mode 100644 index 0000000..2490177 --- /dev/null +++ b/flashvsr_arch/models/wan_video_vae.py @@ -0,0 +1,847 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + # print('cache_x.shape', cache_x.shape, 'x.shape', x.shape) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if check_is_instance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + + def stream_decode(self, z, scale): + # self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # print('self._feat_map', len(self._feat_map)) + # cache encode + if self.encoder is not None: + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + # print('self._enc_feat_map', len(self._enc_feat_map)) + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16, dim=96): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * 8, tile_size[1] * 8) + tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos + + def clear_cache(self): + self.model.clear_cache() + + def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state for hidden_state in hidden_states] + assert len(hidden_states) == 1 + hidden_state = hidden_states[0] + video = self.model.stream_decode(hidden_state, self.scale) + return video + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ diff --git a/flashvsr_arch/pipelines/__init__.py b/flashvsr_arch/pipelines/__init__.py new file mode 100644 index 0000000..21d4390 --- /dev/null +++ b/flashvsr_arch/pipelines/__init__.py @@ -0,0 +1,3 @@ +from .flashvsr_full import FlashVSRFullPipeline +from .flashvsr_tiny import FlashVSRTinyPipeline +from .flashvsr_tiny_long import FlashVSRTinyLongPipeline \ No newline at end of file diff --git a/flashvsr_arch/pipelines/base.py b/flashvsr_arch/pipelines/base.py new file mode 100644 index 0000000..1fe3689 --- /dev/null +++ b/flashvsr_arch/pipelines/base.py @@ -0,0 +1,127 @@ +import torch +import numpy as np +from PIL import Image +from torchvision.transforms import GaussianBlur + + + +class BasePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64): + super().__init__() + self.device = device + self.torch_dtype = torch_dtype + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.cpu_offload = False + self.model_names = [] + + + def check_resize_height_width(self, height, width): + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.") + return height, width + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def preprocess_images(self, images): + return [self.preprocess_image(image) for image in images] + + + def vae_output_to_image(self, vae_output): + image = vae_output[0].cpu().float().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def vae_output_to_video(self, vae_output): + video = vae_output.cpu().permute(1, 2, 0).numpy() + video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] + return video + + + def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): + if len(latents) > 0: + blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) + height, width = value.shape[-2:] + weight = torch.ones_like(value) + for latent, mask, scale in zip(latents, masks, scales): + mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) + mask = blur(mask) + value += latent * mask * scale + weight += mask * scale + value /= weight + return value + + + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): + if special_kwargs is None: + noise_pred_global = inference_callback(prompt_emb_global) + else: + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + if special_local_kwargs_list is None: + noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + else: + noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)] + noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) + return noise_pred + + + def extend_prompt(self, prompt, local_prompts, masks, mask_scales): + local_prompts = local_prompts or [] + masks = masks or [] + mask_scales = mask_scales or [] + extended_prompt_dict = self.prompter.extend_prompt(prompt) + prompt = extended_prompt_dict.get("prompt", prompt) + local_prompts += extended_prompt_dict.get("prompts", []) + masks += extended_prompt_dict.get("masks", []) + mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) + return prompt, local_prompts, masks, mask_scales + + + def enable_cpu_offload(self): + self.cpu_offload = True + + + def load_models_to_device(self, loadmodel_names=[]): + # only load models to device if cpu_offload is enabled + if not self.cpu_offload: + return + # offload the unneeded models to cpu + for model_name in self.model_names: + if model_name not in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() + # load the needed models to device + for model_name in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) + # fresh the cuda cache + torch.cuda.empty_cache() + + + def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): + generator = None if seed is None else torch.Generator(device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + return noise diff --git a/flashvsr_arch/pipelines/flashvsr_full.py b/flashvsr_arch/pipelines/flashvsr_full.py new file mode 100644 index 0000000..e26fda0 --- /dev/null +++ b/flashvsr_arch/pipelines/flashvsr_full.py @@ -0,0 +1,636 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models.utils import clean_vram +from ..models import ModelManager +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + +# ----------------------------- +# Safetensors support --------- +# ----------------------------- +st_load_file = None # Define the variable in global scope first +try: + from safetensors.torch import load_file as st_load_file +except ImportError: + # st_load_file remains None if import fails + print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.") + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRFullPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + + # --- Safetensors loading logic added here --- + prompt_path_lower = prompt_path.lower() + if prompt_path_lower.endswith(".safetensors"): + if st_load_file is None: + raise ImportError("The 'safetensors' library must be installed to load .safetensors files.") + + # Load the tensor from safetensors + loaded_dict = st_load_file(prompt_path, device=self.device) + + # Safetensors loads a dict. Assuming the context tensor is the only or primary key. + if len(loaded_dict) == 1: + ctx = list(loaded_dict.values())[0] + elif 'context' in loaded_dict: # Common key for text context + ctx = loaded_dict['context'] + else: + raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).") + + else: + # Default behavior for .pth, .pt, etc. + ctx = torch.load(prompt_path, map_location=self.device) + + # -------------------------------------------- + # ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-1.3B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + skip_vae = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + latents_total = [] + self.vae.clear_cache() + + if unload_dit and hasattr(self, 'dit') and self.dit is not None: + current_dit_device = next(iter(self.dit.parameters())).device + if str(current_dit_device) != str(self.device): + print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...") + self.dit.to(self.device) + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + latents_total.append(cur_latents) + + if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + try: + del pre_cache_k, pre_cache_v + except NameError: + pass + print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") + self.dit.to('cpu') + clean_vram() + + latents = torch.cat(latents_total, dim=2) + + del latents_total + clean_vram() + + if skip_vae: + return latents + + # Decode + print("[FlashVSR] Starting VAE decoding...") + frames = self.decode_video(latents, **tiler_kwargs) + + # 颜色校正(wavelet) + try: + if color_fix: + frames = self.ColorCorrector( + frames.to(device=LQ_video.device), + LQ_video[:, :, :frames.shape[2], :, :], + clip_range=(-1, 1), + chunk_size=16, + method='adain' + ) + except: + pass + + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/flashvsr_arch/pipelines/flashvsr_tiny.py b/flashvsr_arch/pipelines/flashvsr_tiny.py new file mode 100644 index 0000000..60ed390 --- /dev/null +++ b/flashvsr_arch/pipelines/flashvsr_tiny.py @@ -0,0 +1,633 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models.utils import clean_vram +from ..models import ModelManager +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + +# ----------------------------- +# Safetensors support --------- +# ----------------------------- +st_load_file = None # Define the variable in global scope first +try: + from safetensors.torch import load_file as st_load_file +except ImportError: + # st_load_file remains None if import fails + print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.") + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRTinyPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None, + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + + # --- Safetensors loading logic added here --- + prompt_path_lower = prompt_path.lower() + if prompt_path_lower.endswith(".safetensors"): + if st_load_file is None: + raise ImportError("The 'safetensors' library must be installed to load .safetensors files.") + + # Load the tensor from safetensors + loaded_dict = st_load_file(prompt_path, device=self.device) + + # Safetensors loads a dict. Assuming the context tensor is the only or primary key. + if len(loaded_dict) == 1: + ctx = list(loaded_dict.values())[0] + elif 'context' in loaded_dict: # Common key for text context + ctx = loaded_dict['context'] + else: + raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).") + + else: + # Default behavior for .pth, .pt, etc. + ctx = torch.load(prompt_path, map_location=self.device) + + # -------------------------------------------- + # ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + def decode_video(self, latents, cond=None, **kwargs): + frames = self.TCDecoder.decode_video( + latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W) + parallel=False, + show_progress_bar=False, + cond=cond + ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1 + + return frames + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-14B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + skip_vae = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + # 尺寸修正 + height, width = self.check_resize_height_width(height, width) + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + latents_total = [] + self.TCDecoder.clean_mem() + LQ_pre_idx = 0 + LQ_cur_idx = 0 + + if unload_dit and hasattr(self, 'dit') and self.dit is not None: + current_dit_device = next(iter(self.dit.parameters())).device + if str(current_dit_device) != str(self.device): + print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...") + self.dit.to(self.device) + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + latents_total.append(cur_latents) + LQ_pre_idx = LQ_cur_idx + + if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + try: + del pre_cache_k, pre_cache_v + except NameError: + pass + print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") + self.dit.to('cpu') + clean_vram() + + latents = torch.cat(latents_total, dim=2) + + del latents_total + clean_vram() + + if skip_vae: + return latents + + # Decode + print("[FlashVSR] Starting VAE decoding...") + frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1) + + # 颜色校正(wavelet) + try: + if color_fix: + frames = self.ColorCorrector( + frames.to(device=LQ_video.device), + LQ_video[:, :, :frames.shape[2], :, :], + clip_range=(-1, 1), + chunk_size=16, + method='adain' + ) + except: + pass + + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/flashvsr_arch/pipelines/flashvsr_tiny_long.py b/flashvsr_arch/pipelines/flashvsr_tiny_long.py new file mode 100644 index 0000000..e89e8f3 --- /dev/null +++ b/flashvsr_arch/pipelines/flashvsr_tiny_long.py @@ -0,0 +1,619 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models.utils import clean_vram +from ..models import ModelManager +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + +# ----------------------------- +# Safetensors support --------- +# ----------------------------- +st_load_file = None # Define the variable in global scope first +try: + from safetensors.torch import load_file as st_load_file +except ImportError: + # st_load_file remains None if import fails + print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.") + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRTinyLongPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None, + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + + # --- Safetensors loading logic added here --- + prompt_path_lower = prompt_path.lower() + if prompt_path_lower.endswith(".safetensors"): + if st_load_file is None: + raise ImportError("The 'safetensors' library must be installed to load .safetensors files.") + + # Load the tensor from safetensors + loaded_dict = st_load_file(prompt_path, device=self.device) + + # Safetensors loads a dict. Assuming the context tensor is the only or primary key. + if len(loaded_dict) == 1: + ctx = list(loaded_dict.values())[0] + elif 'context' in loaded_dict: # Common key for text context + ctx = loaded_dict['context'] + else: + raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).") + + else: + # Default behavior for .pth, .pt, etc. + ctx = torch.load(prompt_path, map_location=self.device) + + # -------------------------------------------- + # ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + def decode_video(self, latents, cond=None, **kwargs): + frames = self.TCDecoder.decode_video( + latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W) + parallel=False, + show_progress_bar=False, + cond=cond + ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1 + + return frames + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-1.3B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + skip_vae = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + # 尺寸修正 + height, width = self.check_resize_height_width(height, width) + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + frames_total = [] + LQ_pre_idx = 0 + LQ_cur_idx = 0 + self.TCDecoder.clean_mem() + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device) + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device) + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + + # Decode + cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device) + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1) + + # 颜色校正(wavelet) + try: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: + pass + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + del cur_frames, cur_latents, cur_LQ_frame + clean_vram() + + frames = torch.cat(frames_total, dim=2) + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/flashvsr_arch/schedulers/__init__.py b/flashvsr_arch/schedulers/__init__.py new file mode 100644 index 0000000..f349a86 --- /dev/null +++ b/flashvsr_arch/schedulers/__init__.py @@ -0,0 +1 @@ +from .flow_match import FlowMatchScheduler diff --git a/flashvsr_arch/schedulers/flow_match.py b/flashvsr_arch/schedulers/flow_match.py new file mode 100644 index 0000000..bd0cb08 --- /dev/null +++ b/flashvsr_arch/schedulers/flow_match.py @@ -0,0 +1,79 @@ +import torch + + + +class FlowMatchScheduler(): + + def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps) + + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None): + if shift is not None: + self.shift = shift + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + if self.extra_one_step: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/flashvsr_arch/vram_management/__init__.py b/flashvsr_arch/vram_management/__init__.py new file mode 100644 index 0000000..5bf38c8 --- /dev/null +++ b/flashvsr_arch/vram_management/__init__.py @@ -0,0 +1 @@ +from .layers import * diff --git a/flashvsr_arch/vram_management/layers.py b/flashvsr_arch/vram_management/layers.py new file mode 100644 index 0000000..fafaef7 --- /dev/null +++ b/flashvsr_arch/vram_management/layers.py @@ -0,0 +1,95 @@ +import torch, copy +from ..models.utils import init_weights_on_device + + +def cast_to(weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + +class AutoWrappedModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + super().__init__() + self.module = module.to(dtype=offload_dtype, device=offload_device) + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + module = self.module + else: + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + return module(*args, **kwargs) + + +class AutoWrappedLinear(torch.nn.Linear): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + with init_weights_on_device(device=torch.device("meta")): + super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) + self.weight = module.weight + self.bias = module.bias + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, x, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + else: + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + return torch.nn.functional.linear(x, weight, bias) + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): + for name, module in model.named_children(): + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + num_param = sum(p.numel() for p in module.parameters()) + if max_num_param is not None and total_num_param + num_param > max_num_param: + module_config_ = overflow_module_config + else: + module_config_ = module_config + module_ = target_module(module, **module_config_) + setattr(model, name, module_) + total_num_param += num_param + break + else: + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) + return total_num_param + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) + model.vram_management_enabled = True + diff --git a/inference.py b/inference.py index c9d71c9..a4809a2 100644 --- a/inference.py +++ b/inference.py @@ -1,8 +1,11 @@ import logging +import math +import os from functools import partial import torch import torch.nn as nn +import torch.nn.functional as F from .bim_vfi_arch import BiMVFI from .ema_vfi_arch import feature_extractor as ema_feature_extractor @@ -621,3 +624,241 @@ class GIMMVFIModel: results.append(torch.clamp(unpadded, 0, 1)) return results + + +# --------------------------------------------------------------------------- +# FlashVSR model wrapper (4x video super-resolution) +# --------------------------------------------------------------------------- + +class FlashVSRModel: + """Inference wrapper for FlashVSR diffusion-based video super-resolution. + + Supports three pipeline modes: + - full: Standard VAE decode, highest quality + - tiny: TCDecoder decode, faster + - tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos + """ + + # Minimum input frame count required by the pipeline + MIN_FRAMES = 21 + + def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16): + from safetensors.torch import load_file + from .flashvsr_arch import ( + ModelManager, FlashVSRFullPipeline, + FlashVSRTinyPipeline, FlashVSRTinyLongPipeline, + ) + from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj + from .flashvsr_arch.models.TCDecoder import build_tcdecoder + + self.mode = mode + self.device = device + self.dtype = dtype + + dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors") + vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors") + lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors") + tcd_path = os.path.join(model_dir, "TCDecoder.safetensors") + prompt_path = os.path.join(model_dir, "Prompt.safetensors") + + mm = ModelManager(torch_dtype=dtype, device="cpu") + + if mode == "full": + mm.load_models([dit_path, vae_path]) + self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device) + self.pipe.vae.model.encoder = None + self.pipe.vae.model.conv1 = None + else: + mm.load_models([dit_path]) + Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline + self.pipe = Pipeline.from_model_manager(mm, device=device) + self.pipe.TCDecoder = build_tcdecoder( + [512, 256, 128, 128], device, dtype, 16 + 768, + ) + self.pipe.TCDecoder.load_state_dict( + load_file(tcd_path, device=device), strict=False, + ) + self.pipe.TCDecoder.clean_mem() + + # LQ frame projection + self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype) + if os.path.exists(lq_path): + lq_sd = load_file(lq_path, device="cpu") + cleaned = {} + for k, v in lq_sd.items(): + cleaned[k.removeprefix("LQ_proj_in.")] = v + self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True) + self.pipe.denoising_model().LQ_proj_in.to(device) + + self.pipe.to(device, dtype) + self.pipe.enable_vram_management(num_persistent_param_in_dit=None) + self.pipe.init_cross_kv(prompt_path=prompt_path) + self.pipe.load_models_to_device([]) # offload to CPU + + def to(self, device): + self.device = device + self.pipe.device = device + return self + + def load_to_device(self): + """Load models to the compute device for inference.""" + names = ["dit", "vae"] if self.mode == "full" else ["dit"] + self.pipe.load_models_to_device(names) + + def offload(self): + """Offload models to CPU.""" + self.pipe.load_models_to_device([]) + + def clear_caches(self): + if hasattr(self.pipe.denoising_model(), "LQ_proj_in"): + self.pipe.denoising_model().LQ_proj_in.clear_cache() + if hasattr(self.pipe, "vae") and self.pipe.vae is not None: + self.pipe.vae.clear_cache() + + # ------------------------------------------------------------------ + # Frame preprocessing / postprocessing helpers + # ------------------------------------------------------------------ + + @staticmethod + def _compute_dims(w, h, scale, align=128): + sw, sh = w * scale, h * scale + tw = math.ceil(sw / align) * align + th = math.ceil(sh / align) * align + return sw, sh, tw, th + + @staticmethod + def _pad_video_5d(video): + """Pad [1, C, F, H, W] video: repeat last 2 frames, align for pipeline. + + Uses the reference formula: (F_padded + 2 - 5) % 8 == 0, ensuring + the pipeline's streaming loop gets correct iteration counts. + """ + tail = video[:, :, -1:].repeat(1, 1, 2, 1, 1) + video = torch.cat([video, tail], dim=2) + added = 0 + remainder = (video.shape[2] + 2 - 5) % 8 + if remainder != 0: + added = 8 - remainder + pad = video[:, :, -1:].repeat(1, 1, added, 1, 1) + video = torch.cat([video, pad], dim=2) + return video, added + + @staticmethod + def _restore_video_sequence(result, added_frames, expected): + """Strip padding and warmup frames from the output.""" + if added_frames > 0 and result.shape[0] > added_frames: + result = result[:-added_frames] + # Strip the first 2 pipeline warmup frames + if result.shape[0] > 2: + result = result[2:] + # Adjust to exact expected count + if result.shape[0] > expected: + result = result[:expected] + elif result.shape[0] < expected: + pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:]) + result = torch.cat([result, pad], dim=0) + return result + + def _prepare_video(self, frames, scale): + """Convert [F, H, W, C] [0,1] frames to padded [1, C, F, H, W] [-1,1]. + + Bicubic-upscales each frame to the target resolution, normalizes to + [-1, 1], then applies temporal padding for the pipeline. + + Returns: + video: [1, C, F_padded, H, W] tensor + th, tw: padded spatial dimensions + nf: padded frame count (= video.shape[2]) + sh, sw: actual (unpadded) spatial dimensions + added: number of alignment-padding frames added + """ + N, H, W, C = frames.shape + sw, sh, tw, th = self._compute_dims(W, H, scale) + + processed = [] + for i in range(N): + frame = frames[i].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] + upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False) + pad_h, pad_w = th - sh, tw - sw + if pad_h > 0 or pad_w > 0: + upscaled = F.pad(upscaled, (0, pad_w, 0, pad_h), mode='replicate') + normalized = upscaled * 2.0 - 1.0 + processed.append(normalized.squeeze(0).cpu().to(self.dtype)) + + video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) + + # Apply temporal padding (tail + alignment) + video, added = self._pad_video_5d(video) + nf = video.shape[2] + + return video, th, tw, nf, sh, sw, added + + @staticmethod + def _to_frames(video): + """Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1].""" + from einops import rearrange + v = video.squeeze(0) if video.dim() == 5 else video + v = rearrange(v, "C F H W -> F H W C") + return (v.float() + 1.0) / 2.0 + + # ------------------------------------------------------------------ + # Main upscale method + # ------------------------------------------------------------------ + + @torch.no_grad() + def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104), + topk_ratio=2.0, kv_ratio=2.0, local_range=9, + color_fix=True, unload_dit=False, seed=1, + progress_bar_cmd=None): + """Upscale video frames with FlashVSR. + + Args: + frames: [F, H, W, C] float32 [0, 1] with F >= 21 + scale: Upscaling factor (2 or 4) + tiled: Enable VAE tiled decode (saves VRAM) + tile_size: (H, W) tile size for VAE tiling + topk_ratio: Sparse attention ratio (higher = faster, less detail) + kv_ratio: KV cache ratio (higher = more quality, more VRAM) + local_range: Local attention window (9=sharp, 11=stable) + color_fix: Apply wavelet color correction + unload_dit: Offload DiT before VAE decode (saves VRAM) + seed: Random seed + progress_bar_cmd: Callable wrapping an iterable for progress display + + Returns: + [F, H*scale, W*scale, C] float32 [0, 1] + """ + if progress_bar_cmd is None: + from tqdm import tqdm + progress_bar_cmd = tqdm + + original_count = frames.shape[0] + + # Prepare video tensor (bicubic upscale + pad) + video, th, tw, nf, sh, sw, added_frames = self._prepare_video(frames, scale) + + # Move LQ video to compute device (except for "long" mode which streams) + if "long" not in self.pipe.__class__.__name__.lower(): + video = video.to(self.pipe.device) + + # Run pipeline + out = self.pipe( + prompt="", negative_prompt="", + cfg_scale=1.0, num_inference_steps=1, + seed=seed, tiled=tiled, tile_size=tile_size, + progress_bar_cmd=progress_bar_cmd, + LQ_video=video, + num_frames=nf, height=th, width=tw, + is_full_block=False, if_buffer=True, + topk_ratio=topk_ratio * 768 * 1280 / (th * tw), + kv_ratio=kv_ratio, local_range=local_range, + color_fix=color_fix, unload_dit=unload_dit, + ) + + # Convert to ComfyUI format and crop spatial padding + result = self._to_frames(out).cpu()[:, :sh, :sw, :] + + # Restore original frame count (strip temporal padding + warmup) + result = self._restore_video_sequence(result, added_frames, original_count) + + return result diff --git a/nodes.py b/nodes.py index a25ae3f..71fe9e1 100644 --- a/nodes.py +++ b/nodes.py @@ -8,7 +8,7 @@ import torch import folder_paths from comfy.utils import ProgressBar -from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel +from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel, GIMMVFIModel, FlashVSRModel from .bim_vfi_arch import clear_backwarp_cache from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache @@ -1507,3 +1507,408 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): result = result[1:] # skip duplicate boundary frame return (result, model) + + +# --------------------------------------------------------------------------- +# FlashVSR nodes (4x video super-resolution) +# --------------------------------------------------------------------------- + +FLASHVSR_HF_REPO = "1038lab/FlashVSR" +FLASHVSR_REQUIRED_FILES = [ + "FlashVSR1_1.safetensors", + "Wan2.1_VAE.safetensors", + "LQ_proj_in.safetensors", + "TCDecoder.safetensors", + "Prompt.safetensors", +] + +FLASHVSR_MODEL_DIR = os.path.join(folder_paths.models_dir, "flashvsr") +if not os.path.exists(FLASHVSR_MODEL_DIR): + os.makedirs(FLASHVSR_MODEL_DIR, exist_ok=True) + + +def download_flashvsr_models(model_dir): + """Download FlashVSR checkpoints from HuggingFace if missing.""" + from huggingface_hub import snapshot_download + + missing = [f for f in FLASHVSR_REQUIRED_FILES + if not os.path.exists(os.path.join(model_dir, f))] + if not missing: + return + + logger.info(f"[FlashVSR] Missing files: {', '.join(missing)}. Downloading from HuggingFace...") + snapshot_download( + repo_id=FLASHVSR_HF_REPO, + local_dir=model_dir, + local_dir_use_symlinks=False, + resume_download=True, + ) + + still_missing = [f for f in FLASHVSR_REQUIRED_FILES + if not os.path.exists(os.path.join(model_dir, f))] + if still_missing: + raise FileNotFoundError( + f"[FlashVSR] Failed to download: {', '.join(still_missing)}. " + f"Please download manually from https://huggingface.co/{FLASHVSR_HF_REPO}" + ) + logger.info("[FlashVSR] All checkpoints downloaded successfully.") + + +class _FlashVSRProgressBar: + """Wrap an iterable with a ComfyUI ProgressBar.""" + + def __init__(self, total, pbar, step_ref): + self.total = total + self.pbar = pbar + self.step_ref = step_ref + + def __call__(self, iterable): + return self._Wrapper(iterable, self.pbar, self.step_ref) + + class _Wrapper: + def __init__(self, iterable, pbar, step_ref): + self.iterable = iterable + self.pbar = pbar + self.step_ref = step_ref + self._iter = iter(iterable) + + def __iter__(self): + return self + + def __next__(self): + val = next(self._iter) + self.step_ref[0] += 1 + self.pbar.update_absolute(self.step_ref[0]) + return val + + def __len__(self): + return len(self.iterable) + + +class LoadFlashVSRModel: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mode": (["tiny", "tiny-long", "full"], { + "default": "tiny", + "tooltip": "Pipeline mode. Tiny: fast TCDecoder decode. " + "Tiny-long: streaming TCDecoder, lowest VRAM for long videos. " + "Full: standard VAE decode, highest quality but more VRAM.", + }), + "precision": (["bf16", "fp16"], { + "default": "bf16", + "tooltip": "Model precision. BF16 is faster on modern GPUs. FP16 for older GPUs.", + }), + } + } + + RETURN_TYPES = ("FLASHVSR_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "video/FlashVSR" + + def load_model(self, mode, precision): + download_flashvsr_models(FLASHVSR_MODEL_DIR) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 if precision == "bf16" else torch.float16 + + wrapper = FlashVSRModel( + model_dir=FLASHVSR_MODEL_DIR, + mode=mode, + device=device, + dtype=dtype, + ) + + logger.info(f"[FlashVSR] Model loaded (mode={mode}, precision={precision})") + return (wrapper,) + + +class FlashVSRUpscale: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE", { + "tooltip": "Input video frames. Minimum 21 frames required.", + }), + "model": ("FLASHVSR_MODEL", { + "tooltip": "FlashVSR model from the Load FlashVSR Model node.", + }), + "scale": ("INT", { + "default": 4, "min": 2, "max": 4, "step": 2, + "tooltip": "Upscaling factor. 4x is the native resolution; 2x is supported but less optimized.", + }), + "frame_chunk_size": ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Process frames in chunks of this size to bound VRAM (0=all at once). " + "Each chunk must be >= 21 frames. Recommended: 33 (4x8+1) or 65 (8x8+1).", + }), + "tiled": ("BOOLEAN", { + "default": True, + "tooltip": "Enable VAE tiled decode. Reduces VRAM usage significantly.", + }), + "tile_size_h": ("INT", { + "default": 60, "min": 16, "max": 256, "step": 4, + "tooltip": "VAE tile height (in latent space). Larger = faster but more VRAM.", + }), + "tile_size_w": ("INT", { + "default": 104, "min": 16, "max": 256, "step": 4, + "tooltip": "VAE tile width (in latent space). Larger = faster but more VRAM.", + }), + "topk_ratio": ("FLOAT", { + "default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1, + "tooltip": "Sparse attention ratio. Higher = faster but may lose fine detail.", + }), + "kv_ratio": ("FLOAT", { + "default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1, + "tooltip": "KV cache ratio. Higher = better quality, more VRAM.", + }), + "local_range": ([9, 11], { + "default": 9, + "tooltip": "Local attention window. 9=sharper details, 11=more temporal stability.", + }), + "color_fix": ("BOOLEAN", { + "default": True, + "tooltip": "Apply color correction to prevent color shifts from the diffusion process.", + }), + "unload_dit": ("BOOLEAN", { + "default": False, + "tooltip": "Offload DiT to CPU before VAE decode. Saves VRAM but slower.", + }), + "seed": ("INT", { + "default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "Random seed for the diffusion process.", + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "upscale" + CATEGORY = "video/FlashVSR" + + def upscale(self, images, model, scale, frame_chunk_size, + tiled, tile_size_h, tile_size_w, + topk_ratio, kv_ratio, local_range, + color_fix, unload_dit, seed): + num_frames = images.shape[0] + if num_frames < FlashVSRModel.MIN_FRAMES: + raise ValueError( + f"FlashVSR requires at least {FlashVSRModel.MIN_FRAMES} frames, got {num_frames}" + ) + + tile_size = (tile_size_h, tile_size_w) + + # Build frame chunks + if frame_chunk_size < FlashVSRModel.MIN_FRAMES or frame_chunk_size >= num_frames: + chunks = [(0, num_frames)] + else: + chunks = [] + start = 0 + while start < num_frames: + end = min(start + frame_chunk_size, num_frames) + chunks.append((start, end)) + if end == num_frames: + break + start = end + # If the last chunk is too small, merge it into the previous one + if len(chunks) > 1 and (chunks[-1][1] - chunks[-1][0]) < FlashVSRModel.MIN_FRAMES: + prev_start = chunks[-2][0] + last_end = chunks[-1][1] + chunks = chunks[:-2] + chunks.append((prev_start, last_end)) + + # Estimate total pipeline steps for progress bar + # Mirrors _pad_video_5d: add 2 tail frames, then align with (F+2-5)%8 + total_steps = 0 + for cs, ce in chunks: + padded_n = (ce - cs) + 2 # tail frames appended by _pad_video_5d + remainder = (padded_n + 2 - 5) % 8 + if remainder != 0: + padded_n += 8 - remainder + total_steps += max(1, (padded_n - 1) // 8 - 2) + + pbar = ProgressBar(total_steps) + step_ref = [0] + progress = _FlashVSRProgressBar(total_steps, pbar, step_ref) + + model.load_to_device() + + result_chunks = [] + for chunk_start, chunk_end in chunks: + chunk_frames = images[chunk_start:chunk_end] + + chunk_result = model.upscale( + chunk_frames, + scale=scale, tiled=tiled, tile_size=tile_size, + topk_ratio=topk_ratio, kv_ratio=kv_ratio, + local_range=local_range, color_fix=color_fix, + unload_dit=unload_dit, seed=seed, + progress_bar_cmd=progress, + ) + result_chunks.append(chunk_result) + model.clear_caches() + + model.offload() + from .flashvsr_arch.models.utils import clean_vram + clean_vram() + + return (torch.cat(result_chunks, dim=0),) + + +class FlashVSRSegmentUpscale: + """Process a numbered segment with temporal overlap and crossfade blending. + + Chain multiple instances with Save nodes between them to bound peak RAM. + The model pass-through forces sequential execution so each segment + saves and frees RAM before the next starts. + + Crossfade blending within the overlap region: + - First (overlap - blend) frames: warmup only, discarded from output + - Last blend frames: linear alpha crossfade with previous segment's tail + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE", { + "tooltip": "Full input video frames. Minimum 21 frames required.", + }), + "model": ("FLASHVSR_MODEL", { + "tooltip": "FlashVSR model from Load FlashVSR Model. " + "Chain the model output to the next segment node for sequential execution.", + }), + "segment_index": ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Which segment to process (0-based).", + }), + "segment_size": ("INT", { + "default": 100, "min": 21, "max": 10000, "step": 1, + "tooltip": "Number of input frames per segment.", + }), + "overlap_frames": ("INT", { + "default": 8, "min": 0, "max": 100, "step": 1, + "tooltip": "Number of overlapping frames between adjacent segments. " + "These frames provide temporal context and crossfade blending.", + }), + "blend_frames": ("INT", { + "default": 4, "min": 0, "max": 50, "step": 1, + "tooltip": "Number of frames within the overlap region to crossfade. " + "Must be <= overlap_frames. The rest of the overlap is warmup (discarded).", + }), + "scale": ("INT", { + "default": 4, "min": 2, "max": 4, "step": 2, + "tooltip": "Upscaling factor.", + }), + "tiled": ("BOOLEAN", { + "default": True, + "tooltip": "Enable VAE tiled decode.", + }), + "tile_size_h": ("INT", { + "default": 60, "min": 16, "max": 256, "step": 4, + }), + "tile_size_w": ("INT", { + "default": 104, "min": 16, "max": 256, "step": 4, + }), + "topk_ratio": ("FLOAT", { + "default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1, + }), + "kv_ratio": ("FLOAT", { + "default": 2.0, "min": 1.0, "max": 4.0, "step": 0.1, + }), + "local_range": ([9, 11], { + "default": 9, + }), + "color_fix": ("BOOLEAN", { + "default": True, + }), + "unload_dit": ("BOOLEAN", { + "default": False, + }), + "seed": ("INT", { + "default": 1, "min": 1, "max": 0xFFFFFFFFFFFFFFFF, + }), + } + } + + RETURN_TYPES = ("IMAGE", "FLASHVSR_MODEL") + RETURN_NAMES = ("images", "model") + FUNCTION = "upscale" + CATEGORY = "video/FlashVSR" + + def upscale(self, images, model, segment_index, segment_size, + overlap_frames, blend_frames, scale, + tiled, tile_size_h, tile_size_w, + topk_ratio, kv_ratio, local_range, + color_fix, unload_dit, seed): + total_input = images.shape[0] + blend_frames = min(blend_frames, overlap_frames) + + # Clear stale overlap data from previous workflow runs + if segment_index == 0: + model._overlap_tail = None + + # Compute segment boundaries + stride = segment_size - overlap_frames + start = segment_index * stride + end = min(start + segment_size, total_input) + + if start >= total_input: + # Past the end + return (images[:1], model) + + # Ensure minimum frame count + actual_size = end - start + if actual_size < FlashVSRModel.MIN_FRAMES: + start = max(0, end - FlashVSRModel.MIN_FRAMES) + actual_size = end - start + + segment_frames = images[start:end] + + tile_size = (tile_size_h, tile_size_w) + + model.load_to_device() + + result = model.upscale( + segment_frames, + scale=scale, tiled=tiled, tile_size=tile_size, + topk_ratio=topk_ratio, kv_ratio=kv_ratio, + local_range=local_range, color_fix=color_fix, + unload_dit=unload_dit, seed=seed, + ) + + model.clear_caches() + model.offload() + from .flashvsr_arch.models.utils import clean_vram + clean_vram() + + # Handle crossfade blending with previous segment's tail + if segment_index > 0 and overlap_frames > 0 and hasattr(model, '_overlap_tail'): + prev_tail = model._overlap_tail # [blend_frames, H, W, C] on CPU + + # The overlap region in result: first overlap_frames of the upscaled output + # Within overlap: first (overlap - blend) frames are warmup (discard) + # last blend_frames frames: crossfade with prev_tail + warmup = overlap_frames - blend_frames + + if blend_frames > 0 and prev_tail is not None: + # Linear alpha ramp for crossfade + alpha = torch.linspace(0, 1, blend_frames).view(-1, 1, 1, 1) + blended = (1.0 - alpha) * prev_tail + alpha * result[warmup:warmup + blend_frames] + result = torch.cat([blended, result[overlap_frames:]], dim=0) + else: + result = result[overlap_frames:] + elif segment_index > 0 and overlap_frames > 0: + # No previous tail stored, just skip overlap + result = result[overlap_frames:] + + # Store tail frames for next segment's crossfade + if overlap_frames > 0 and blend_frames > 0 and result.shape[0] > blend_frames: + model._overlap_tail = result[-blend_frames:].cpu().to(torch.float16) + else: + model._overlap_tail = None + + return (result, model) diff --git a/requirements.txt b/requirements.txt index 0a20468..71ab0c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ yacs easydict einops huggingface_hub +safetensors