Fix FlashVSR ghosting: streaming TCDecoder decode + Causal LQ projection
Root cause: three critical differences from naxci1 reference implementation: 1. Batch decode after loop → streaming per-chunk TCDecoder decode with LQ conditioning inside the loop. The TCDecoder uses causal convolutions with temporal memory that must be built incrementally per-chunk. Batch decode breaks this design and loses LQ frame conditioning, causing ghosting. 2. Buffer_LQ4x_Proj → Causal_LQ4x_Proj for FlashVSR v1.1. The causal variant reads the OLD cache before writing the new one (truly causal), while Buffer writes cache BEFORE the conv call. Using the wrong variant misaligns temporal LQ conditioning features. 3. Temporal padding formula: changed from round-up to largest_8n1_leq(N+4) matching the naxci1 reference approach. Changes: - flashvsr_full.py: streaming TCDecoder decode per-chunk with LQ conditioning and per-chunk color correction (was: batch VAE decode after loop) - flashvsr_tiny.py: streaming TCDecoder decode per-chunk (was: batch decode) - inference.py: use Causal_LQ4x_Proj, build TCDecoder for ALL modes (including full), fix temporal padding to largest_8n1_leq(N+4), clear TCDecoder in clear_caches() - utils.py: add Causal_LQ4x_Proj class - nodes.py: update progress bar estimation for new padding formula Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -388,9 +388,12 @@ class FlashVSRFullPipeline(BasePipeline):
|
||||
if hasattr(self.dit, "LQ_proj_in"):
|
||||
self.dit.LQ_proj_in.clear_cache()
|
||||
|
||||
latents_total = []
|
||||
self.vae.clear_cache()
|
||||
|
||||
frames_total = []
|
||||
LQ_pre_idx = 0
|
||||
LQ_cur_idx = 0
|
||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
||||
self.TCDecoder.clean_mem()
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
||||
current_dit_device = next(iter(self.dit.parameters())).device
|
||||
if str(current_dit_device) != str(self.device):
|
||||
@@ -415,6 +418,7 @@ class FlashVSRFullPipeline(BasePipeline):
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = (inner_loop_num-1)*4-3
|
||||
cur_latents = latents[:, :, :6, :, :]
|
||||
else:
|
||||
LQ_latents = None
|
||||
@@ -430,9 +434,10 @@ class FlashVSRFullPipeline(BasePipeline):
|
||||
else:
|
||||
for layer_idx in range(len(LQ_latents)):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||||
|
||||
# 推理(无 motion_controller / vace)
|
||||
|
||||
# Denoise
|
||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||||
self.dit,
|
||||
x=cur_latents,
|
||||
@@ -453,44 +458,41 @@ class FlashVSRFullPipeline(BasePipeline):
|
||||
local_range = local_range,
|
||||
)
|
||||
|
||||
# 更新 latent
|
||||
cur_latents = cur_latents - noise_pred_posi
|
||||
latents_total.append(cur_latents)
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
|
||||
|
||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
||||
|
||||
if hasattr(self, 'TCDecoder') and self.TCDecoder is not None:
|
||||
cur_frames = self.TCDecoder.decode_video(
|
||||
cur_latents.transpose(1, 2),
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cur_LQ_frame
|
||||
).transpose(1, 2).mul_(2).sub_(1)
|
||||
else:
|
||||
cur_frames = self.decode_video(cur_latents, **tiler_kwargs)
|
||||
|
||||
# Per-chunk color correction
|
||||
try:
|
||||
del pre_cache_k, pre_cache_v
|
||||
except NameError:
|
||||
if color_fix:
|
||||
cur_frames = self.ColorCorrector(
|
||||
cur_frames.to(device=self.device),
|
||||
cur_LQ_frame,
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=None,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
|
||||
self.dit.to('cpu')
|
||||
|
||||
frames_total.append(cur_frames.to('cpu'))
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
del cur_frames, cur_latents, cur_LQ_frame
|
||||
clean_vram()
|
||||
|
||||
latents = torch.cat(latents_total, dim=2)
|
||||
|
||||
del latents_total
|
||||
clean_vram()
|
||||
|
||||
if skip_vae:
|
||||
return latents
|
||||
|
||||
# Decode
|
||||
print("[FlashVSR] Starting VAE decoding...")
|
||||
frames = self.decode_video(latents, **tiler_kwargs)
|
||||
|
||||
# 颜色校正(wavelet)
|
||||
try:
|
||||
if color_fix:
|
||||
frames = self.ColorCorrector(
|
||||
frames.to(device=LQ_video.device),
|
||||
LQ_video[:, :, :frames.shape[2], :, :],
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=16,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
frames = torch.cat(frames_total, dim=2)
|
||||
return frames[0]
|
||||
|
||||
|
||||
|
||||
@@ -380,11 +380,11 @@ class FlashVSRTinyPipeline(BasePipeline):
|
||||
if hasattr(self.dit, "LQ_proj_in"):
|
||||
self.dit.LQ_proj_in.clear_cache()
|
||||
|
||||
latents_total = []
|
||||
frames_total = []
|
||||
self.TCDecoder.clean_mem()
|
||||
LQ_pre_idx = 0
|
||||
LQ_cur_idx = 0
|
||||
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and self.dit is not None:
|
||||
current_dit_device = next(iter(self.dit.parameters())).device
|
||||
if str(current_dit_device) != str(self.device):
|
||||
@@ -427,8 +427,8 @@ class FlashVSRTinyPipeline(BasePipeline):
|
||||
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
||||
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
||||
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
||||
|
||||
# 推理(无 motion_controller / vace)
|
||||
|
||||
# Denoise
|
||||
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
||||
self.dit,
|
||||
x=cur_latents,
|
||||
@@ -449,45 +449,37 @@ class FlashVSRTinyPipeline(BasePipeline):
|
||||
local_range = local_range,
|
||||
)
|
||||
|
||||
# 更新 latent
|
||||
cur_latents = cur_latents - noise_pred_posi
|
||||
latents_total.append(cur_latents)
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
|
||||
|
||||
# Streaming TCDecoder decode per-chunk with LQ conditioning
|
||||
cur_LQ_frame = LQ_video[:, :, LQ_pre_idx:LQ_cur_idx, :, :].to(self.device)
|
||||
cur_frames = self.TCDecoder.decode_video(
|
||||
cur_latents.transpose(1, 2),
|
||||
parallel=False,
|
||||
show_progress_bar=False,
|
||||
cond=cur_LQ_frame
|
||||
).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
||||
# Per-chunk color correction
|
||||
try:
|
||||
del pre_cache_k, pre_cache_v
|
||||
except NameError:
|
||||
if color_fix:
|
||||
cur_frames = self.ColorCorrector(
|
||||
cur_frames.to(device=self.device),
|
||||
cur_LQ_frame,
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=None,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
|
||||
self.dit.to('cpu')
|
||||
|
||||
frames_total.append(cur_frames.to('cpu'))
|
||||
LQ_pre_idx = LQ_cur_idx
|
||||
|
||||
del cur_frames, cur_latents, cur_LQ_frame
|
||||
clean_vram()
|
||||
|
||||
latents = torch.cat(latents_total, dim=2)
|
||||
|
||||
del latents_total
|
||||
clean_vram()
|
||||
|
||||
if skip_vae:
|
||||
return latents
|
||||
|
||||
# Decode
|
||||
print("[FlashVSR] Starting VAE decoding...")
|
||||
frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1)
|
||||
|
||||
# 颜色校正(wavelet)
|
||||
try:
|
||||
if color_fix:
|
||||
frames = self.ColorCorrector(
|
||||
frames.to(device=LQ_video.device),
|
||||
LQ_video[:, :, :frames.shape[2], :, :],
|
||||
clip_range=(-1, 1),
|
||||
chunk_size=16,
|
||||
method='adain'
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
frames = torch.cat(frames_total, dim=2)
|
||||
return frames[0]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user