From 1d8b9b59e0d6f9cd93524b4d32e1c724da6d2b0c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 23:57:03 +0100 Subject: [PATCH] debug: add DIT velocity diagnostic at t=1 to isolate DIT vs VAE quality issue Co-Authored-By: Claude Sonnet 4.6 --- nodes/sampler.py | 14 ++++++++++++++ nodes/text_only.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/nodes/sampler.py b/nodes/sampler.py index 7ae0d4c..60155a3 100644 --- a/nodes/sampler.py +++ b/nodes/sampler.py @@ -111,6 +111,20 @@ class PrismAudioSampler: from prismaudio_core.inference.sampling import sample_discrete_euler + # Diagnostic: log DIT velocity at first step to verify model is working + t_diag = torch.ones([noise.shape[0]], dtype=noise.dtype, device=noise.device) + with torch.no_grad(): + v_diag = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + vd = v_diag.float() + print(f"[PrismAudio] DIT velocity@t=1: shape={tuple(vd.shape)} mean={vd.mean():.4f} std={vd.std():.4f} min={vd.min():.4f} max={vd.max():.4f}", flush=True) + # Also check uncond (cfg_scale=1.0) to verify conditioning is active + v_uncond = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=1.0, batch_cfg=True) + vu = v_uncond.float() + print(f"[PrismAudio] DIT velocity@t=1 uncond: mean={vu.mean():.4f} std={vu.std():.4f}", flush=True) + diff = (vd - vu).abs() + print(f"[PrismAudio] DIT cond-uncond diff: mean={diff.mean():.4f} max={diff.max():.4f}", flush=True) + del v_diag, v_uncond, vd, vu, diff + def on_step(info): pbar.update(1) diff --git a/nodes/text_only.py b/nodes/text_only.py index 8053e1e..517e2a4 100644 --- a/nodes/text_only.py +++ b/nodes/text_only.py @@ -85,6 +85,19 @@ class PrismAudioTextOnly: from prismaudio_core.inference.sampling import sample_discrete_euler + # Diagnostic: log DIT velocity at first step to verify model is working + t_diag = torch.ones([noise.shape[0]], dtype=noise.dtype, device=noise.device) + with torch.no_grad(): + v_diag = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + vd = v_diag.float() + print(f"[PrismAudio] DIT velocity@t=1: shape={tuple(vd.shape)} mean={vd.mean():.4f} std={vd.std():.4f} min={vd.min():.4f} max={vd.max():.4f}", flush=True) + v_uncond = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=1.0, batch_cfg=True) + vu = v_uncond.float() + print(f"[PrismAudio] DIT velocity@t=1 uncond: mean={vu.mean():.4f} std={vu.std():.4f}", flush=True) + diff = (vd - vu).abs() + print(f"[PrismAudio] DIT cond-uncond diff: mean={diff.mean():.4f} max={diff.max():.4f}", flush=True) + del v_diag, v_uncond, vd, vu, diff + def on_step(info): pbar.update(1)