Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection

Root cause: three critical differences from naxci1 reference implementation:

1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ
   conditioning inside the loop. The TCDecoder uses causal convolutions with
   temporal memory that must be built incrementally per-chunk. Batch decode
   breaks this design and loses LQ frame conditioning, causing ghosting.

2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal
   variant reads the OLD cache before writing the new one (truly causal),
   while Buffer writes cache BEFORE the conv call. Using the wrong variant
   misaligns temporal LQ conditioning features.

3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4)
   matching the naxci1 reference approach.

Changes:
- flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning
  and per-chunk color correction (was: batch VAE decode after loop)
- flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode)
- inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including
  full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in
  clear_caches()
- utils.py: add Causal_LQ4x_Proj class
- nodes.py: update progress bar estimation for new padding formula

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 17:42:20 +01:00
parent 94d9818675
commit fa250897a2
5 changed files with 196 additions and 98 deletions

View File

@@ -358,3 +358,103 @@ class Buffer_LQ4x_Proj(nn.Module):
self.clip_idx += 1
return outputs
class Causal_LQ4x_Proj(nn.Module):
"""Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1.
Key difference: reads old cache BEFORE writing new cache (truly causal),
whereas Buffer_LQ4x_Proj writes cache BEFORE conv call.
"""
def __init__(self, in_dim, out_dim, layer_num=30):
super().__init__()
self.ff = 1
self.hh = 16
self.ww = 16
self.hidden_dim1 = 2048
self.hidden_dim2 = 3072
self.layer_num = layer_num
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
self.act1 = nn.SiLU()
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1))
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
self.act2 = nn.SiLU()
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
self.clip_idx = 0
def forward(self, video):
self.clear_cache()
t = video.shape[2]
iter_ = 1 + (t - 1) // 4
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video = torch.cat([first_frame, video], dim=2)
out_x = []
for i in range(iter_):
x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :])
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes NEW cache AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
if i == 0:
self.cache['conv2'] = cache2_x
continue
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes NEW cache AFTER
x = self.norm2(x)
x = self.act2(x)
out_x.append(x)
out_x = torch.cat(out_x, dim=2)
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
return outputs
def clear_cache(self):
self.cache = {}
self.cache['conv1'] = None
self.cache['conv2'] = None
self.clip_idx = 0
def stream_forward(self, video_clip):
if self.clip_idx == 0:
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
video_clip = torch.cat([first_frame, video_clip], dim=2)
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
self.cache['conv2'] = cache2_x
self.clip_idx += 1
return None
else:
x = self.pixel_shuffle(video_clip)
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv1(x, self.cache['conv1']) # reads OLD cache
self.cache['conv1'] = cache1_x # writes AFTER
x = self.norm1(x)
x = self.act1(x)
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
x = self.conv2(x, self.cache['conv2']) # reads OLD cache
self.cache['conv2'] = cache2_x # writes AFTER
x = self.norm2(x)
x = self.act2(x)
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
outputs = []
for i in range(self.layer_num):
outputs.append(self.linear_layers[i](out_x))
self.clip_idx += 1
return outputs

View File

@@ -388,8 +388,11 @@ class FlashVSRFullPipeline(BasePipeline):
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
latents_total = []
self.vae.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
self.TCDecoder.clean_mem()
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
current_dit_device = next(iter(self.dit.parameters())).device
@@ -415,6 +418,7 @@ class FlashVSRFullPipeline(BasePipeline):
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = (inner_loop_num-1)*4-3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents = None
@@ -430,9 +434,10 @@ class FlashVSRFullPipeline(BasePipeline):
else:
for layer_idx in range(len(LQ_latents)):
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
@@ -453,44 +458,41 @@ class FlashVSRFullPipeline(BasePipeline):
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
latents_total.append(cur_latents)
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
try:
del pre_cache_k, pre_cache_v
except NameError:
pass
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
self.dit.to('cpu')
clean_vram()
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
latents = torch.cat(latents_total, dim=2)
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
else:
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
del latents_total
clean_vram()
if skip_vae:
return latents
# Decode
print("[FlashVSR] Starting VAE decoding...")
frames = self.decode_video(latents, **tiler_kwargs)
# 颜色校正wavelet
# Per-chunk color correction
try:
if color_fix:
frames = self.ColorCorrector(
frames.to(device=LQ_video.device),
LQ_video[:, :, :frames.shape[2], :, :],
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=16,
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]

View File

@@ -380,7 +380,7 @@ class FlashVSRTinyPipeline(BasePipeline):
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
latents_total = []
frames_total = []
self.TCDecoder.clean_mem()
LQ_pre_idx = 0
LQ_cur_idx = 0
@@ -428,7 +428,7 @@ class FlashVSRTinyPipeline(BasePipeline):
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
# 推理(无 motion_controller / vace
# Denoise
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
@@ -449,45 +449,37 @@ class FlashVSRTinyPipeline(BasePipeline):
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
latents_total.append(cur_latents)
LQ_pre_idx = LQ_cur_idx
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
try:
del pre_cache_k, pre_cache_v
except NameError:
pass
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
self.dit.to('cpu')
clean_vram()
# Streaming TCDecoder decode per-chunk with LQ conditioning
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
cur_frames = self.TCDecoder.decode_video(
cur_latents.transpose(1, 2),
parallel=False,
show_progress_bar=False,
cond=cur_LQ_frame
).transpose(1, 2).mul_(2).sub_(1)
latents = torch.cat(latents_total, dim=2)
del latents_total
clean_vram()
if skip_vae:
return latents
# Decode
print("[FlashVSR] Starting VAE decoding...")
frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1)
# 颜色校正wavelet
# Per-chunk color correction
try:
if color_fix:
frames = self.ColorCorrector(
frames.to(device=LQ_video.device),
LQ_video[:, :, :frames.shape[2], :, :],
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=16,
chunk_size=None,
method='adain'
)
except:
pass
frames_total.append(cur_frames.to('cpu'))
LQ_pre_idx = LQ_cur_idx
del cur_frames, cur_latents, cur_LQ_frame
clean_vram()
frames = torch.cat(frames_total, dim=2)
return frames[0]

View File

@@ -648,7 +648,7 @@ class FlashVSRModel:
ModelManager, FlashVSRFullPipeline,
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
)
from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
self.mode = mode
@@ -672,6 +672,8 @@ class FlashVSRModel:
mm.load_models([dit_path])
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
self.pipe = Pipeline.from_model_manager(mm, device=device)
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
self.pipe.TCDecoder = build_tcdecoder(
[512, 256, 128, 128], device, dtype, 16 + 768,
)
@@ -680,8 +682,8 @@ class FlashVSRModel:
)
self.pipe.TCDecoder.clean_mem()
# LQ frame projection
self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype)
# LQ frame projection — Causal variant for FlashVSR v1.1
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
if os.path.exists(lq_path):
lq_sd = load_file(lq_path, device="cpu")
cleaned = {}
@@ -714,6 +716,8 @@ class FlashVSRModel:
self.pipe.denoising_model().LQ_proj_in.clear_cache()
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
self.pipe.vae.clear_cache()
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
self.pipe.TCDecoder.clean_mem()
# ------------------------------------------------------------------
# Frame preprocessing / postprocessing helpers
@@ -743,7 +747,7 @@ class FlashVSRModel:
1. Bicubic-upscale each frame to target resolution
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
3. Normalize to [-1, 1]
4. Temporal padding: repeat last frame to reach 8k+1 count
4. Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
No front dummy frames — the pipeline handles LQ indexing correctly
starting from frame 0.
@@ -780,14 +784,16 @@ class FlashVSRModel:
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
# Temporal padding: repeat last frame to reach 8k+1 (pipeline requirement)
target = max(N, 25) # minimum 25 for streaming loop (P >= 1)
remainder = (target - 1) % 8
if remainder != 0:
target += 8 - remainder
# Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference)
num_with_pad = N + 4
target = ((num_with_pad - 1) // 8) * 8 + 1 # largest_8n1_leq
if target < 1:
target = 1
if target > N:
pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1)
video = torch.cat([video, pad], dim=2)
elif target < N:
video = video[:, :, :target, :, :]
nf = video.shape[2]
return video, th, tw, nf, sh, sw, pad_top, pad_left

View File

@@ -1731,14 +1731,12 @@ class FlashVSRUpscale:
chunks.append((prev_start, last_end))
# Estimate total pipeline steps for progress bar
# Mirrors _prepare_video: target = max(N, 25), round up to 8k+1
# Mirrors _prepare_video: largest_8n1_leq(N + 4)
total_steps = 0
for cs, ce in chunks:
n = ce - cs
target = max(n, 25)
remainder = (target - 1) % 8
if remainder != 0:
target += 8 - remainder
num_with_pad = n + 4
target = ((num_with_pad - 1) // 8) * 8 + 1
total_steps += max(1, (target - 1) // 8 - 2)
pbar = ProgressBar(total_steps)