fix(vae-roundtrip): use model feature_utils for decode, add normalize/unnormalize, normalize output
- Load fresh FeaturesUtils only for encoding; use model["feature_utils"] for decode+vocode to mirror the exact path the sampler takes - Apply generator.normalize() → unnormalize() around the encoded latent so the decoder receives latents in the same space it expects from inference - Log both encoded and norm→unnorm latent stats to diagnose round-trip fidelity - Normalize output to -27 dBFS (matching training clip RMS) and clamp to [-1, 1] to prevent clipping artifacts in the output waveform Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ not the diffusion model or LoRA.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -46,9 +47,12 @@ class SelvaVaeRoundtrip:
|
|||||||
def roundtrip(self, model, audio):
|
def roundtrip(self, model, audio):
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
|
||||||
mode = model["mode"]
|
mode = model["mode"]
|
||||||
seq_cfg = model["seq_cfg"]
|
seq_cfg = model["seq_cfg"]
|
||||||
device = get_device()
|
dtype = model["dtype"]
|
||||||
|
device = get_device()
|
||||||
|
generator = model["generator"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
|
||||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||||
@@ -58,9 +62,10 @@ class SelvaVaeRoundtrip:
|
|||||||
"Run SelVA Model Loader first to auto-download weights."
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE with encoder enabled
|
# Load encoder only — decoder/vocoder come from model["feature_utils"]
|
||||||
print("[VAE Roundtrip] Loading VAE...", flush=True)
|
# to mirror exactly what the sampler uses
|
||||||
vae = FeaturesUtils(
|
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
|
||||||
|
vae_enc = FeaturesUtils(
|
||||||
tod_vae_ckpt=str(vae_path),
|
tod_vae_ckpt=str(vae_path),
|
||||||
enable_conditions=False,
|
enable_conditions=False,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@@ -72,10 +77,8 @@ class SelvaVaeRoundtrip:
|
|||||||
waveform = audio["waveform"] # [1, C, L]
|
waveform = audio["waveform"] # [1, C, L]
|
||||||
sr_in = audio["sample_rate"]
|
sr_in = audio["sample_rate"]
|
||||||
|
|
||||||
# Flatten to mono [L]
|
wav = waveform[0].mean(0) # mono [L]
|
||||||
wav = waveform[0].mean(0)
|
|
||||||
|
|
||||||
# Resample to model sample rate if needed
|
|
||||||
if sr_in != seq_cfg.sampling_rate:
|
if sr_in != seq_cfg.sampling_rate:
|
||||||
wav = torchaudio.functional.resample(
|
wav = torchaudio.functional.resample(
|
||||||
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
|
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
|
||||||
@@ -83,53 +86,69 @@ class SelvaVaeRoundtrip:
|
|||||||
print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz",
|
print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz",
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Trim or pad to model duration
|
|
||||||
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
|
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
|
||||||
if wav.shape[0] > target_len:
|
if wav.shape[0] > target_len:
|
||||||
wav = wav[:target_len]
|
wav = wav[:target_len]
|
||||||
print(f"[VAE Roundtrip] Trimmed to {seq_cfg.duration:.1f}s", flush=True)
|
|
||||||
elif wav.shape[0] < target_len:
|
elif wav.shape[0] < target_len:
|
||||||
import torch.nn.functional as F
|
|
||||||
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
||||||
|
|
||||||
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
|
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Encode
|
# Encode: audio → raw latent [1, latent_dim, T]
|
||||||
dist = vae.encode_audio(wav_b)
|
dist = vae_enc.encode_audio(wav_b)
|
||||||
latent = dist.mode().clone() # [1, latent_dim, T]
|
latent = dist.mode().clone()
|
||||||
|
|
||||||
# Trim/pad latent to the exact model sequence length
|
# Trim/pad to exact model sequence length (same as _prepare_dataset)
|
||||||
# (same as _prepare_dataset) so the decoder produces the right duration
|
|
||||||
tgt = seq_cfg.latent_seq_len
|
tgt = seq_cfg.latent_seq_len
|
||||||
if latent.shape[2] < tgt:
|
if latent.shape[2] < tgt:
|
||||||
import torch.nn.functional as F
|
|
||||||
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
||||||
elif latent.shape[2] > tgt:
|
elif latent.shape[2] > tgt:
|
||||||
latent = latent[:, :, :tgt]
|
latent = latent[:, :, :tgt]
|
||||||
|
|
||||||
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
|
# To [B, T, latent_dim] — layout the generator uses
|
||||||
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
|
latent_t = latent.transpose(1, 2).to(dtype)
|
||||||
|
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
# Decode straight back — no normalization, no generation
|
# Normalize → unnormalize mirrors the training/inference pipeline:
|
||||||
latent_t = latent.transpose(1, 2) # [1, T, latent_dim]
|
# training normalizes encoded latents; sampler unnormalizes before decode.
|
||||||
spec = vae.decode(latent_t)
|
# This ensures the latent is in the same space the decoder expects.
|
||||||
out = vae.vocode(spec)
|
latent_norm = generator.normalize(latent_t.clone())
|
||||||
|
latent_unnorm = generator.unnormalize(latent_norm)
|
||||||
|
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
# Decode using model's feature_utils — same path as the sampler
|
||||||
|
orig_device = next(feature_utils.parameters()).device
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils.to(device)
|
||||||
|
try:
|
||||||
|
spec = feature_utils.decode(latent_unnorm)
|
||||||
|
out = feature_utils.vocode(spec)
|
||||||
|
finally:
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils.to(orig_device)
|
||||||
|
|
||||||
out = out.float().cpu()
|
out = out.float().cpu()
|
||||||
if out.dim() == 1:
|
if out.dim() == 1:
|
||||||
out = out.unsqueeze(0).unsqueeze(0) # [1, 1, L]
|
out = out.unsqueeze(0).unsqueeze(0)
|
||||||
elif out.dim() == 2:
|
elif out.dim() == 2:
|
||||||
out = out.unsqueeze(1)
|
out = out.unsqueeze(1)
|
||||||
elif out.dim() == 3 and out.shape[1] != 1:
|
elif out.dim() == 3 and out.shape[1] != 1:
|
||||||
out = out.mean(dim=1, keepdim=True)
|
out = out.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
|
||||||
|
target_rms = 10 ** (-27.0 / 20.0)
|
||||||
|
out = out * (target_rms / rms)
|
||||||
|
out = out.clamp(-1.0, 1.0)
|
||||||
|
|
||||||
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
|
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
|
||||||
f"peak={out.abs().max():.4f} "
|
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
|
||||||
f"rms={out.pow(2).mean().sqrt():.4f}", flush=True)
|
flush=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
del vae
|
del vae_enc
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
|
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
|
||||||
|
|||||||
Reference in New Issue
Block a user