Files
ComfyUI-SelVA/nodes/selva_vae_roundtrip.py
T
Ethanfel 8195c3114a feat: add SelVA VAE Roundtrip node
Encodes audio through the VAE then decodes straight back, bypassing the
diffusion model entirely. Use this to isolate whether saturation artifacts
are introduced by the codec reconstruction (VAE/DAC) or by the LoRA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:15:20 +02:00

126 lines
4.4 KiB
Python

"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
Useful for diagnosing codec reconstruction quality: if the output sounds
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
not the diffusion model or LoRA.
"""
import torch
import torchaudio
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
class SelvaVaeRoundtrip:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"audio": ("AUDIO",),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio_reconstructed",)
OUTPUT_TOOLTIPS = (
"Audio after VAE encode → decode roundtrip. "
"Compare to the input to hear codec reconstruction quality.",
)
FUNCTION = "roundtrip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
"Use this to isolate codec reconstruction quality from generation quality. "
"If the output sounds degraded compared to the input, the VAE/DAC is the "
"bottleneck — not the model or LoRA."
)
def roundtrip(self, model, audio):
from selva_core.model.utils.features_utils import FeaturesUtils
mode = model["mode"]
seq_cfg = model["seq_cfg"]
device = get_device()
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
# Load VAE with encoder enabled
print("[VAE Roundtrip] Loading VAE...", flush=True)
vae = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device).eval()
try:
# Prepare input audio
waveform = audio["waveform"] # [1, C, L]
sr_in = audio["sample_rate"]
# Flatten to mono [L]
wav = waveform[0].mean(0)
# Resample to model sample rate if needed
if sr_in != seq_cfg.sampling_rate:
wav = torchaudio.functional.resample(
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
).squeeze(0)
print(f"[VAE Roundtrip] Resampled {sr_in}{seq_cfg.sampling_rate} Hz",
flush=True)
# Trim or pad to model duration
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
if wav.shape[0] > target_len:
wav = wav[:target_len]
print(f"[VAE Roundtrip] Trimmed to {seq_cfg.duration:.1f}s", flush=True)
elif wav.shape[0] < target_len:
import torch.nn.functional as F
wav = F.pad(wav, (0, target_len - wav.shape[0]))
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
with torch.no_grad():
# Encode
dist = vae.encode_audio(wav_b)
latent = dist.mode().clone() # [1, latent_dim, T]
print(f"[VAE Roundtrip] Latent: shape={tuple(latent.shape)} "
f"mean={latent.mean():.4f} std={latent.std():.4f}", flush=True)
# Decode straight back — no normalization, no generation
latent_t = latent.transpose(1, 2) # [1, T, latent_dim]
spec = vae.decode(latent_t)
out = vae.vocode(spec)
out = out.float().cpu()
if out.dim() == 1:
out = out.unsqueeze(0).unsqueeze(0) # [1, 1, L]
elif out.dim() == 2:
out = out.unsqueeze(1)
elif out.dim() == 3 and out.shape[1] != 1:
out = out.mean(dim=1, keepdim=True)
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
f"peak={out.abs().max():.4f} "
f"rms={out.pow(2).mean().sqrt():.4f}", flush=True)
finally:
del vae
soft_empty_cache()
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)