diff --git a/nodes/sampler.py b/nodes/sampler.py index 7ae0d4c..60155a3 100644 --- a/nodes/sampler.py +++ b/nodes/sampler.py @@ -111,6 +111,20 @@ class PrismAudioSampler: from prismaudio_core.inference.sampling import sample_discrete_euler + # Diagnostic: log DIT velocity at first step to verify model is working + t_diag = torch.ones([noise.shape[0]], dtype=noise.dtype, device=noise.device) + with torch.no_grad(): + v_diag = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + vd = v_diag.float() + print(f"[PrismAudio] DIT velocity@t=1: shape={tuple(vd.shape)} mean={vd.mean():.4f} std={vd.std():.4f} min={vd.min():.4f} max={vd.max():.4f}", flush=True) + # Also check uncond (cfg_scale=1.0) to verify conditioning is active + v_uncond = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=1.0, batch_cfg=True) + vu = v_uncond.float() + print(f"[PrismAudio] DIT velocity@t=1 uncond: mean={vu.mean():.4f} std={vu.std():.4f}", flush=True) + diff = (vd - vu).abs() + print(f"[PrismAudio] DIT cond-uncond diff: mean={diff.mean():.4f} max={diff.max():.4f}", flush=True) + del v_diag, v_uncond, vd, vu, diff + def on_step(info): pbar.update(1) diff --git a/nodes/text_only.py b/nodes/text_only.py index 8053e1e..517e2a4 100644 --- a/nodes/text_only.py +++ b/nodes/text_only.py @@ -85,6 +85,19 @@ class PrismAudioTextOnly: from prismaudio_core.inference.sampling import sample_discrete_euler + # Diagnostic: log DIT velocity at first step to verify model is working + t_diag = torch.ones([noise.shape[0]], dtype=noise.dtype, device=noise.device) + with torch.no_grad(): + v_diag = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + vd = v_diag.float() + print(f"[PrismAudio] DIT velocity@t=1: shape={tuple(vd.shape)} mean={vd.mean():.4f} std={vd.std():.4f} min={vd.min():.4f} max={vd.max():.4f}", flush=True) + v_uncond = diffusion.model(noise, t_diag, **cond_inputs, cfg_scale=1.0, batch_cfg=True) + vu = v_uncond.float() + print(f"[PrismAudio] DIT velocity@t=1 uncond: mean={vu.mean():.4f} std={vu.std():.4f}", flush=True) + diff = (vd - vu).abs() + print(f"[PrismAudio] DIT cond-uncond diff: mean={diff.mean():.4f} max={diff.max():.4f}", flush=True) + del v_diag, v_uncond, vd, vu, diff + def on_step(info): pbar.update(1)