diff --git a/flashvsr_arch/models/utils.py b/flashvsr_arch/models/utils.py index 00b5313..b7ccad0 100644 --- a/flashvsr_arch/models/utils.py +++ b/flashvsr_arch/models/utils.py @@ -357,4 +357,104 @@ class Buffer_LQ4x_Proj(nn.Module): outputs.append(self.linear_layers[i](out_x)) self.clip_idx += 1 return outputs - \ No newline at end of file + + +class Causal_LQ4x_Proj(nn.Module): + """Causal variant of Buffer_LQ4x_Proj for FlashVSR v1.1. + + Key difference: reads old cache BEFORE writing new cache (truly causal), + whereas Buffer_LQ4x_Proj writes cache BEFORE conv call. + """ + + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:, :, i*4:(i+1)*4, :, :]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD cache + self.cache['conv1'] = cache1_x # writes NEW cache AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + if i == 0: + self.cache['conv2'] = cache2_x + continue + x = self.conv2(x, self.cache['conv2']) # reads OLD cache + self.cache['conv2'] = cache2_x # writes NEW cache AFTER + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim=2) + out_x = rearrange(out_x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache['conv1'] = None + self.cache['conv2'] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD (None) cache + self.cache['conv1'] = cache1_x # writes AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) # reads OLD cache + self.cache['conv1'] = cache1_x # writes AFTER + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv2(x, self.cache['conv2']) # reads OLD cache + self.cache['conv2'] = cache2_x # writes AFTER + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs \ No newline at end of file diff --git a/flashvsr_arch/pipelines/flashvsr_full.py b/flashvsr_arch/pipelines/flashvsr_full.py index e26fda0..4eed010 100644 --- a/flashvsr_arch/pipelines/flashvsr_full.py +++ b/flashvsr_arch/pipelines/flashvsr_full.py @@ -388,9 +388,12 @@ class FlashVSRFullPipeline(BasePipeline): if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() - latents_total = [] - self.vae.clear_cache() - + frames_total = [] + LQ_pre_idx = 0 + LQ_cur_idx = 0 + if hasattr(self, 'TCDecoder') and self.TCDecoder is not None: + self.TCDecoder.clean_mem() + if unload_dit and hasattr(self, 'dit') and self.dit is not None: current_dit_device = next(iter(self.dit.parameters())).device if str(current_dit_device) != str(self.device): @@ -415,6 +418,7 @@ class FlashVSRFullPipeline(BasePipeline): else: for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 cur_latents = latents[:, :, :6, :, :] else: LQ_latents = None @@ -430,9 +434,10 @@ class FlashVSRFullPipeline(BasePipeline): else: for layer_idx in range(len(LQ_latents)): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] - - # 推理(无 motion_controller / vace) + + # Denoise noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( self.dit, x=cur_latents, @@ -453,44 +458,41 @@ class FlashVSRFullPipeline(BasePipeline): local_range = local_range, ) - # 更新 latent cur_latents = cur_latents - noise_pred_posi - latents_total.append(cur_latents) - - if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + + # Streaming TCDecoder decode per-chunk with LQ conditioning + cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device) + + if hasattr(self, 'TCDecoder') and self.TCDecoder is not None: + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame + ).transpose(1, 2).mul_(2).sub_(1) + else: + cur_frames = self.decode_video(cur_latents, **tiler_kwargs) + + # Per-chunk color correction try: - del pre_cache_k, pre_cache_v - except NameError: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: pass - print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") - self.dit.to('cpu') + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + del cur_frames, cur_latents, cur_LQ_frame clean_vram() - latents = torch.cat(latents_total, dim=2) - - del latents_total - clean_vram() - - if skip_vae: - return latents - - # Decode - print("[FlashVSR] Starting VAE decoding...") - frames = self.decode_video(latents, **tiler_kwargs) - - # 颜色校正(wavelet) - try: - if color_fix: - frames = self.ColorCorrector( - frames.to(device=LQ_video.device), - LQ_video[:, :, :frames.shape[2], :, :], - clip_range=(-1, 1), - chunk_size=16, - method='adain' - ) - except: - pass - + frames = torch.cat(frames_total, dim=2) return frames[0] diff --git a/flashvsr_arch/pipelines/flashvsr_tiny.py b/flashvsr_arch/pipelines/flashvsr_tiny.py index 60ed390..14e7363 100644 --- a/flashvsr_arch/pipelines/flashvsr_tiny.py +++ b/flashvsr_arch/pipelines/flashvsr_tiny.py @@ -380,11 +380,11 @@ class FlashVSRTinyPipeline(BasePipeline): if hasattr(self.dit, "LQ_proj_in"): self.dit.LQ_proj_in.clear_cache() - latents_total = [] + frames_total = [] self.TCDecoder.clean_mem() LQ_pre_idx = 0 LQ_cur_idx = 0 - + if unload_dit and hasattr(self, 'dit') and self.dit is not None: current_dit_device = next(iter(self.dit.parameters())).device if str(current_dit_device) != str(self.device): @@ -427,8 +427,8 @@ class FlashVSRTinyPipeline(BasePipeline): LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] - - # 推理(无 motion_controller / vace) + + # Denoise noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( self.dit, x=cur_latents, @@ -449,45 +449,37 @@ class FlashVSRTinyPipeline(BasePipeline): local_range = local_range, ) - # 更新 latent cur_latents = cur_latents - noise_pred_posi - latents_total.append(cur_latents) - LQ_pre_idx = LQ_cur_idx - - if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + + # Streaming TCDecoder decode per-chunk with LQ conditioning + cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device) + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame + ).transpose(1, 2).mul_(2).sub_(1) + + # Per-chunk color correction try: - del pre_cache_k, pre_cache_v - except NameError: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: pass - print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") - self.dit.to('cpu') + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + del cur_frames, cur_latents, cur_LQ_frame clean_vram() - - latents = torch.cat(latents_total, dim=2) - - del latents_total - clean_vram() - - if skip_vae: - return latents - - # Decode - print("[FlashVSR] Starting VAE decoding...") - frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1) - - # 颜色校正(wavelet) - try: - if color_fix: - frames = self.ColorCorrector( - frames.to(device=LQ_video.device), - LQ_video[:, :, :frames.shape[2], :, :], - clip_range=(-1, 1), - chunk_size=16, - method='adain' - ) - except: - pass + frames = torch.cat(frames_total, dim=2) return frames[0] diff --git a/inference.py b/inference.py index 736ce36..24e5047 100644 --- a/inference.py +++ b/inference.py @@ -648,7 +648,7 @@ class FlashVSRModel: ModelManager, FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline, ) - from .flashvsr_arch.models.utils import Buffer_LQ4x_Proj + from .flashvsr_arch.models.utils import Causal_LQ4x_Proj from .flashvsr_arch.models.TCDecoder import build_tcdecoder self.mode = mode @@ -672,16 +672,18 @@ class FlashVSRModel: mm.load_models([dit_path]) Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline self.pipe = Pipeline.from_model_manager(mm, device=device) - self.pipe.TCDecoder = build_tcdecoder( - [512, 256, 128, 128], device, dtype, 16 + 768, - ) - self.pipe.TCDecoder.load_state_dict( - load_file(tcd_path, device=device), strict=False, - ) - self.pipe.TCDecoder.clean_mem() - # LQ frame projection - self.pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(3, 1536, 1).to(device, dtype) + # TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning) + self.pipe.TCDecoder = build_tcdecoder( + [512, 256, 128, 128], device, dtype, 16 + 768, + ) + self.pipe.TCDecoder.load_state_dict( + load_file(tcd_path, device=device), strict=False, + ) + self.pipe.TCDecoder.clean_mem() + + # LQ frame projection — Causal variant for FlashVSR v1.1 + self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype) if os.path.exists(lq_path): lq_sd = load_file(lq_path, device="cpu") cleaned = {} @@ -714,6 +716,8 @@ class FlashVSRModel: self.pipe.denoising_model().LQ_proj_in.clear_cache() if hasattr(self.pipe, "vae") and self.pipe.vae is not None: self.pipe.vae.clear_cache() + if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None: + self.pipe.TCDecoder.clean_mem() # ------------------------------------------------------------------ # Frame preprocessing / postprocessing helpers @@ -743,7 +747,7 @@ class FlashVSRModel: 1. Bicubic-upscale each frame to target resolution 2. Centered symmetric padding to 128-pixel alignment (reflect mode) 3. Normalize to [-1, 1] - 4. Temporal padding: repeat last frame to reach 8k+1 count + 4. Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference) No front dummy frames — the pipeline handles LQ indexing correctly starting from frame 0. @@ -780,14 +784,16 @@ class FlashVSRModel: video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0) - # Temporal padding: repeat last frame to reach 8k+1 (pipeline requirement) - target = max(N, 25) # minimum 25 for streaming loop (P >= 1) - remainder = (target - 1) % 8 - if remainder != 0: - target += 8 - remainder + # Temporal padding: N+4 then floor to largest 8k+1 (matches naxci1 reference) + num_with_pad = N + 4 + target = ((num_with_pad - 1) // 8) * 8 + 1 # largest_8n1_leq + if target < 1: + target = 1 if target > N: pad = video[:, :, -1:].repeat(1, 1, target - N, 1, 1) video = torch.cat([video, pad], dim=2) + elif target < N: + video = video[:, :, :target, :, :] nf = video.shape[2] return video, th, tw, nf, sh, sw, pad_top, pad_left diff --git a/nodes.py b/nodes.py index bc48012..69cd9ef 100644 --- a/nodes.py +++ b/nodes.py @@ -1731,14 +1731,12 @@ class FlashVSRUpscale: chunks.append((prev_start, last_end)) # Estimate total pipeline steps for progress bar - # Mirrors _prepare_video: target = max(N, 25), round up to 8k+1 + # Mirrors _prepare_video: largest_8n1_leq(N + 4) total_steps = 0 for cs, ce in chunks: n = ce - cs - target = max(n, 25) - remainder = (target - 1) % 8 - if remainder != 0: - target += 8 - remainder + num_with_pad = n + 4 + target = ((num_with_pad - 1) // 8) * 8 + 1 total_steps += max(1, (target - 1) // 8 - 2) pbar = ProgressBar(total_steps)