From fa250897a22e86fe676cab94ef1b5c2ca2f6156e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 13 Feb 2026 17:42:20 +0100 Subject: [PATCH] Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause: three critical differences from naxci1 reference implementation: 1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ conditioning inside the loop. The TCDecoder uses causal convolutions with temporal memory that must be built incrementally per-chunk. Batch decode breaks this design and loses LQ frame conditioning, causing ghosting. 2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal variant reads the OLD cache before writing the new one (truly causal), while Buffer writes cache BEFORE the conv call. Using the wrong variant misaligns temporal LQ conditioning features. 3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4) matching the naxci1 reference approach. Changes: - flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning and per-chunk color correction (was: batch VAE decode after loop) - flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode) - inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in clear_caches() - utils.py: add Causal_LQ4x_Proj class - nodes.py: update progress bar estimation for new padding formula Co-Authored-By: Claude Opus 4.6 --- flashvsr_arch/models/utils.py | 102 ++++++++++++++++++++++- flashvsr_arch/pipelines/flashvsr_full.py | 78 ++++++++--------- flashvsr_arch/pipelines/flashvsr_tiny.py | 68 +++++++-------- inference.py | 38 +++++---- nodes.py | 8 +- 5 files changed, 196 insertions(+), 98 deletions(-) diff --git a/flashvsr_arch/models/utils.py b/flashvsr_arch/models/utils.py index 00b5313..b7ccad0 100644 --- a/flashvsr_arch/models/utils.py +++ b/flashvsr_arch/models/utils.py @@ -357,4 +357,104 @@ class Buffer_LQ4x_Proj(nn.Module): outputs.append(self.linear_layers[i](out_x)) self.clip_idx += 1 return outputs - \ No newline at end of file + + +class Causal_LQ4x_Proj(nn.Module): + """Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1. + + Key difference: reads old cache BEFORE writing new cache (truly causal), + whereas Buffer_LQ4x_Proj writes cache BEFORE conv call. + """ + + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD cache + self.cache['conv1'] = cache1_x # writes NEW cache AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + if i == 0: + self.cache['conv2'] = cache2_x + continue + x = self.conv2(x, self.cache['conv2']) # reads OLD cache + self.cache['conv2'] = cache2_x # writes NEW cache AFTER + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim=2) + out_x = rearrange(out_x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache['conv1'] = None + self.cache['conv2'] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache + self.cache['conv1'] = cache1_x # writes AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD cache + self.cache['conv1'] = cache1_x # writes AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv2(x, self.cache['conv2']) # reads OLD cache + self.cache['conv2'] = cache2_x # writes AFTER + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs \ No newline at end of file diff --git a/flashvsr_arch/pipelines/flashvsr_full.py b/flashvsr_arch/pipelines/flashvsr_full.py index e26fda0..4eed010 100644 --- a/flashvsr_arch/pipelines/flashvsr_full.py +++ b/flashvsr_arch/pipelines/flashvsr_full.py @@ -388,9 +388,12 @@ class FlashVSRFullPipeline(BasePipeline): if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() - latents_total = [] - self.vae.clear_cache() - + frames_total = [] + LQ_pre_idx = 0 + LQ_cur_idx = 0 + if hasattr(self, 'TCDecoder') and self.TCDecoder is not None: + self.TCDecoder.clean_mem() + if unload_dit and hasattr(self, 'dit') and self.dit is not None: current_dit_device = next(iter(self.dit.parameters())).device if str(current_dit_device) != str(self.device): @@ -415,6 +418,7 @@ class FlashVSRFullPipeline(BasePipeline): else: for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 cur_latents = latents[:, :, :6, :, :] else: LQ_latents = None @@ -430,9 +434,10 @@ class FlashVSRFullPipeline(BasePipeline): else: for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] - - # 推理(无 motion_controller / vace) + + # Denoise noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( self.dit, x=cur_latents, @@ -453,44 +458,41 @@ class FlashVSRFullPipeline(BasePipeline): local_range = local_range, ) - # 更新 latent cur_latents = cur_latents - noise_pred_posi - latents_total.append(cur_latents) - - if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + + # Streaming TCDecoder decode per-chunk with LQ conditioning + cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device) + + if hasattr(self, 'TCDecoder') and self.TCDecoder is not None: + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame + ).transpose(1, 2).mul_(2).sub_(1) + else: + cur_frames = self.decode_video(cur_latents, **tiler_kwargs) + + # Per-chunk color correction try: - del pre_cache_k, pre_cache_v - except NameError: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: pass - print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") - self.dit.to('cpu') + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + del cur_frames, cur_latents, cur_LQ_frame clean_vram() - latents = torch.cat(latents_total, dim=2) - - del latents_total - clean_vram() - - if skip_vae: - return latents - - # Decode - print("[FlashVSR] Starting VAE decoding...") - frames = self.decode_video(latents, **tiler_kwargs) - - # 颜色校正(wavelet) - try: - if color_fix: - frames = self.ColorCorrector( - frames.to(device=LQ_video.device), - LQ_video[:, :, :frames.shape[2], :, :], - clip_range=(-1, 1), - chunk_size=16, - method='adain' - ) - except: - pass - + frames = torch.cat(frames_total, dim=2) return frames[0] diff --git a/flashvsr_arch/pipelines/flashvsr_tiny.py b/flashvsr_arch/pipelines/flashvsr_tiny.py index 60ed390..14e7363 100644 --- a/flashvsr_arch/pipelines/flashvsr_tiny.py +++ b/flashvsr_arch/pipelines/flashvsr_tiny.py @@ -380,11 +380,11 @@ class FlashVSRTinyPipeline(BasePipeline): if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() - latents_total = [] + frames_total = [] self.TCDecoder.clean_mem() LQ_pre_idx = 0 LQ_cur_idx = 0 - + if unload_dit and hasattr(self, 'dit') and self.dit is not None: current_dit_device = next(iter(self.dit.parameters())).device if str(current_dit_device) != str(self.device): @@ -427,8 +427,8 @@ class FlashVSRTinyPipeline(BasePipeline): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] - - # 推理(无 motion_controller / vace) + + # Denoise noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( self.dit, x=cur_latents, @@ -449,45 +449,37 @@ class FlashVSRTinyPipeline(BasePipeline): local_range = local_range, ) - # 更新 latent cur_latents = cur_latents - noise_pred_posi - latents_total.append(cur_latents) - LQ_pre_idx = LQ_cur_idx - - if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + + # Streaming TCDecoder decode per-chunk with LQ conditioning + cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device) + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame + ).transpose(1, 2).mul_(2).sub_(1) + + # Per-chunk color correction try: - del pre_cache_k, pre_cache_v - except NameError: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: pass - print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") - self.dit.to('cpu') + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + del cur_frames, cur_latents, cur_LQ_frame clean_vram() - - latents = torch.cat(latents_total, dim=2) - - del latents_total - clean_vram() - - if skip_vae: - return latents - - # Decode - print("[FlashVSR] Starting VAE decoding...") - frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1) - - # 颜色校正(wavelet) - try: - if color_fix: - frames = self.ColorCorrector( - frames.to(device=LQ_video.device), - LQ_video[:, :, :frames.shape[2], :, :], - clip_range=(-1, 1), - chunk_size=16, - method='adain' - ) - except: - pass + frames = torch.cat(frames_total, dim=2) return frames[0] diff --git a/inference.py b/inference.py index 736ce36..24e5047 100644 --- a/inference.py +++ b/inference.py @@ -648,7 +648,7 @@ class FlashVSRModel: ModelManager, FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline, ) - from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj + from .flashvsr_arch.models.utils import Causal_LQ4x_Proj from .flashvsr_arch.models.TCDecoder import build_tcdecoder self.mode = mode @@ -672,16 +672,18 @@ class FlashVSRModel: mm.load_models([dit_path]) Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline self.pipe = Pipeline.from_model_manager(mm, device=device) - self.pipe.TCDecoder = build_tcdecoder( - [512, 256, 128, 128], device, dtype, 16 + 768, - ) - self.pipe.TCDecoder.load_state_dict( - load_file(tcd_path, device=device), strict=False, - ) - self.pipe.TCDecoder.clean_mem() - # LQ frame projection - self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype) + # TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning) + self.pipe.TCDecoder = build_tcdecoder( + [512, 256, 128, 128], device, dtype, 16 + 768, + ) + self.pipe.TCDecoder.load_state_dict( + load_file(tcd_path, device=device), strict=False, + ) + self.pipe.TCDecoder.clean_mem() + + # LQ frame projection — Causal variant for FlashVSR v1.1 + self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype) if os.path.exists(lq_path): lq_sd = load_file(lq_path, device="cpu") cleaned = {} @@ -714,6 +716,8 @@ class FlashVSRModel: self.pipe.denoising_model().LQ_proj_in.clear_cache() if hasattr(self.pipe, "vae") and self.pipe.vae is not None: self.pipe.vae.clear_cache() + if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None: + self.pipe.TCDecoder.clean_mem() # ------------------------------------------------------------------ # Frame preprocessing / postprocessing helpers @@ -743,7 +747,7 @@ class FlashVSRModel: 1. Bicubic-upscale each frame to target resolution 2. Centered symmetric padding to 128-pixel alignment (reflect mode) 3. Normalize to [-1, 1] - 4. Temporal padding: repeat last frame to reach 8k+1 count + 4. Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference) No front dummy frames — the pipeline handles LQ indexing correctly starting from frame 0. @@ -780,14 +784,16 @@ class FlashVSRModel: video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) - # Temporal padding: repeat last frame to reach 8k+1 (pipeline requirement) - target = max(N, 25) # minimum 25 for streaming loop (P >= 1) - remainder = (target - 1) % 8 - if remainder != 0: - target += 8 - remainder + # Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference) + num_with_pad = N + 4 + target = ((num_with_pad - 1) // 8) * 8 + 1 # largest_8n1_leq + if target < 1: + target = 1 if target > N: pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1) video = torch.cat([video, pad], dim=2) + elif target < N: + video = video[:, :, :target, :, :] nf = video.shape[2] return video, th, tw, nf, sh, sw, pad_top, pad_left diff --git a/nodes.py b/nodes.py index bc48012..69cd9ef 100644 --- a/nodes.py +++ b/nodes.py @@ -1731,14 +1731,12 @@ class FlashVSRUpscale: chunks.append((prev_start, last_end)) # Estimate total pipeline steps for progress bar - # Mirrors _prepare_video: target = max(N, 25), round up to 8k+1 + # Mirrors _prepare_video: largest_8n1_leq(N + 4) total_steps = 0 for cs, ce in chunks: n = ce - cs - target = max(n, 25) - remainder = (target - 1) % 8 - if remainder != 0: - target += 8 - remainder + num_with_pad = n + 4 + target = ((num_with_pad - 1) // 8) * 8 + 1 total_steps += max(1, (target - 1) // 8 - 2) pbar = ProgressBar(total_steps)