debug: add latent and audio stats logging to T2A node

Print fakes latent stats (mean/std/min/max) and audio pre-norm stats
to diagnose whether diffusion output is numerically reasonable.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 22:06:39 +01:00
parent 11457fc27a
commit 45633788a4
+8 -1
View File
@@ -90,6 +90,9 @@ class PrismAudioTextOnly:
batch_cfg=True,
)
fakes_f = fakes.float()
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
@@ -98,7 +101,7 @@ class PrismAudioTextOnly:
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes.float())
audio = diffusion.pretransform.decode(fakes_f)
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
@@ -106,8 +109,12 @@ class PrismAudioTextOnly:
# Peak normalize then clamp
audio = audio.float()
pre_norm_std = audio.std().item()
pre_norm_peak = audio.abs().max().item()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)