Commit Graph

4 Commits

Author SHA1 Message Date
Ethanfel 82e449681c fix: cast mel_converter and wav to float32 before cuFFT in DITTO
cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:59:55 +02:00
Ethanfel 357b875e5e fix: strip inference tensor flags in DITTO optimizer
Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":

1. clip_f / sync_f loaded from main-thread inference_mode carry the
   inference flag. Clone them on entry to the worker thread so the
   conditions built from them are clean non-inference tensors.
   Also clone x after Phase 1 before the STE reconnection — Phase 1
   runs under no_grad and produces outputs that may still carry the
   flag through the conditions path.

2. net_generator.unnormalize + feature_utils.decode called outside
   any checkpoint wrapper with requires_grad=True input. Backward
   tried to save inference-flagged model weights. Wrapped both calls
   in checkpoint(use_reentrant=False) so they recompute on backward
   instead of storing activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:18:20 +02:00
Ethanfel 211494a91c fix: DITTO gradient never reached x0, remove unused imports and dead code
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.

Fix: use a straight-through estimator after the no-grad prefix:
  x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.

Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:10:02 +02:00
Ethanfel 1e9551152e feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes
BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:04:05 +02:00