Root cause: three critical differences from naxci1 reference implementation: 1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ conditioning inside the loop. The TCDecoder uses causal convolutions with temporal memory that must be built incrementally per-chunk. Batch decode breaks this design and loses LQ frame conditioning, causing ghosting. 2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal variant reads the OLD cache before writing the new one (truly causal), while Buffer writes cache BEFORE the conv call. Using the wrong variant misaligns temporal LQ conditioning features. 3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4) matching the naxci1 reference approach. Changes: - flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning and per-chunk color correction (was: batch VAE decode after loop) - flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode) - inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in clear_caches() - utils.py: add Causal_LQ4x_Proj class - nodes.py: update progress bar estimation for new padding formula Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
460 lines
18 KiB
Python
460 lines
18 KiB
Python
import torch, os, gc
|
|
from safetensors import safe_open
|
|
from contextlib import contextmanager
|
|
from einops import rearrange, repeat
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm import tqdm
|
|
import time
|
|
import hashlib
|
|
|
|
CACHE_T = 2
|
|
|
|
@contextmanager
|
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
|
|
|
old_register_parameter = torch.nn.Module.register_parameter
|
|
if include_buffers:
|
|
old_register_buffer = torch.nn.Module.register_buffer
|
|
|
|
def register_empty_parameter(module, name, param):
|
|
old_register_parameter(module, name, param)
|
|
if param is not None:
|
|
param_cls = type(module._parameters[name])
|
|
kwargs = module._parameters[name].__dict__
|
|
kwargs["requires_grad"] = param.requires_grad
|
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
|
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
|
if buffer is not None:
|
|
module._buffers[name] = module._buffers[name].to(device)
|
|
|
|
def patch_tensor_constructor(fn):
|
|
def wrapper(*args, **kwargs):
|
|
kwargs["device"] = device
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
if include_buffers:
|
|
tensor_constructors_to_patch = {
|
|
torch_function_name: getattr(torch, torch_function_name)
|
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
|
}
|
|
else:
|
|
tensor_constructors_to_patch = {}
|
|
|
|
try:
|
|
torch.nn.Module.register_parameter = register_empty_parameter
|
|
if include_buffers:
|
|
torch.nn.Module.register_buffer = register_empty_buffer
|
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
|
yield
|
|
finally:
|
|
torch.nn.Module.register_parameter = old_register_parameter
|
|
if include_buffers:
|
|
torch.nn.Module.register_buffer = old_register_buffer
|
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
|
setattr(torch, torch_function_name, old_torch_function)
|
|
|
|
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
|
state_dict = {}
|
|
for file_name in os.listdir(file_path):
|
|
if "." in file_name and file_name.split(".")[-1] in [
|
|
"safetensors", "bin", "ckpt", "pth", "pt"
|
|
]:
|
|
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
|
return state_dict
|
|
|
|
|
|
def load_state_dict(file_path, torch_dtype=None):
|
|
if file_path.endswith(".safetensors"):
|
|
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
else:
|
|
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
|
|
|
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
state_dict = {}
|
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
state_dict[k] = f.get_tensor(k)
|
|
if torch_dtype is not None:
|
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
return state_dict
|
|
|
|
|
|
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
|
if torch_dtype is not None:
|
|
for i in state_dict:
|
|
if isinstance(state_dict[i], torch.Tensor):
|
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
return state_dict
|
|
|
|
|
|
def search_for_embeddings(state_dict):
|
|
embeddings = []
|
|
for k in state_dict:
|
|
if isinstance(state_dict[k], torch.Tensor):
|
|
embeddings.append(state_dict[k])
|
|
elif isinstance(state_dict[k], dict):
|
|
embeddings += search_for_embeddings(state_dict[k])
|
|
return embeddings
|
|
|
|
|
|
def search_parameter(param, state_dict):
|
|
for name, param_ in state_dict.items():
|
|
if param.numel() == param_.numel():
|
|
if param.shape == param_.shape:
|
|
if torch.dist(param, param_) < 1e-3:
|
|
return name
|
|
else:
|
|
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
|
return name
|
|
return None
|
|
|
|
|
|
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
matched_keys = set()
|
|
with torch.no_grad():
|
|
for name in source_state_dict:
|
|
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
if rename is not None:
|
|
print(f'"{name}": "{rename}",')
|
|
matched_keys.add(rename)
|
|
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
length = source_state_dict[name].shape[0] // 3
|
|
rename = []
|
|
for i in range(3):
|
|
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
if None not in rename:
|
|
print(f'"{name}": {rename},')
|
|
for rename_ in rename:
|
|
matched_keys.add(rename_)
|
|
for name in target_state_dict:
|
|
if name not in matched_keys:
|
|
print("Cannot find", name, target_state_dict[name].shape)
|
|
|
|
|
|
def search_for_files(folder, extensions):
|
|
files = []
|
|
if os.path.isdir(folder):
|
|
for file in sorted(os.listdir(folder)):
|
|
files += search_for_files(os.path.join(folder, file), extensions)
|
|
elif os.path.isfile(folder):
|
|
for extension in extensions:
|
|
if folder.endswith(extension):
|
|
files.append(folder)
|
|
break
|
|
return files
|
|
|
|
|
|
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
|
keys = []
|
|
for key, value in state_dict.items():
|
|
if isinstance(key, str):
|
|
if isinstance(value, torch.Tensor):
|
|
if with_shape:
|
|
shape = "_".join(map(str, list(value.shape)))
|
|
keys.append(key + ":" + shape)
|
|
keys.append(key)
|
|
elif isinstance(value, dict):
|
|
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
|
keys.sort()
|
|
keys_str = ",".join(keys)
|
|
return keys_str
|
|
|
|
|
|
def split_state_dict_with_prefix(state_dict):
|
|
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
|
prefix_dict = {}
|
|
for key in keys:
|
|
prefix = key if "." not in key else key.split(".")[0]
|
|
if prefix not in prefix_dict:
|
|
prefix_dict[prefix] = []
|
|
prefix_dict[prefix].append(key)
|
|
state_dicts = []
|
|
for prefix, keys in prefix_dict.items():
|
|
sub_state_dict = {key: state_dict[key] for key in keys}
|
|
state_dicts.append(sub_state_dict)
|
|
return state_dicts
|
|
|
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
|
keys_str = keys_str.encode(encoding="UTF-8")
|
|
return hashlib.md5(keys_str).hexdigest()
|
|
|
|
def clean_vram():
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
if torch.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
|
|
def get_device_list():
|
|
devs = []
|
|
if torch.cuda.is_available():
|
|
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
|
|
|
if torch.mps.is_available():
|
|
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
|
|
|
|
return devs
|
|
|
|
class RMS_norm(nn.Module):
|
|
|
|
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
|
super().__init__()
|
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
|
|
self.channel_first = channel_first
|
|
self.scale = dim**0.5
|
|
self.gamma = nn.Parameter(torch.ones(shape))
|
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
|
|
|
def forward(self, x):
|
|
return F.normalize(
|
|
x, dim=(1 if self.channel_first else
|
|
-1)) * self.scale * self.gamma + self.bias
|
|
|
|
class CausalConv3d(nn.Conv3d):
|
|
"""
|
|
Causal 3d convolusion.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
|
self.padding[1], 2 * self.padding[0], 0)
|
|
self.padding = (0, 0, 0)
|
|
|
|
def forward(self, x, cache_x=None):
|
|
padding = list(self._padding)
|
|
if cache_x is not None and self._padding[4] > 0:
|
|
cache_x = cache_x.to(x.device)
|
|
# print(cache_x.shape, x.shape)
|
|
x = torch.cat([cache_x, x], dim=2)
|
|
padding[4] -= cache_x.shape[2]
|
|
# print('cache!')
|
|
x = F.pad(x, padding, mode='replicate') # mode='replicate'
|
|
# print(x[0,0,:,0,0])
|
|
|
|
return super().forward(x)
|
|
|
|
class PixelShuffle3d(nn.Module):
|
|
def __init__(self, ff, hh, ww):
|
|
super().__init__()
|
|
self.ff = ff
|
|
self.hh = hh
|
|
self.ww = ww
|
|
|
|
def forward(self, x):
|
|
# x: (B, C, F, H, W)
|
|
return rearrange(x,
|
|
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
|
ff=self.ff, hh=self.hh, ww=self.ww)
|
|
|
|
class Buffer_LQ4x_Proj(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, layer_num=30):
|
|
super().__init__()
|
|
self.ff = 1
|
|
self.hh = 16
|
|
self.ww = 16
|
|
self.hidden_dim1 = 2048
|
|
self.hidden_dim2 = 3072
|
|
self.layer_num = layer_num
|
|
|
|
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
|
|
|
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
|
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
|
self.act1 = nn.SiLU()
|
|
|
|
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
|
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
|
self.act2 = nn.SiLU()
|
|
|
|
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
|
|
|
self.clip_idx = 0
|
|
|
|
def forward(self, video):
|
|
self.clear_cache()
|
|
# x: (B, C, F, H, W)
|
|
|
|
t = video.shape[2]
|
|
iter_ = 1 + (t - 1) // 4
|
|
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
video = torch.cat([first_frame, video], dim=2)
|
|
# print(video.shape)
|
|
|
|
out_x = []
|
|
for i in range(iter_):
|
|
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv1'] = cache1_x
|
|
x = self.conv1(x, self.cache['conv1'])
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv2'] = cache2_x
|
|
if i == 0:
|
|
continue
|
|
x = self.conv2(x, self.cache['conv2'])
|
|
x = self.norm2(x)
|
|
x = self.act2(x)
|
|
out_x.append(x)
|
|
out_x = torch.cat(out_x, dim = 2)
|
|
# print(out_x.shape)
|
|
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
|
outputs = []
|
|
for i in range(self.layer_num):
|
|
outputs.append(self.linear_layers[i](out_x))
|
|
return outputs
|
|
|
|
def clear_cache(self):
|
|
self.cache = {}
|
|
self.cache['conv1'] = None
|
|
self.cache['conv2'] = None
|
|
self.clip_idx = 0
|
|
|
|
def stream_forward(self, video_clip):
|
|
if self.clip_idx == 0:
|
|
# self.clear_cache()
|
|
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
|
x = self.pixel_shuffle(video_clip)
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv1'] = cache1_x
|
|
x = self.conv1(x, self.cache['conv1'])
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv2'] = cache2_x
|
|
self.clip_idx += 1
|
|
return None
|
|
else:
|
|
x = self.pixel_shuffle(video_clip)
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv1'] = cache1_x
|
|
x = self.conv1(x, self.cache['conv1'])
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv2'] = cache2_x
|
|
x = self.conv2(x, self.cache['conv2'])
|
|
x = self.norm2(x)
|
|
x = self.act2(x)
|
|
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
|
outputs = []
|
|
for i in range(self.layer_num):
|
|
outputs.append(self.linear_layers[i](out_x))
|
|
self.clip_idx += 1
|
|
return outputs
|
|
|
|
|
|
class Causal_LQ4x_Proj(nn.Module):
|
|
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
|
|
|
|
Key difference: reads old cache BEFORE writing new cache (truly causal),
|
|
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
|
|
"""
|
|
|
|
def __init__(self, in_dim, out_dim, layer_num=30):
|
|
super().__init__()
|
|
self.ff = 1
|
|
self.hh = 16
|
|
self.ww = 16
|
|
self.hidden_dim1 = 2048
|
|
self.hidden_dim2 = 3072
|
|
self.layer_num = layer_num
|
|
|
|
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
|
|
|
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
|
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
|
self.act1 = nn.SiLU()
|
|
|
|
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
|
|
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
|
self.act2 = nn.SiLU()
|
|
|
|
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
|
|
|
self.clip_idx = 0
|
|
|
|
def forward(self, video):
|
|
self.clear_cache()
|
|
t = video.shape[2]
|
|
iter_ = 1 + (t - 1) // 4
|
|
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
video = torch.cat([first_frame, video], dim=2)
|
|
|
|
out_x = []
|
|
for i in range(iter_):
|
|
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
|
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if i == 0:
|
|
self.cache['conv2'] = cache2_x
|
|
continue
|
|
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
|
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
|
|
x = self.norm2(x)
|
|
x = self.act2(x)
|
|
out_x.append(x)
|
|
out_x = torch.cat(out_x, dim=2)
|
|
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
|
outputs = []
|
|
for i in range(self.layer_num):
|
|
outputs.append(self.linear_layers[i](out_x))
|
|
return outputs
|
|
|
|
def clear_cache(self):
|
|
self.cache = {}
|
|
self.cache['conv1'] = None
|
|
self.cache['conv2'] = None
|
|
self.clip_idx = 0
|
|
|
|
def stream_forward(self, video_clip):
|
|
if self.clip_idx == 0:
|
|
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
|
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
|
x = self.pixel_shuffle(video_clip)
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
|
|
self.cache['conv1'] = cache1_x # writes AFTER
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
self.cache['conv2'] = cache2_x
|
|
self.clip_idx += 1
|
|
return None
|
|
else:
|
|
x = self.pixel_shuffle(video_clip)
|
|
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
|
|
self.cache['conv1'] = cache1_x # writes AFTER
|
|
x = self.norm1(x)
|
|
x = self.act1(x)
|
|
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
|
|
self.cache['conv2'] = cache2_x # writes AFTER
|
|
x = self.norm2(x)
|
|
x = self.act2(x)
|
|
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
|
outputs = []
|
|
for i in range(self.layer_num):
|
|
outputs.append(self.linear_layers[i](out_x))
|
|
self.clip_idx += 1
|
|
return outputs |