Files
ComfyUI-Tween/flashvsr_arch/models/wan_video_dit.py
Ethanfel 76dff7e573 Fix FlashVSR quality: two-stage temporal padding, kv_ratio=3, float64 precision
Root cause of remaining ghosting: our single-stage temporal padding
(N+4 → floor to 8k+1) TRUNCATED frames when N+4 wasn't already 8k+1.
For 50 frames: 50+4=54 → floor to 49, LOSING the last input frame.
The pipeline then processed misaligned LQ→output frame mapping.

Fix matches naxci1/ComfyUI-FlashVSR_Stable two-stage approach:
1. Pad to next_8n5(N) (next integer >= N of form 8k+5, minimum 21)
2. Add 4 → result is always 8(k+1)+1, a valid 8k+1 — NEVER truncates

Also:
- kv_ratio default 2.0→3.0 (matches naxci1, max quality KV cache)
- local_range default 9→11 (more stable temporal consistency)
- sinusoidal_embedding_1d, precompute_freqs_cis, rope_apply: float32→float64
  (matches naxci1 reference precision for embeddings and RoPE)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 18:06:46 +01:00

852 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import os
import time
from typing import Tuple, Optional, List
from einops import rearrange
from .utils import hash_state_dict_keys
try:
import flash_attn_interface
assert callable(getattr(flash_attn_interface, "flash_attn_func", None))
FLASH_ATTN_3_AVAILABLE = True
except Exception:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
assert callable(getattr(flash_attn, "flash_attn_func", None))
FLASH_ATTN_2_AVAILABLE = True
except Exception:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
assert callable(sageattn)
SAGE_ATTN_AVAILABLE = True
except Exception:
SAGE_ATTN_AVAILABLE = False
try:
from sageattn.core import sparse_sageattn
assert callable(sparse_sageattn)
SPARSE_SAGE_AVAILABLE = True
except Exception:
SPARSE_SAGE_AVAILABLE = False
sparse_sageattn = None
from PIL import Image
import numpy as np
# ----------------------------
# Local / window masks
# ----------------------------
@torch.no_grad()
def build_local_block_mask_shifted_vec(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
end_r = start_r + win_h - 1
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
@torch.no_grad()
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
block_w: int,
win_h: int = 6,
win_w: int = 6,
include_self: bool = True,
device=None) -> torch.Tensor:
device = device or torch.device("cpu")
H, W = block_h, block_w
r = torch.arange(H, device=device)
c = torch.arange(W, device=device)
YY, XX = torch.meshgrid(r, c, indexing="ij")
r_all = YY.reshape(-1)
c_all = XX.reshape(-1)
r_half = win_h // 2
c_half = win_w // 2
start_r = r_all - r_half
end_r = start_r + win_h - 1
start_c = c_all - c_half
end_c = start_c + win_w - 1
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
mask = in_row & in_col
if not include_self:
mask.fill_diagonal_(False)
return mask
class WindowPartition3D:
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
@staticmethod
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
B, F, H, W, C = x.shape
wf, wh, ww = win
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
return x.view(-1, wf * wh * ww, C)
@staticmethod
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
F, H, W = orig
wf, wh, ww = win
nf, nh, nw = F // wf, H // wh, W // ww
B = windows.size(0) // (nf * nh * nw)
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
return x.view(B, F, H, W, -1)
@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_k = torch.mean(k_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_heads = avgpool_k.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
n = flat.shape[1]
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
# 修正:上行变量名统一
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
mask = mask.repeat_interleave(2, dim=-1)
return mask
@torch.no_grad()
def generate_draft_block_mask_refined(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
assert local_attn_mask is not None, "local_attn_mask must be provided"
avgpool_q = torch.mean(q_w, dim=1)
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
q_heads = avgpool_q.permute(1, 0, 2)
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
avgpool_k_split = torch.mean(k_w_split, dim=2)
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2)
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads)
k_heads = avgpool_k_refined.permute(1, 0, 2)
D = avgpool_q.shape[-1]
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
repeat_head = scores.shape[0]
num_q_blocks_local = local_attn_mask.shape[0]
num_k_blocks_local = local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
local_attn_mask = local_attn_mask.to(torch.float32)
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
assert scores.shape == local_attn_mask.shape
scores = scores + local_attn_mask
attn_map = torch.softmax(scores, dim=-1)
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) # it=seqlen可能需要调整取决于seqlen的含义
loop_num, s1, s2 = attn_map.shape
flat = attn_map.reshape(loop_num, -1)
apply_topk = min(flat.shape[1]-1, topk)
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
thresholds = thresholds.unsqueeze(1)
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
return mask
# ----------------------------
# Attention kernels
# ----------------------------
def _sdpa_fallback(q, k, v, num_heads):
"""PyTorch scaled dot-product attention (always available)."""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, enable_sageattention=True):
global SPARSE_SAGE_AVAILABLE, SAGE_ATTN_AVAILABLE, FLASH_ATTN_2_AVAILABLE, FLASH_ATTN_3_AVAILABLE
if attention_mask is not None and enable_sageattention and SPARSE_SAGE_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
base_blockmask = attention_mask
x = sparse_sageattn(
q, k, v,
mask_id=base_blockmask.to(torch.int8),
is_causal=False,
tensor_layout="HND"
)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SPARSE_SAGE_AVAILABLE = False
print("[FlashVSR] sparse_sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif compatibility_mode:
x = _sdpa_fallback(q, k, v, num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
try:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
except Exception:
SAGE_ATTN_AVAILABLE = False
print("[FlashVSR] sageattn failed (unsupported GPU?), falling back to SDPA")
# q,k,v already rearranged to [b, n, s, d] above
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
x = _sdpa_fallback(q, k, v, num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
half_dim = max(dim // 2, 1)
scale = torch.arange(half_dim, dtype=torch.float64, device=position.device)
inv_freq = torch.pow(10000.0, -scale / half_dim)
sinusoid = torch.outer(position.to(torch.float64), inv_freq)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
half_dim = max(dim // 2, 1)
base = torch.arange(0, dim, 2, dtype=torch.float64)[:half_dim]
freqs = torch.pow(theta, -base / max(dim, 1))
steps = torch.arange(end, dtype=torch.float64)
angles = torch.outer(steps, freqs)
return torch.polar(torch.ones_like(angles), angles)
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
orig_dtype = x.dtype
reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
x_complex = torch.view_as_complex(reshaped)
freqs = freqs.to(dtype=x_complex.dtype, device=x_complex.device)
x_out = torch.view_as_real(x_complex * freqs).flatten(2)
return x_out.to(orig_dtype)
# ----------------------------
# Norms & Blocks
# ----------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads, enable_sageattention=True):
super().__init__()
self.num_heads = num_heads
self.enable_sageattention = enable_sageattention
def forward(self, q, k, v, attention_mask=None):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, enable_sageattention=self.enable_sageattention)
return x
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=enable_sageattention)
self.local_attn_mask = None
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
B, L, D = x.shape
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
assert f==2, "f must be 2"
if is_stream and (pre_cache_k is None or pre_cache_v is None):
assert f==6, " start f must be 6"
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
win = (2, 8, 8)
q = q.view(B, f, h, w, D)
k = k.view(B, f, h, w, D)
v = v.view(B, f, h, w, D)
q_w = WindowPartition3D.partition(q, win)
k_w = WindowPartition3D.partition(k, win)
v_w = WindowPartition3D.partition(v, win)
seqlen = f//win[0]
one_len = k_w.shape[0] // B // seqlen
if pre_cache_k is not None and pre_cache_v is not None:
k_w = torch.cat([pre_cache_k, k_w], dim=0)
v_w = torch.cat([pre_cache_v, v_w], dim=0)
block_n = q_w.shape[0] // B
block_s = q_w.shape[1]
block_n_kv = k_w.shape[0] // B
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
window_size = win[0]*h*w//128
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
self.local_attn_mask_h = h//8
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
if cache_num > kv_len:
cache_k = k_w[one_len:, :, :]
cache_v = v_w[one_len:, :, :]
else:
cache_k = k_w
cache_v = v_w
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
x = WindowPartition3D.reverse(x, win, (f, h, w))
x = x.view(B, f*h*w, D)
if is_stream:
return self.o(x), cache_k, cache_v
return self.o(x)
class CrossAttention(nn.Module):
"""
仅考虑文本 context提供持久 KV 缓存。
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, enable_sageattention=False)
# 持久缓存
self.cache_k = None
self.cache_v = None
@torch.no_grad()
def init_cache(self, ctx: torch.Tensor):
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
self.cache_k = self.norm_k(self.k(ctx))
self.cache_v = self.v(ctx)
def clear_cache(self):
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
"""
y 即文本上下文(未做其他分支)。
"""
q = self.norm_q(self.q(x))
assert self.cache_k is not None and self.cache_v is not None
k = self.cache_k
v = self.cache_v
x = self.attn(q, k, v)
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, enable_sageattention: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps, enable_sageattention=enable_sageattention)
self.cross_attn = CrossAttention(dim, num_heads, eps, enable_sageattention=False)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
train_img=False, block_id=None, kv_len=None, is_full_block=False,
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
x = self.gate(x, gate_msa, self_attn_output)
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
if is_stream:
return x, self_attn_cache_k, self_attn_cache_v
return x
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
# ----------------------------
# WanModel (no image branch) — init 时即产生 KV 缓存
# ----------------------------
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool = False,
enable_sageattention: bool = True,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.patch_size = patch_size
# patch embed
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# text / time embed
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, ffn_dim, eps, enable_sageattention=enable_sageattention)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
self._cross_kv_initialized = False
# 可选:手动清空 / 重新初始化
# 可选:手动清空 / 重新初始化
def clear_cross_kv(self):
for blk in self.blocks:
blk.cross_attn.clear_cache()
self._cross_kv_initialized = False
@torch.no_grad()
def reinit_cross_kv(self, new_context: torch.Tensor):
ctx_txt = self.text_embedding(new_context)
for blk in self.blocks:
blk.cross_attn.init_cache(ctx_txt)
self._cross_kv_initialized = True
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
LQ_latents: Optional[List[torch.Tensor]] = None,
train_img: bool = False,
topk_ratio: Optional[float] = None,
kv_ratio: Optional[float] = None,
local_num: Optional[int] = None,
is_full_block: bool = False,
causal_idx: Optional[int] = None,
**kwargs,
):
# time / text embeds
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
# 这里仍会嵌入 textCrossAttention 若已有缓存会忽略它)
# context = self.text_embedding(context)
# 输入打补丁
x, (f, h, w) = self.patchify(x)
B = x.shape[0]
# window / masks 超参
win = (2, 8, 8)
seqlen = f//win[0]
if local_num is None:
local_random = random.random()
if local_random < 0.3:
local_num = seqlen - 3
elif local_random < 0.4:
local_num = seqlen - 4
elif local_random < 0.5:
local_num = seqlen - 2
else:
local_num = seqlen
window_size = win[0]*h*w//128
square_num = window_size*window_size
topk_ratio = 2.0
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
if kv_ratio is None:
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
decay_ratio = random.uniform(0.7, 1.0)
# RoPE 3D
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block_id, block in enumerate(self.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x += LQ_latents[block_id]
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
train_img, block_id, kv_len, is_full_block, False,
None, None)
x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
@staticmethod
def state_dict_converter():
return WanModelStateDictConverter()
# ----------------------------
# State dict converter保持原映射已忽略 has_image_input 使用)
# ----------------------------
class WanModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
"text_len": 512,
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
}
else:
config = {}
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
# 保留原有哈希匹配返回的 config实现本身不使用 has_image_input 分支
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
else:
config = {}
return state_dict, config