Files
ComfyUI-Tween/flashvsr_arch/models/utils.py
Ethanfel fa250897a2 Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection
Root cause: three critical differences from naxci1 reference implementation:

1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ
   conditioning inside the loop. The TCDecoder uses causal convolutions with
   temporal memory that must be built incrementally per-chunk. Batch decode
   breaks this design and loses LQ frame conditioning, causing ghosting.

2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal
   variant reads the OLD cache before writing the new one (truly causal),
   while Buffer writes cache BEFORE the conv call. Using the wrong variant
   misaligns temporal LQ conditioning features.

3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4)
   matching the naxci1 reference approach.

Changes:
- flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning
  and per-chunk color correction (was: batch VAE decode after loop)
- flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode)
- inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including
  full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in
  clear_caches()
- utils.py: add Causal_LQ4x_Proj class
- nodes.py: update progress bar estimation for new padding formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:42:20 +01:00

460 lines
18 KiB
Python

import torch, os, gc
from safetensors import safe_open
from contextlib import contextmanager
from einops import rearrange, repeat
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import hashlib
CACHE_T = 2
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def clean_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.mps.is_available():
torch.mps.empty_cache()
def get_device_list():
devs = []
if torch.cuda.is_available():
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
if torch.mps.is_available():
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
return devs
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
# print(cache_x.shape, x.shape)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
# print('cache!')
x = F.pad(x, padding, mode='replicate') # mode='replicate'
# print(x[0,0,:,0,0])
return super().forward(x)
class PixelShuffle3d(nn.Module):
def __init__(self, ff, hh, ww):
super().__init__()
self.ff = ff
self.hh = hh
self.ww = ww
def forward(self, x):
# x: (B, C, F, H, W)
return rearrange(x,
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
ff=self.ff, hh=self.hh, ww=self.ww)
class Buffer_LQ4x_Proj(nn.Module):
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
# x: (B, C, F, H, W)
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
# print(video.shape)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
if i == 0:
continue
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim = 2)
# print(out_x.shape)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
# self.clear_cache()
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv1'] = cache1_x
x = self.conv1(x, self.cache['conv1'])
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
x = self.conv2(x, self.cache['conv2'])
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs
class Causal_LQ4x_Proj(nn.Module):
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
Key difference: reads old cache BEFORE writing new cache (truly causal),
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
"""
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
if i == 0:
self.cache['conv2'] = cache2_x
continue
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes AFTER
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs