8fa2699551
References were stored in normalized flow-matching space (net_generator.normalize(z_sample)) but the style loss compares against unnormalize(x) which is in VAE latent space. The optimizer was minimizing L1 between tensors at different scales, pushing the ODE endpoint out of distribution and producing noise. Fix: store reference latents in VAE space (z_sample directly) so both ref_mean/ref_gram and x_un are in the same coordinate system. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
516 lines
23 KiB
Python
516 lines
23 KiB
Python
"""SelVA DITTO Optimizer.
|
||
|
||
Inference-time noise optimization: optimizes the initial noise latent x_0
|
||
using a style loss against BJ reference clips, backpropagating through the
|
||
ODE solver. All model weights remain frozen — only x_0 changes.
|
||
|
||
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
|
||
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
|
||
|
||
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
|
||
against BJ reference clips. Runs entirely before the vocoder — optimization
|
||
only requires the DiT + VAE decoder, not BigVGAN.
|
||
|
||
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
|
||
forward pass activations) instead of O(N steps). Backward recomputes each
|
||
step's activations on demand.
|
||
"""
|
||
|
||
import dataclasses
|
||
import threading
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import torchaudio
|
||
import comfy.utils
|
||
import comfy.model_management
|
||
import folder_paths
|
||
|
||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||
|
||
|
||
def _load_wav(path):
|
||
"""Load audio file to [channels, samples] float32 tensor."""
|
||
try:
|
||
return torchaudio.load(str(path))
|
||
except Exception:
|
||
pass
|
||
import soundfile as sf
|
||
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
|
||
wav = torch.from_numpy(data.T)
|
||
return wav, sr
|
||
|
||
|
||
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
||
"""Style loss between generated mel and precomputed reference statistics.
|
||
|
||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
||
ref_mean: [n_mels] mean spectrum of reference clips (detached)
|
||
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
|
||
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
|
||
Start at 0; enable only if mean-only optimization converges
|
||
without noise, then increase slowly (0.01–0.1).
|
||
"""
|
||
m = mel_gen.squeeze(0) # [n_mels, T]
|
||
|
||
# Mean spectrum loss — captures spectral envelope
|
||
gen_mean = m.mean(dim=-1) # [n_mels]
|
||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||
|
||
if gram_weight <= 0.0:
|
||
return loss_mean
|
||
|
||
# Gram matrix loss — captures timbral texture (can add noise if too high)
|
||
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||
|
||
return loss_mean + gram_weight * loss_gram
|
||
|
||
|
||
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
|
||
"""Style loss computed directly in VAE latent space.
|
||
|
||
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
|
||
ref_mean: [C_lat] mean latent vector of reference clips
|
||
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
|
||
gram_weight: weight for Gram component — 0 = mean only (recommended start)
|
||
|
||
Operating in latent space avoids backprop through the VAE decoder, which
|
||
is @torch.inference_mode() and produces noisy, unstable gradients.
|
||
"""
|
||
# Mean latent loss — matches average activation per channel
|
||
gen_mean = z.mean(dim=0) # [C_lat]
|
||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
||
|
||
if gram_weight <= 0.0:
|
||
return loss_mean
|
||
|
||
# Gram matrix — inter-channel covariance, position-invariant
|
||
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
|
||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
||
|
||
return loss_mean + gram_weight * loss_gram
|
||
|
||
|
||
class SelvaDittoOptimizer:
|
||
"""DITTO inference-time noise optimization.
|
||
|
||
Freezes all model weights and optimizes only the initial noise latent x_0
|
||
to make the generated audio sound like the BJ reference clips.
|
||
No training data or gradient updates to the model — per-video per-run.
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"model": ("SELVA_MODEL",),
|
||
"features": ("SELVA_FEATURES",),
|
||
"prompt": ("STRING", {
|
||
"default": "", "multiline": True,
|
||
"tooltip": "Sound description. Leave empty to use features prompt.",
|
||
}),
|
||
"negative_prompt": ("STRING", {
|
||
"default": "", "multiline": False,
|
||
}),
|
||
"reference_dir": ("STRING", {
|
||
"default": "",
|
||
"tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). "
|
||
"Reference mel statistics are precomputed from these once.",
|
||
}),
|
||
"n_opt_steps": ("INT", {
|
||
"default": 50, "min": 5, "max": 500,
|
||
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
|
||
"each step requires ~2 DiT forward passes.",
|
||
}),
|
||
"opt_lr": ("FLOAT", {
|
||
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
|
||
"tooltip": "Adam learning rate for x_0 optimization. "
|
||
"0.02–0.05 is recommended; 0.1 (paper default) causes oscillation.",
|
||
}),
|
||
"n_ode_steps": ("INT", {
|
||
"default": 10, "min": 5, "max": 50,
|
||
"tooltip": "Euler ODE steps run during each optimization iteration. "
|
||
"Lower = faster optimization (10–15 is a good trade-off). "
|
||
"Final generation always uses the steps parameter below.",
|
||
}),
|
||
"n_grad_steps": ("INT", {
|
||
"default": 5, "min": 1, "max": 50,
|
||
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
|
||
"Higher = more accurate gradient, more VRAM. "
|
||
"Must be ≤ n_ode_steps. 5 is a good default.",
|
||
}),
|
||
"style_weight": ("FLOAT", {
|
||
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
|
||
"tooltip": "Weight of the BJ style loss. High values push harder toward "
|
||
"BJ style but add noise. Start at 0.1 and increase slowly.",
|
||
}),
|
||
"gram_weight": ("FLOAT", {
|
||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
|
||
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
|
||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
||
"0.1 adds texture matching but can introduce white noise.",
|
||
}),
|
||
"anchor_weight": ("FLOAT", {
|
||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
||
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
|
||
"Prevents optimization from pushing x0 out of the flow's "
|
||
"expected distribution (which causes white noise). "
|
||
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
|
||
}),
|
||
"steps": ("INT", {
|
||
"default": 25, "min": 1, "max": 200,
|
||
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
||
}),
|
||
"cfg_strength": ("FLOAT", {
|
||
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||
},
|
||
"optional": {
|
||
"normalize": ("BOOLEAN", {"default": True}),
|
||
"target_lufs": ("FLOAT", {
|
||
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("AUDIO",)
|
||
RETURN_NAMES = ("audio",)
|
||
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",)
|
||
FUNCTION = "optimize"
|
||
CATEGORY = SELVA_CATEGORY
|
||
DESCRIPTION = (
|
||
"DITTO inference-time noise optimization (arXiv:2401.12179). "
|
||
"Optimizes the initial noise latent x_0 to match BJ reference clips "
|
||
"via mel statistics style loss, backpropagating through the ODE. "
|
||
"All model weights frozen — zero quality degradation risk."
|
||
)
|
||
|
||
def optimize(self, model, features, prompt, negative_prompt,
|
||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||
normalize=True, target_lufs=-27.0):
|
||
import traceback
|
||
|
||
device = get_device()
|
||
dtype = model["dtype"]
|
||
strategy = model["strategy"]
|
||
net_generator = model["generator"]
|
||
feature_utils = model["feature_utils"]
|
||
mel_converter = feature_utils.mel_converter
|
||
|
||
# Validate variant match
|
||
feat_variant = features.get("variant")
|
||
if feat_variant is not None and feat_variant != model["variant"]:
|
||
raise ValueError(
|
||
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
|
||
f"Re-run Feature Extractor."
|
||
)
|
||
|
||
if not prompt or not prompt.strip():
|
||
prompt = features.get("prompt", "")
|
||
|
||
# Resolve duration and seq_cfg
|
||
duration = features.get("duration", 0)
|
||
if duration <= 0:
|
||
raise ValueError("[DITTO] Features contain no duration field.")
|
||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
||
sample_rate = seq_cfg.sampling_rate
|
||
|
||
# Load reference clips and encode to latent space.
|
||
# Style loss is computed in latent space (after net_generator.unnormalize)
|
||
# rather than mel space — this avoids backpropagating through the VAE
|
||
# decoder (which is @torch.inference_mode() and produces noisy gradients).
|
||
ref_dir = Path(reference_dir.strip())
|
||
if not ref_dir.is_absolute():
|
||
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
||
if not ref_dir.exists():
|
||
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
|
||
|
||
ref_files = []
|
||
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
|
||
ref_files.extend(ref_dir.rglob(ext))
|
||
if not ref_files:
|
||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
||
|
||
if not hasattr(feature_utils.tod.vae, "encoder"):
|
||
raise RuntimeError(
|
||
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
|
||
"Reload the model with the encoder enabled."
|
||
)
|
||
|
||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
||
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
||
|
||
ref_latents = []
|
||
with torch.no_grad():
|
||
for rf in ref_files:
|
||
try:
|
||
wav, sr = _load_wav(rf)
|
||
if wav.shape[0] > 1:
|
||
wav = wav.mean(0, keepdim=True)
|
||
if sr != sample_rate:
|
||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
||
wav = wav.squeeze(0).to(device, torch.float32)
|
||
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
|
||
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
|
||
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
||
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
||
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
|
||
except Exception as e:
|
||
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
||
|
||
if not ref_latents:
|
||
raise RuntimeError("[DITTO] No usable reference clips.")
|
||
|
||
# Precompute reference latent statistics (done once — detached, no grad)
|
||
with torch.no_grad():
|
||
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
|
||
ref_mean = all_means.mean(0) # [C_lat]
|
||
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
|
||
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
|
||
|
||
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
|
||
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
||
f"grad_steps={n_grad_steps}", flush=True)
|
||
|
||
if strategy == "offload_to_cpu":
|
||
net_generator.to(device)
|
||
feature_utils.to(device)
|
||
soft_empty_cache()
|
||
|
||
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
|
||
|
||
_result = [None]
|
||
_exc = [None]
|
||
|
||
def _worker():
|
||
try:
|
||
_result[0] = _do_optimize(
|
||
net_generator, feature_utils, mel_converter,
|
||
features, prompt, negative_prompt,
|
||
ref_mean, ref_gram,
|
||
seq_cfg, sample_rate, device, dtype,
|
||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||
normalize, target_lufs, pbar,
|
||
)
|
||
except Exception as e:
|
||
_exc[0] = e
|
||
traceback.print_exc()
|
||
|
||
t = threading.Thread(target=_worker, daemon=True)
|
||
t.start()
|
||
t.join()
|
||
|
||
if strategy == "offload_to_cpu":
|
||
net_generator.to(get_offload_device())
|
||
feature_utils.to(get_offload_device())
|
||
soft_empty_cache()
|
||
|
||
if _exc[0] is not None:
|
||
raise _exc[0]
|
||
return (_result[0],)
|
||
|
||
|
||
def _do_optimize(net_generator, feature_utils, mel_converter,
|
||
features, prompt, negative_prompt,
|
||
ref_mean, ref_gram,
|
||
seq_cfg, sample_rate, device, dtype,
|
||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
||
normalize, target_lufs, pbar):
|
||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
||
|
||
# Strip inference flags from ref stats (came from main thread) and cast to
|
||
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
|
||
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
|
||
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
|
||
ref_mean = ref_mean.clone().detach().to(dtype)
|
||
ref_gram = ref_gram.clone().detach().to(dtype)
|
||
|
||
torch.manual_seed(seed)
|
||
|
||
clip_f = features["clip_features"].to(device, dtype).clone()
|
||
sync_f = features["sync_features"].to(device, dtype).clone()
|
||
|
||
# Strip inference-mode flags from all model weights and buffers BEFORE any
|
||
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
|
||
# operations on inference tensors produce inference tensors, so conditions
|
||
# computed from tainted weights would also be tainted. clone() outside
|
||
# inference_mode produces a normal tensor regardless of the source flag.
|
||
def _strip_inference(module):
|
||
for mod in module.modules():
|
||
for name, buf in list(mod._buffers.items()):
|
||
if buf is not None:
|
||
mod._buffers[name] = buf.clone()
|
||
for name, param in list(mod._parameters.items()):
|
||
if param is not None:
|
||
mod._parameters[name] = torch.nn.Parameter(
|
||
param.data.clone(), requires_grad=False
|
||
)
|
||
|
||
_strip_inference(net_generator)
|
||
_strip_inference(feature_utils)
|
||
_strip_inference(mel_converter)
|
||
|
||
net_generator.update_seq_lengths(
|
||
latent_seq_len=seq_cfg.latent_seq_len,
|
||
clip_seq_len=clip_f.shape[1],
|
||
sync_seq_len=sync_f.shape[1],
|
||
)
|
||
|
||
with torch.no_grad():
|
||
text_clip = feature_utils.encode_text_clip([prompt])
|
||
|
||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||
if negative_prompt.strip() else None
|
||
|
||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||
empty_conditions = net_generator.get_empty_conditions(
|
||
bs=1, negative_text_features=neg_text_clip
|
||
)
|
||
|
||
# Clone all tensors inside conditions/empty_conditions to ensure no inference
|
||
# flags survived from intermediate computations inside preprocess_conditions.
|
||
def _clone_nested(obj):
|
||
if isinstance(obj, torch.Tensor):
|
||
return obj.clone()
|
||
elif isinstance(obj, dict):
|
||
return {k: _clone_nested(v) for k, v in obj.items()}
|
||
elif isinstance(obj, (list, tuple)):
|
||
return type(obj)(_clone_nested(v) for v in obj)
|
||
return obj
|
||
|
||
conditions = _clone_nested(conditions)
|
||
empty_conditions = _clone_nested(empty_conditions)
|
||
|
||
# Initial noise — x_0 is the parameter we optimize
|
||
x0_init = torch.randn(
|
||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||
device=device, dtype=dtype,
|
||
)
|
||
x0 = torch.nn.Parameter(x0_init.clone())
|
||
x0_init = x0_init.detach() # anchor — kept fixed, no grad
|
||
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
||
|
||
# n_grad_steps must not exceed n_ode_steps
|
||
n_grad_steps = min(n_grad_steps, n_ode_steps)
|
||
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
|
||
|
||
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
|
||
|
||
print(f"[DITTO] Optimizing x_0 "
|
||
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
|
||
|
||
# Freeze all model weights (double-check — should already be frozen at inference)
|
||
net_generator.requires_grad_(False)
|
||
feature_utils.requires_grad_(False)
|
||
mel_converter.requires_grad_(False)
|
||
|
||
for opt_step in range(n_opt_steps):
|
||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||
|
||
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
|
||
# Detach from x0 so Phase 1 does not build a computation graph.
|
||
with torch.no_grad():
|
||
x = x0.detach()
|
||
for i in range(n_free_steps):
|
||
t = ts[i]
|
||
dt = ts[i + 1] - t
|
||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||
x = x + dt * flow
|
||
|
||
# Straight-through estimator: reconnect x to x0's gradient path by
|
||
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
|
||
# creates a grad_fn pointing back to x0, so loss.backward() will
|
||
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
|
||
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
|
||
# treated as identity for gradient purposes (truncated BPTT).
|
||
#
|
||
# x may carry an inference tensor flag from Phase 1 (derived from
|
||
# conditions which were built outside inference_mode but may have
|
||
# propagated the flag). .clone() strips it so the STE addition does
|
||
# not try to save an inference tensor for backward.
|
||
x = x.clone()
|
||
x = x + (x0 - x0.detach())
|
||
|
||
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
||
for i in range(n_free_steps, n_ode_steps):
|
||
t = ts[i]
|
||
dt = ts[i + 1] - t
|
||
|
||
# Gradient checkpointing: recompute forward during backward,
|
||
# avoiding storage of DiT activations for each step.
|
||
def _ode_step(x_in, t=t):
|
||
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
|
||
|
||
flow = torch.utils.checkpoint.checkpoint(
|
||
_ode_step, x, use_reentrant=False
|
||
)
|
||
x = x + dt * flow
|
||
|
||
# ── Style loss in latent space ───────────────────────────────────────
|
||
# Unnormalize x back to VAE latent space — fully differentiable, no
|
||
# decode needed. ref_mean/ref_gram are computed from encoded reference
|
||
# clips in the same space. Avoids backprop through VAE decoder which
|
||
# is @torch.inference_mode() and produces noisy gradients.
|
||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
||
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
||
|
||
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
|
||
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
|
||
# the ODE into an out-of-distribution region that decodes as white noise.
|
||
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
|
||
loss = style_loss + anchor_loss
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
||
optimizer.step()
|
||
|
||
pbar.update(1)
|
||
|
||
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
|
||
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
|
||
f"x0_std={x0.data.std().item():.3f}", flush=True)
|
||
|
||
# ── Final generation with optimized x_0 ─────────────────────────────────
|
||
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
||
|
||
with torch.no_grad():
|
||
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
|
||
x = x0.detach()
|
||
for i in range(steps):
|
||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||
t = fm_ts[i]
|
||
dt = fm_ts[i + 1] - t
|
||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||
x = x + dt * flow
|
||
pbar.update(1)
|
||
|
||
x1_unnorm = net_generator.unnormalize(x)
|
||
spec = feature_utils.decode(x1_unnorm)
|
||
audio = feature_utils.vocode(spec)
|
||
|
||
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
|
||
flush=True)
|
||
|
||
audio = audio.float()
|
||
if audio.dim() == 2:
|
||
audio = audio.unsqueeze(1)
|
||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||
audio = audio.mean(dim=1, keepdim=True)
|
||
|
||
if normalize:
|
||
target_rms = 10 ** (target_lufs / 20.0)
|
||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
||
audio = audio * (target_rms / rms)
|
||
peak = audio.abs().max().clamp(min=1e-8)
|
||
if peak > 1.0:
|
||
audio = audio / peak
|
||
|
||
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
|