Commit Graph

8 Commits

Author SHA1 Message Date
Ethanfel 732df151b0 fix: cast ref_mean/ref_gram to model dtype before loss computation
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires
float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to
float32, producing a float32 loss. loss.backward() then pushes float32
gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'.

Fix: clone().detach().to(dtype) at the start of _do_optimize.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:48:41 +02:00
Ethanfel 817b75df49 fix: bypass @torch.inference_mode() on decode to preserve gradient chain
feature_utils.decode and autoencoder.decode are both decorated with
@torch.inference_mode(), which unconditionally destroys grad_fn on all
outputs — making loss.backward() fail with 'does not require grad'.

Fix: call feature_utils.tod.vae.decode() directly, which has no decorator
and is fully differentiable. Transpose matches the original wrapper signature.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:44:35 +02:00
Ethanfel 1f02d73a3e fix: remove checkpoint wrapper on decode — direct call preserves grad chain
_unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving
inference-mode weight tensors during backward. Since _strip_inference() now
cleans all params/buffers before any forward pass, the checkpoint is no longer
needed and was silently breaking the gradient chain from mel_gen back to x0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:40:00 +02:00
Ethanfel fb255edaf0 fix: strip inference-mode tensor flags in DITTO before conditions computation
Root cause: net_generator/feature_utils/mel_converter parameters were loaded
in ComfyUI's inference_mode; operations on inference tensors propagate the flag,
so conditions computed from tainted weights were also tainted. checkpoint()
with use_reentrant=False then failed trying to save inference tensors during
the backward recompute pass.

Fix: _strip_inference() clones all params/buffers of all three models before
any forward pass, and _clone_nested() cleans any residual inference flags in
the conditions/empty_conditions output tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:35:15 +02:00
Ethanfel 82e449681c fix: cast mel_converter and wav to float32 before cuFFT in DITTO
cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:59:55 +02:00
Ethanfel 357b875e5e fix: strip inference tensor flags in DITTO optimizer
Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":

1. clip_f / sync_f loaded from main-thread inference_mode carry the
   inference flag. Clone them on entry to the worker thread so the
   conditions built from them are clean non-inference tensors.
   Also clone x after Phase 1 before the STE reconnection — Phase 1
   runs under no_grad and produces outputs that may still carry the
   flag through the conditions path.

2. net_generator.unnormalize + feature_utils.decode called outside
   any checkpoint wrapper with requires_grad=True input. Backward
   tried to save inference-flagged model weights. Wrapped both calls
   in checkpoint(use_reentrant=False) so they recompute on backward
   instead of storing activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:18:20 +02:00
Ethanfel 211494a91c fix: DITTO gradient never reached x0, remove unused imports and dead code
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.

Fix: use a straight-through estimator after the no-grad prefix:
  x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.

Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:10:02 +02:00
Ethanfel 1e9551152e feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes
BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:04:05 +02:00