From 10e6095e3196c3242880cb59d4e72f55de322374 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 19:50:01 +0200 Subject: [PATCH] fix(vae-roundtrip): use model feature_utils for decode, add normalize/unnormalize, normalize output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Load fresh FeaturesUtils only for encoding; use model["feature_utils"] for decode+vocode to mirror the exact path the sampler takes - Apply generator.normalize() → unnormalize() around the encoded latent so the decoder receives latents in the same space it expects from inference - Log both encoded and norm→unnorm latent stats to diagnose round-trip fidelity - Normalize output to -27 dBFS (matching training clip RMS) and clamp to [-1, 1] to prevent clipping artifacts in the output waveform Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_vae_roundtrip.py | 75 ++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/nodes/selva_vae_roundtrip.py b/nodes/selva_vae_roundtrip.py index 5b1cff4..1152e0e 100644 --- a/nodes/selva_vae_roundtrip.py +++ b/nodes/selva_vae_roundtrip.py @@ -6,6 +6,7 @@ not the diffusion model or LoRA. """ import torch +import torch.nn.functional as F import torchaudio from pathlib import Path @@ -46,9 +47,12 @@ class SelvaVaeRoundtrip: def roundtrip(self, model, audio): from selva_core.model.utils.features_utils import FeaturesUtils - mode = model["mode"] - seq_cfg = model["seq_cfg"] - device = get_device() + mode = model["mode"] + seq_cfg = model["seq_cfg"] + dtype = model["dtype"] + device = get_device() + generator = model["generator"] + feature_utils = model["feature_utils"] vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth" vae_path = _SELVA_DIR / "ext" / vae_name @@ -58,9 +62,10 @@ class SelvaVaeRoundtrip: "Run SelVA Model Loader first to auto-download weights." ) - # Load VAE with encoder enabled - print("[VAE Roundtrip] Loading VAE...", flush=True) - vae = FeaturesUtils( + # Load encoder only — decoder/vocoder come from model["feature_utils"] + # to mirror exactly what the sampler uses + print("[VAE Roundtrip] Loading VAE encoder...", flush=True) + vae_enc = FeaturesUtils( tod_vae_ckpt=str(vae_path), enable_conditions=False, mode=mode, @@ -72,10 +77,8 @@ class SelvaVaeRoundtrip: waveform = audio["waveform"] # [1, C, L] sr_in = audio["sample_rate"] - # Flatten to mono [L] - wav = waveform[0].mean(0) + wav = waveform[0].mean(0) # mono [L] - # Resample to model sample rate if needed if sr_in != seq_cfg.sampling_rate: wav = torchaudio.functional.resample( wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate @@ -83,53 +86,69 @@ class SelvaVaeRoundtrip: print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz", flush=True) - # Trim or pad to model duration target_len = int(seq_cfg.duration * seq_cfg.sampling_rate) if wav.shape[0] > target_len: wav = wav[:target_len] - print(f"[VAE Roundtrip] Trimmed to {seq_cfg.duration:.1f}s", flush=True) elif wav.shape[0] < target_len: - import torch.nn.functional as F wav = F.pad(wav, (0, target_len - wav.shape[0])) wav_b = wav.unsqueeze(0).to(device).float() # [1, L] with torch.no_grad(): - # Encode - dist = vae.encode_audio(wav_b) - latent = dist.mode().clone() # [1, latent_dim, T] + # Encode: audio → raw latent [1, latent_dim, T] + dist = vae_enc.encode_audio(wav_b) + latent = dist.mode().clone() - # Trim/pad latent to the exact model sequence length - # (same as _prepare_dataset) so the decoder produces the right duration + # Trim/pad to exact model sequence length (same as _prepare_dataset) tgt = seq_cfg.latent_seq_len if latent.shape[2] < tgt: - import torch.nn.functional as F latent = F.pad(latent, (0, tgt - latent.shape[2])) elif latent.shape[2] > tgt: latent = latent[:, :, :tgt] - print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} " - f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True) + # To [B, T, latent_dim] — layout the generator uses + latent_t = latent.transpose(1, 2).to(dtype) + print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}", + flush=True) - # Decode straight back — no normalization, no generation - latent_t = latent.transpose(1, 2) # [1, T, latent_dim] - spec = vae.decode(latent_t) - out = vae.vocode(spec) + # Normalize → unnormalize mirrors the training/inference pipeline: + # training normalizes encoded latents; sampler unnormalizes before decode. + # This ensures the latent is in the same space the decoder expects. + latent_norm = generator.normalize(latent_t.clone()) + latent_unnorm = generator.unnormalize(latent_norm) + print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}", + flush=True) + + # Decode using model's feature_utils — same path as the sampler + orig_device = next(feature_utils.parameters()).device + if orig_device != device: + feature_utils.to(device) + try: + spec = feature_utils.decode(latent_unnorm) + out = feature_utils.vocode(spec) + finally: + if orig_device != device: + feature_utils.to(orig_device) out = out.float().cpu() if out.dim() == 1: - out = out.unsqueeze(0).unsqueeze(0) # [1, 1, L] + out = out.unsqueeze(0).unsqueeze(0) elif out.dim() == 2: out = out.unsqueeze(1) elif out.dim() == 3 and out.shape[1] != 1: out = out.mean(dim=1, keepdim=True) + rms = out.pow(2).mean().sqrt().clamp(min=1e-8) + target_rms = 10 ** (-27.0 / 20.0) + out = out * (target_rms / rms) + out = out.clamp(-1.0, 1.0) + print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} " - f"peak={out.abs().max():.4f} " - f"rms={out.pow(2).mean().sqrt():.4f}", flush=True) + f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}", + flush=True) finally: - del vae + del vae_enc soft_empty_cache() return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)