Root cause: three critical differences from naxci1 reference implementation: 1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ conditioning inside the loop. The TCDecoder uses causal convolutions with temporal memory that must be built incrementally per-chunk. Batch decode breaks this design and loses LQ frame conditioning, causing ghosting. 2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal variant reads the OLD cache before writing the new one (truly causal), while Buffer writes cache BEFORE the conv call. Using the wrong variant misaligns temporal LQ conditioning features. 3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4) matching the naxci1 reference approach. Changes: - flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning and per-chunk color correction (was: batch VAE decode after loop) - flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode) - inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in clear_caches() - utils.py: add Causal_LQ4x_Proj class - nodes.py: update progress bar estimation for new padding formula Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
626 lines
26 KiB
Python
626 lines
26 KiB
Python
import types
|
||
import os
|
||
import time
|
||
from typing import Optional, Tuple, Literal
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
from einops import rearrange
|
||
from PIL import Image
|
||
from tqdm import tqdm
|
||
# import pyfiglet
|
||
|
||
from ..models.utils import clean_vram
|
||
from ..models import ModelManager
|
||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||
from ..schedulers.flow_match import FlowMatchScheduler
|
||
from .base import BasePipeline
|
||
|
||
# -----------------------------
|
||
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
||
# -----------------------------
|
||
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
||
N, C = feat.shape[:2]
|
||
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
||
std = var.sqrt().view(N, C, 1, 1)
|
||
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||
return mean, std
|
||
|
||
|
||
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
||
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
||
size = content_feat.size()
|
||
style_mean, style_std = _calc_mean_std(style_feat)
|
||
content_mean, content_std = _calc_mean_std(content_feat)
|
||
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||
return normalized * style_std.expand(size) + style_mean.expand(size)
|
||
|
||
|
||
# -----------------------------
|
||
# 小波式模糊与分解/重构(ColorCorrector 用)
|
||
# -----------------------------
|
||
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
||
vals = [
|
||
[0.0625, 0.125, 0.0625],
|
||
[0.125, 0.25, 0.125 ],
|
||
[0.0625, 0.125, 0.0625],
|
||
]
|
||
return torch.tensor(vals, dtype=dtype, device=device)
|
||
|
||
|
||
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||
N, C, H, W = x.shape
|
||
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
||
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
||
pad = radius
|
||
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
||
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
||
return out
|
||
|
||
|
||
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
||
high = torch.zeros_like(x)
|
||
low = x
|
||
for i in range(levels):
|
||
radius = 2 ** i
|
||
blurred = _wavelet_blur(low, radius)
|
||
high = high + (low - blurred)
|
||
low = blurred
|
||
return high, low
|
||
|
||
|
||
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
||
c_high, _ = _wavelet_decompose(content, levels=levels)
|
||
_, s_low = _wavelet_decompose(style, levels=levels)
|
||
return c_high + s_low
|
||
|
||
# -----------------------------
|
||
# Safetensors support ---------
|
||
# -----------------------------
|
||
st_load_file = None # Define the variable in global scope first
|
||
try:
|
||
from safetensors.torch import load_file as st_load_file
|
||
except ImportError:
|
||
# st_load_file remains None if import fails
|
||
print("Warning: 'safetensors' not installed. Safetensors (.safetensors) files cannot be loaded.")
|
||
|
||
# -----------------------------
|
||
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
||
# -----------------------------
|
||
class TorchColorCorrectorWavelet(nn.Module):
|
||
def __init__(self, levels: int = 5):
|
||
super().__init__()
|
||
self.levels = levels
|
||
|
||
@staticmethod
|
||
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
||
B, C, f, H, W = x.shape
|
||
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
||
return y, B, f
|
||
|
||
@staticmethod
|
||
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
||
BF, C, H, W = y.shape
|
||
assert BF == B * f
|
||
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
||
|
||
def forward(
|
||
self,
|
||
hq_image: torch.Tensor, # (B, C, f, H, W)
|
||
lq_image: torch.Tensor, # (B, C, f, H, W)
|
||
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
||
method: Literal['wavelet', 'adain'] = 'wavelet',
|
||
chunk_size: Optional[int] = None,
|
||
) -> torch.Tensor:
|
||
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
||
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
||
|
||
B, C, f, H, W = hq_image.shape
|
||
if chunk_size is None or chunk_size >= f:
|
||
hq4, B, f = self._flatten_time(hq_image)
|
||
lq4, _, _ = self._flatten_time(lq_image)
|
||
if method == 'wavelet':
|
||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||
elif method == 'adain':
|
||
out4 = _adain(hq4, lq4)
|
||
else:
|
||
raise ValueError(f"未知 method: {method}")
|
||
out4 = torch.clamp(out4, *clip_range)
|
||
out = self._unflatten_time(out4, B, f)
|
||
return out
|
||
|
||
outs = []
|
||
for start in range(0, f, chunk_size):
|
||
end = min(start + chunk_size, f)
|
||
hq_chunk = hq_image[:, :, start:end]
|
||
lq_chunk = lq_image[:, :, start:end]
|
||
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
||
lq4, _, _ = self._flatten_time(lq_chunk)
|
||
if method == 'wavelet':
|
||
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
||
elif method == 'adain':
|
||
out4 = _adain(hq4, lq4)
|
||
else:
|
||
raise ValueError(f"未知 method: {method}")
|
||
out4 = torch.clamp(out4, *clip_range)
|
||
out_chunk = self._unflatten_time(out4, B_, f_)
|
||
outs.append(out_chunk)
|
||
out = torch.cat(outs, dim=2)
|
||
return out
|
||
|
||
|
||
# -----------------------------
|
||
# 简化版 Pipeline(仅 dit + vae)
|
||
# -----------------------------
|
||
class FlashVSRTinyPipeline(BasePipeline):
|
||
|
||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||
self.dit: WanModel = None
|
||
self.vae: WanVideoVAE = None
|
||
self.model_names = ['dit', 'vae']
|
||
self.height_division_factor = 16
|
||
self.width_division_factor = 16
|
||
self.use_unified_sequence_parallel = False
|
||
self.prompt_emb_posi = None
|
||
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
||
|
||
|
||
|
||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||
# 仅管理 dit / vae
|
||
dtype = next(iter(self.dit.parameters())).dtype
|
||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||
enable_vram_management(
|
||
self.dit,
|
||
module_map={
|
||
torch.nn.Linear: AutoWrappedLinear,
|
||
torch.nn.Conv3d: AutoWrappedModule,
|
||
torch.nn.LayerNorm: AutoWrappedModule,
|
||
RMSNorm: AutoWrappedModule,
|
||
},
|
||
module_config=dict(
|
||
offload_dtype=dtype,
|
||
offload_device="cpu",
|
||
onload_dtype=dtype,
|
||
onload_device=self.device,
|
||
computation_dtype=self.torch_dtype,
|
||
computation_device=self.device,
|
||
),
|
||
max_num_param=num_persistent_param_in_dit,
|
||
overflow_module_config=dict(
|
||
offload_dtype=dtype,
|
||
offload_device="cpu",
|
||
onload_dtype=dtype,
|
||
onload_device="cpu",
|
||
computation_dtype=self.torch_dtype,
|
||
computation_device=self.device,
|
||
),
|
||
)
|
||
self.enable_cpu_offload()
|
||
|
||
def fetch_models(self, model_manager: ModelManager):
|
||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||
|
||
@staticmethod
|
||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||
if device is None: device = model_manager.device
|
||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
|
||
pipe.fetch_models(model_manager)
|
||
# 可选:统一序列并行入口(此处默认关闭)
|
||
pipe.use_unified_sequence_parallel = False
|
||
return pipe
|
||
|
||
def denoising_model(self):
|
||
return self.dit
|
||
|
||
# -------------------------
|
||
# 新增:显式 KV 预初始化函数
|
||
# -------------------------
|
||
def init_cross_kv(
|
||
self,
|
||
context_tensor: Optional[torch.Tensor] = None,
|
||
prompt_path = None,
|
||
):
|
||
self.load_models_to_device(["dit"])
|
||
"""
|
||
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
||
必须在 __call__ 前显式调用一次。
|
||
"""
|
||
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
||
|
||
if self.dit is None:
|
||
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
||
|
||
if context_tensor is None:
|
||
if prompt_path is None:
|
||
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
||
|
||
# --- Safetensors loading logic added here ---
|
||
prompt_path_lower = prompt_path.lower()
|
||
if prompt_path_lower.endswith(".safetensors"):
|
||
if st_load_file is None:
|
||
raise ImportError("The 'safetensors' library must be installed to load .safetensors files.")
|
||
|
||
# Load the tensor from safetensors
|
||
loaded_dict = st_load_file(prompt_path, device=self.device)
|
||
|
||
# Safetensors loads a dict. Assuming the context tensor is the only or primary key.
|
||
if len(loaded_dict) == 1:
|
||
ctx = list(loaded_dict.values())[0]
|
||
elif 'context' in loaded_dict: # Common key for text context
|
||
ctx = loaded_dict['context']
|
||
else:
|
||
raise ValueError(f"Safetensors file {prompt_path} does not contain an obvious single tensor ('context' key not found and multiple keys exist).")
|
||
|
||
else:
|
||
# Default behavior for .pth, .pt, etc.
|
||
ctx = torch.load(prompt_path, map_location=self.device)
|
||
|
||
# --------------------------------------------
|
||
# ctx = torch.load(prompt_path, map_location=self.device)
|
||
else:
|
||
ctx = context_tensor
|
||
|
||
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
||
|
||
if self.prompt_emb_posi is None:
|
||
self.prompt_emb_posi = {}
|
||
self.prompt_emb_posi['context'] = ctx
|
||
|
||
if hasattr(self.dit, "reinit_cross_kv"):
|
||
self.dit.reinit_cross_kv(ctx)
|
||
else:
|
||
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
||
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
||
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
||
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
||
# Scheduler
|
||
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
||
self.load_models_to_device([])
|
||
|
||
def prepare_unified_sequence_parallel(self):
|
||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||
|
||
def prepare_extra_input(self, latents=None):
|
||
return {}
|
||
|
||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||
return latents
|
||
|
||
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||
return frames
|
||
|
||
def decode_video(self, latents, cond=None, **kwargs):
|
||
frames = self.TCDecoder.decode_video(
|
||
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
||
parallel=False,
|
||
show_progress_bar=False,
|
||
cond=cond
|
||
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
||
|
||
return frames
|
||
|
||
@torch.no_grad()
|
||
def __call__(
|
||
self,
|
||
prompt=None,
|
||
negative_prompt="",
|
||
denoising_strength=1.0,
|
||
seed=None,
|
||
rand_device="gpu",
|
||
height=480,
|
||
width=832,
|
||
num_frames=81,
|
||
cfg_scale=5.0,
|
||
num_inference_steps=50,
|
||
sigma_shift=5.0,
|
||
tiled=True,
|
||
tile_size=(60, 104),
|
||
tile_stride=(30, 52),
|
||
tea_cache_l1_thresh=None,
|
||
tea_cache_model_id="Wan2.1-T2V-14B",
|
||
progress_bar_cmd=tqdm,
|
||
progress_bar_st=None,
|
||
LQ_video=None,
|
||
is_full_block=False,
|
||
if_buffer=False,
|
||
topk_ratio=2.0,
|
||
kv_ratio=3.0,
|
||
local_range = 9,
|
||
color_fix = True,
|
||
unload_dit = False,
|
||
skip_vae = False,
|
||
):
|
||
# 只接受 cfg=1.0(与原代码一致)
|
||
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
||
|
||
# 要求:必须先 init_cross_kv()
|
||
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
||
raise RuntimeError(
|
||
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
||
" pipe.init_cross_kv()\n"
|
||
"或传入自定义 context:\n"
|
||
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
||
)
|
||
|
||
# 尺寸修正
|
||
height, width = self.check_resize_height_width(height, width)
|
||
if num_frames % 4 != 1:
|
||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||
|
||
# Tiler 参数
|
||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||
|
||
# 初始化噪声
|
||
if if_buffer:
|
||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||
else:
|
||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||
latents = noise
|
||
|
||
process_total_num = (num_frames - 1) // 8 - 2
|
||
is_stream = True
|
||
|
||
# 清理可能存在的 LQ_proj_in cache
|
||
if hasattr(self.dit, "LQ_proj_in"):
|
||
self.dit.LQ_proj_in.clear_cache()
|
||
|
||
frames_total = []
|
||
self.TCDecoder.clean_mem()
|
||
LQ_pre_idx = 0
|
||
LQ_cur_idx = 0
|
||
|
||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
||
current_dit_device = next(iter(self.dit.parameters())).device
|
||
if str(current_dit_device) != str(self.device):
|
||
print(f"[FlashVSR] DiT is on {current_dit_device}, moving it to target device {self.device}...")
|
||
self.dit.to(self.device)
|
||
|
||
with torch.no_grad():
|
||
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
||
if cur_process_idx == 0:
|
||
pre_cache_k = [None] * len(self.dit.blocks)
|
||
pre_cache_v = [None] * len(self.dit.blocks)
|
||
LQ_latents = None
|
||
inner_loop_num = 7
|
||
for inner_idx in range(inner_loop_num):
|
||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
|
||
) if LQ_video is not None else None
|
||
if cur is None:
|
||
continue
|
||
if LQ_latents is None:
|
||
LQ_latents = cur
|
||
else:
|
||
for layer_idx in range(len(LQ_latents)):
|
||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
||
cur_latents = latents[:, :, :6, :, :]
|
||
else:
|
||
LQ_latents = None
|
||
inner_loop_num = 2
|
||
for inner_idx in range(inner_loop_num):
|
||
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
||
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
|
||
) if LQ_video is not None else None
|
||
if cur is None:
|
||
continue
|
||
if LQ_latents is None:
|
||
LQ_latents = cur
|
||
else:
|
||
for layer_idx in range(len(LQ_latents)):
|
||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||
|
||
# Denoise
|
||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||
self.dit,
|
||
x=cur_latents,
|
||
timestep=self.timestep,
|
||
context=None,
|
||
tea_cache=None,
|
||
use_unified_sequence_parallel=False,
|
||
LQ_latents=LQ_latents,
|
||
is_full_block=is_full_block,
|
||
is_stream=is_stream,
|
||
pre_cache_k=pre_cache_k,
|
||
pre_cache_v=pre_cache_v,
|
||
topk_ratio=topk_ratio,
|
||
kv_ratio=kv_ratio,
|
||
cur_process_idx=cur_process_idx,
|
||
t_mod=self.t_mod,
|
||
t=self.t,
|
||
local_range = local_range,
|
||
)
|
||
|
||
cur_latents = cur_latents - noise_pred_posi
|
||
|
||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
||
cur_frames = self.TCDecoder.decode_video(
|
||
cur_latents.transpose(1, 2),
|
||
parallel=False,
|
||
show_progress_bar=False,
|
||
cond=cur_LQ_frame
|
||
).transpose(1, 2).mul_(2).sub_(1)
|
||
|
||
# Per-chunk color correction
|
||
try:
|
||
if color_fix:
|
||
cur_frames = self.ColorCorrector(
|
||
cur_frames.to(device=self.device),
|
||
cur_LQ_frame,
|
||
clip_range=(-1, 1),
|
||
chunk_size=None,
|
||
method='adain'
|
||
)
|
||
except:
|
||
pass
|
||
|
||
frames_total.append(cur_frames.to('cpu'))
|
||
LQ_pre_idx = LQ_cur_idx
|
||
|
||
del cur_frames, cur_latents, cur_LQ_frame
|
||
clean_vram()
|
||
|
||
frames = torch.cat(frames_total, dim=2)
|
||
return frames[0]
|
||
|
||
|
||
# -----------------------------
|
||
# TeaCache(保留原逻辑;此处默认不启用)
|
||
# -----------------------------
|
||
class TeaCache:
|
||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||
self.num_inference_steps = num_inference_steps
|
||
self.step = 0
|
||
self.accumulated_rel_l1_distance = 0
|
||
self.previous_modulated_input = None
|
||
self.rel_l1_thresh = rel_l1_thresh
|
||
self.previous_residual = None
|
||
self.previous_hidden_states = None
|
||
|
||
self.coefficients_dict = {
|
||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||
}
|
||
if model_id not in self.coefficients_dict:
|
||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||
self.coefficients = self.coefficients_dict[model_id]
|
||
|
||
def check(self, dit: WanModel, x, t_mod):
|
||
modulated_inp = t_mod.clone()
|
||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||
should_calc = True
|
||
self.accumulated_rel_l1_distance = 0
|
||
else:
|
||
coefficients = self.coefficients
|
||
rescale_func = np.poly1d(coefficients)
|
||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
||
if should_calc:
|
||
self.accumulated_rel_l1_distance = 0
|
||
self.previous_modulated_input = modulated_inp
|
||
self.step = (self.step + 1) % self.num_inference_steps
|
||
if should_calc:
|
||
self.previous_hidden_states = x.clone()
|
||
return not should_calc
|
||
|
||
def store(self, hidden_states):
|
||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||
self.previous_hidden_states = None
|
||
|
||
def update(self, hidden_states):
|
||
hidden_states = hidden_states + self.previous_residual
|
||
return hidden_states
|
||
|
||
|
||
# -----------------------------
|
||
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
||
# -----------------------------
|
||
def model_fn_wan_video(
|
||
dit: WanModel,
|
||
x: torch.Tensor,
|
||
timestep: torch.Tensor,
|
||
context: torch.Tensor,
|
||
tea_cache: Optional[TeaCache] = None,
|
||
use_unified_sequence_parallel: bool = False,
|
||
LQ_latents: Optional[torch.Tensor] = None,
|
||
is_full_block: bool = False,
|
||
is_stream: bool = False,
|
||
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
||
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
||
topk_ratio: float = 2.0,
|
||
kv_ratio: float = 3.0,
|
||
cur_process_idx: int = 0,
|
||
t_mod : torch.Tensor = None,
|
||
t : torch.Tensor = None,
|
||
local_range: int = 9,
|
||
**kwargs,
|
||
):
|
||
# patchify
|
||
x, (f, h, w) = dit.patchify(x)
|
||
|
||
win = (2, 8, 8)
|
||
seqlen = f // win[0]
|
||
local_num = seqlen
|
||
window_size = win[0] * h * w // 128
|
||
square_num = window_size * window_size
|
||
topk = int(square_num * topk_ratio) - 1
|
||
kv_len = int(kv_ratio)
|
||
|
||
# RoPE 位置(分段)
|
||
if cur_process_idx == 0:
|
||
freqs = torch.cat([
|
||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||
else:
|
||
freqs = torch.cat([
|
||
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||
|
||
# TeaCache(默认不启用)
|
||
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
||
|
||
# 统一序列并行(此处默认关闭)
|
||
if use_unified_sequence_parallel:
|
||
import torch.distributed as dist
|
||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||
get_sequence_parallel_world_size,
|
||
get_sp_group)
|
||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||
|
||
# Block 堆叠
|
||
if tea_cache_update:
|
||
x = tea_cache.update(x)
|
||
else:
|
||
for block_id, block in enumerate(dit.blocks):
|
||
if LQ_latents is not None and block_id < len(LQ_latents):
|
||
x = x + LQ_latents[block_id]
|
||
x, last_pre_cache_k, last_pre_cache_v = block(
|
||
x, context, t_mod, freqs, f, h, w,
|
||
local_num, topk,
|
||
block_id=block_id,
|
||
kv_len=kv_len,
|
||
is_full_block=is_full_block,
|
||
is_stream=is_stream,
|
||
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
||
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
||
local_range = local_range,
|
||
)
|
||
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
||
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
||
|
||
x = dit.head(x, t)
|
||
if use_unified_sequence_parallel:
|
||
import torch.distributed as dist
|
||
from xfuser.core.distributed import get_sp_group
|
||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||
x = get_sp_group().all_gather(x, dim=1)
|
||
x = dit.unpatchify(x, (f, h, w))
|
||
return x, pre_cache_k, pre_cache_v
|