feat: add SelVA VAE Roundtrip node
Encodes audio through the VAE then decodes straight back, bypassing the diffusion model entirely. Use this to isolate whether saturation artifacts are introduced by the codec reconstruction (VAE/DAC) or by the LoRA. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,7 @@ _NODES = {
|
|||||||
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
||||||
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
||||||
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
|
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
|
||||||
|
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -0,0 +1,125 @@
|
|||||||
|
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
|
||||||
|
|
||||||
|
Useful for diagnosing codec reconstruction quality: if the output sounds
|
||||||
|
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
|
||||||
|
not the diffusion model or LoRA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
|
|
||||||
|
|
||||||
|
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaVaeRoundtrip:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"audio": ("AUDIO",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio_reconstructed",)
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Audio after VAE encode → decode roundtrip. "
|
||||||
|
"Compare to the input to hear codec reconstruction quality.",
|
||||||
|
)
|
||||||
|
FUNCTION = "roundtrip"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
|
||||||
|
"Use this to isolate codec reconstruction quality from generation quality. "
|
||||||
|
"If the output sounds degraded compared to the input, the VAE/DAC is the "
|
||||||
|
"bottleneck — not the model or LoRA."
|
||||||
|
)
|
||||||
|
|
||||||
|
def roundtrip(self, model, audio):
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
|
||||||
|
mode = model["mode"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
|
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||||
|
if not vae_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
|
||||||
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load VAE with encoder enabled
|
||||||
|
print("[VAE Roundtrip] Loading VAE...", flush=True)
|
||||||
|
vae = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=str(vae_path),
|
||||||
|
enable_conditions=False,
|
||||||
|
mode=mode,
|
||||||
|
need_vae_encoder=True,
|
||||||
|
).to(device).eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare input audio
|
||||||
|
waveform = audio["waveform"] # [1, C, L]
|
||||||
|
sr_in = audio["sample_rate"]
|
||||||
|
|
||||||
|
# Flatten to mono [L]
|
||||||
|
wav = waveform[0].mean(0)
|
||||||
|
|
||||||
|
# Resample to model sample rate if needed
|
||||||
|
if sr_in != seq_cfg.sampling_rate:
|
||||||
|
wav = torchaudio.functional.resample(
|
||||||
|
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
|
||||||
|
).squeeze(0)
|
||||||
|
print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
# Trim or pad to model duration
|
||||||
|
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
|
||||||
|
if wav.shape[0] > target_len:
|
||||||
|
wav = wav[:target_len]
|
||||||
|
print(f"[VAE Roundtrip] Trimmed to {seq_cfg.duration:.1f}s", flush=True)
|
||||||
|
elif wav.shape[0] < target_len:
|
||||||
|
import torch.nn.functional as F
|
||||||
|
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
||||||
|
|
||||||
|
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode
|
||||||
|
dist = vae.encode_audio(wav_b)
|
||||||
|
latent = dist.mode().clone() # [1, latent_dim, T]
|
||||||
|
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
|
||||||
|
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
|
||||||
|
|
||||||
|
# Decode straight back — no normalization, no generation
|
||||||
|
latent_t = latent.transpose(1, 2) # [1, T, latent_dim]
|
||||||
|
spec = vae.decode(latent_t)
|
||||||
|
out = vae.vocode(spec)
|
||||||
|
|
||||||
|
out = out.float().cpu()
|
||||||
|
if out.dim() == 1:
|
||||||
|
out = out.unsqueeze(0).unsqueeze(0) # [1, 1, L]
|
||||||
|
elif out.dim() == 2:
|
||||||
|
out = out.unsqueeze(1)
|
||||||
|
elif out.dim() == 3 and out.shape[1] != 1:
|
||||||
|
out = out.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
|
||||||
|
f"peak={out.abs().max():.4f} "
|
||||||
|
f"rms={out.pow(2).mean().sqrt():.4f}", flush=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
del vae
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
|
||||||
Reference in New Issue
Block a user