Optimized x0 was reaching std=2.72 vs expected ~1.0 for flow matching.
An out-of-distribution initial condition maps to white noise in the output.
After each step, rescale x0 back toward unit std if it exceeds 1.5.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
mel_converter outputs float32 (cuFFT requirement), but VAE encoder weights
are bfloat16. Cast mel to dtype before encode to avoid type mismatch.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Root cause of white noise: backpropagating through vae.decode produces
unstable gradients — the VAE decoder was designed for inference only.
Fix: encode reference clips to VAE latent space once (no grad), compute
mean + Gram matrix statistics there, and compute style loss directly on
net_generator.unnormalize(x) — a single differentiable linear operation.
The gradient path is now: loss → x (unnormalized) → ODE → x0, with no
decoder in the backward pass.
Also adds VAE encoder availability check (fails cleanly if encoder was
deleted to save VRAM).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
White noise on output was caused by the Gram matrix loss pushing the latent
into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only)
and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram
gradually once mean-only optimization converges cleanly.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
optimize() does return (_result[0],) to wrap for ComfyUI. _do_optimize was
returning (dict,) instead of dict, causing double-wrapping: ((dict,),).
ComfyUI then received a tuple as audio and failed on audio["waveform"].
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires
float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to
float32, producing a float32 loss. loss.backward() then pushes float32
gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'.
Fix: clone().detach().to(dtype) at the start of _do_optimize.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
feature_utils.decode and autoencoder.decode are both decorated with
@torch.inference_mode(), which unconditionally destroys grad_fn on all
outputs — making loss.backward() fail with 'does not require grad'.
Fix: call feature_utils.tod.vae.decode() directly, which has no decorator
and is fully differentiable. Transpose matches the original wrapper signature.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
_unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving
inference-mode weight tensors during backward. Since _strip_inference() now
cleans all params/buffers before any forward pass, the checkpoint is no longer
needed and was silently breaking the gradient chain from mel_gen back to x0.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Root cause: net_generator/feature_utils/mel_converter parameters were loaded
in ComfyUI's inference_mode; operations on inference tensors propagate the flag,
so conditions computed from tainted weights were also tainted. checkpoint()
with use_reentrant=False then failed trying to save inference tensors during
the backward recompute pass.
Fix: _strip_inference() clones all params/buffers of all three models before
any forward pass, and _clone_nested() cleans any residual inference flags in
the conditions/empty_conditions output tensors.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":
1. clip_f / sync_f loaded from main-thread inference_mode carry the
inference flag. Clone them on entry to the worker thread so the
conditions built from them are clean non-inference tensors.
Also clone x after Phase 1 before the STE reconnection — Phase 1
runs under no_grad and produces outputs that may still carry the
flag through the conditions path.
2. net_generator.unnormalize + feature_utils.decode called outside
any checkpoint wrapper with requires_grad=True input. Backward
tried to save inference-flagged model weights. Wrapped both calls
in checkpoint(use_reentrant=False) so they recompute on backward
instead of storing activations.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.
Fix: use a straight-through estimator after the no-grad prefix:
x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.
Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>