189 Commits

Author SHA1 Message Date
Ethanfel 40d29bcaf8 feat: add experiment configs for logit+cosine combo and BigVGAN decoder fine-tuning
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 16:48:21 +02:00
Ethanfel 65dc549494 feat: add reference audio comparison metrics to LoRA trainer eval
New _reference_metrics() computes LSD, MCD, and per-band correlation
between eval samples and the original source audio at each checkpoint.
Loads reference audio once before the training loop and logs metrics
alongside existing spectral metrics.

Also fix batch_size in lora_optimized_dataset.json (4 -> 16).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 15:04:07 +02:00
Ethanfel f745e241c4 chore: sanitize tooltips/comments + add experiment configs
- Replace all BJ references with generic "target style/audio" in
  activation steering, DITTO optimizer, and BigVGAN trainer
- Add latent_mixup_alpha/latent_noise_sigma to LoRA scheduler defaults
- Add bigvgan_disc_fm_retest.json and lora_optimized_dataset.json

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 13:44:37 +02:00
Ethanfel 082a2da438 fix: restore dtype after float32 STFT in discriminator spectrogram
torch.stft requires float32 input, but the .float() cast was not
reversed before the spectrogram hit bfloat16 Conv2d weights. Save
the original dtype and cast back after abs().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 12:13:55 +02:00
Ethanfel c28e090196 fix: cast discriminator inputs to match bfloat16 dtype in BigVGAN FM loss
The frozen discriminators are loaded in model dtype (bfloat16) but vocoder
waveform outputs are float32, causing a Conv2d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 11:36:02 +02:00
Ethanfel af6c225f53 feat: add dataset pipeline nodes + latent augmentation for LoRA trainer
New dataset pipeline nodes:
- SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution
- SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility
- SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking

Improvements:
- Inspector: silence detection (max_silence_fraction param)
- Saver: origin_name lookup for augmented clips' npz pairing
- LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization
- LoRA trainer: one-time SR mismatch warning in _load_audio

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 11:32:00 +02:00
Ethanfel 30127c13ca feat: add BigVGAN vocoder sweep scheduler node
Runs a series of BigVGAN fine-tuning experiments from a JSON sweep file.
Audio clips loaded once, vocoder deep-copied per experiment, results
collected in experiment_summary.json with comparison loss curves.
Resume-aware — skips completed experiments on re-run.

Includes overnight sweep config (8 experiments): snake alpha steps,
GAFilter ablation, phase loss weight, discriminator FM, all_params.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 02:39:56 +02:00
Ethanfel 4226297735 chore: remove debug VRAM logging
Training confirmed working — VRAM usage is normal backward-pass
activation memory, not a leak. Removed all debug _vram_log and _vram
calls. Kept the video_enc offload and torch.cuda.empty_cache fixes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:50:08 +02:00
Ethanfel 4297715a08 debug: add driver-level VRAM reporting + offload video_enc
torch.cuda.memory_allocated only tracks PyTorch allocator. Added
torch.cuda.mem_get_info to see actual CUDA driver memory usage.
Also offload video_enc (TextSynch) which was missed in the original
offload — stays on GPU when strategy != offload_to_cpu.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:48:04 +02:00
Ethanfel 9af4bbdd91 fix: force torch.cuda.empty_cache() after pre-generation and CLIP encoding
PyTorch's caching allocator reserves GPU memory from pre-generation
(~90 GiB for generator + tod) and doesn't return it to CUDA/OS.
soft_empty_cache may not call torch.cuda.empty_cache(). Force a full
cache release after CLIP encoding and after LoRA mel pre-generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:42:45 +02:00
Ethanfel 89d6fccd28 debug: add per-operation VRAM logging in first training step
Logs VRAM at: after target_mel, after vocoder forward, before loss,
after loss computation, and after backward. Only logs for step 0 to
avoid spam. Will identify which operation causes the 94 GiB spike.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:35:54 +02:00
Ethanfel bd84242fa1 debug: add VRAM logging at offload and training checkpoints
Logs torch.cuda.memory_allocated/reserved at each step: before unload,
after unload_all_models, after feature_utils.to(cpu), after generator
to(cpu), after cache clear, after mel_converter to(device), and before
training loop. This will identify what's holding VRAM.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:28:31 +02:00
Ethanfel 5a2c003fb2 fix: move baseline sample after inference flag stripping
_save_sample("baseline") was called before the vocoder's inference
tensors were sanitized, causing "Inference tensors do not track version
counter". Moved it after the clone/detach loop and vocoder.to(device).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:26:11 +02:00
Ethanfel f8d4d77b0d fix: pre-compute text CLIP embeddings in main thread to avoid inference tensor crash
CLIP weights are inference tensors from ComfyUI loading. inference_mode
is thread-local, so the worker thread can't use CLIP even with a context
manager. Pre-compute all text embeddings in the main thread (where
inference_mode IS active), clone+detach to normal tensors, and pass them
to the worker via text_clip_cache dict. CLIP no longer needs to be on
GPU during pre-generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:19:44 +02:00
Ethanfel 32e5344ea2 fix: wrap CLIP encoding in inference_mode during pre-generation
CLIP weights are inference tensors from ComfyUI loading. The worker
thread runs without inference_mode, so PyTorch rejects inference tensors
in multi_head_attention_forward (version counter tracking). Wrap the
encode_text_clip call in torch.inference_mode() since text encoding
doesn't need gradients.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:10:58 +02:00
Ethanfel 10a71b0c4f fix: offload entire model to CPU in main thread before worker starts
The previous offload ran inside the worker thread, but by then ComfyUI
had already loaded the full model to GPU. Now feature_utils.to('cpu')
and generator.to('cpu') run in the main thread right after
unload_all_models(), before the worker starts. vocoder.to(device, dtype)
is called explicitly after inference flag stripping in _do_train to
bring only the vocoder back to GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:56:13 +02:00
Ethanfel 37a27160aa fix: match mel dtype to vocoder in baseline sample generation
ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16
before inference flag stripping. Cast mel to vocoder's dtype to prevent
input/bias type mismatch during baseline sample save.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:45:31 +02:00
Ethanfel cb9a1eef01 fix: stop loading full feature_utils to GPU before training
feature_utils.to(device) was loading CLIP ViT-H, synchformer, T5, VAE,
and vocoder (~90 GiB) to GPU for the entire training run. Now only
mel_converter (tiny) is moved to GPU. Pre-generation manages its own
device placement: temporarily moves CLIP and tod to GPU, then moves them
back when done. This frees ~90 GiB for the backward pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:44:38 +02:00
Ethanfel d70c611bf7 fix: offload CLIP, synchformer, T5, generator, VAE to CPU before training
Only the vocoder and mel_converter are needed during BigVGAN training.
The rest of the SelVA pipeline (CLIP ViT-H, synchformer, T5, generator,
VAE) was staying on GPU and consuming ~90 GiB, leaving no room for
backward pass activations. Now offloaded individually to CPU before
the training loop starts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:33:07 +02:00
Ethanfel 4e6cc4d519 feat: cache pre-generated LoRA mels to disk for reuse
LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow.
Cache results to a .pt file next to the output, keyed by a SHA-256 hash
of the LoRA adapter content + generation parameters (seed, steps, CFG,
duration, sample rate, npz file list). Automatically reused on subsequent
runs when parameters haven't changed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:30:20 +02:00
Ethanfel 0854bd2638 fix: cast discriminators to model dtype to match vocoder output
Discriminators are constructed as float32 but receive bfloat16 tensors
from the vocoder. Cast to model dtype on load to prevent conv dtype
mismatch in feature matching loss.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:25:04 +02:00
Ethanfel 187b2e3169 fix: cast GAFilter to model dtype after injection
GAFilter conv weights are created as float32 but the rest of the vocoder
is bfloat16. vocoder.to(device) missed the dtype cast, causing conv1d
dtype mismatch when Snake bfloat16 output flows into GAFilter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:24:11 +02:00
Ethanfel 608746ce7b fix: cast input mel to model dtype before vocoder forward pass
mel_converter outputs float32 (cuFFT requirement) but vocoder weights are
bfloat16 from model loading. Cast input_mel back to model dtype before
feeding the vocoder to avoid conv1d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:18:05 +02:00
Ethanfel bba5aec7a5 fix: add CFG to LoRA mel pre-generation to match inference conditions
Pre-generated mels were using a bare forward pass with no classifier-free
guidance, producing mels that don't match what the vocoder sees at inference
(where cfg_strength=4.5 is the default). Now uses ode_wrapper with
preprocess_conditions/get_empty_conditions, same as the sampler node.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:17:16 +02:00
Ethanfel d06936802b fix: cast mel_converter buffers to float32 to match STFT input dtype
mel_basis and hann_window buffers inherit bfloat16 from model loading.
Since all mel_converter inputs are cast to float32 for cuFFT, the
internal buffers must also be float32 to avoid matmul dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:10:52 +02:00
Ethanfel bee518a855 fix: cast all STFT inputs to float32 to prevent cuFFT bfloat16 crash
cuFFT does not support bfloat16 tensors. When the model is loaded in
bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram,
multi-resolution STFT loss) crash. Add .float() at every STFT boundary.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 23:53:36 +02:00
Ethanfel 48b72c0be0 feat: add LoRA mel pre-generation to BigVGAN vocoder trainer
When a lora_adapter path is provided, the trainer pre-generates
LoRA-distorted mels for each training clip (full ODE generation +
VAE decode) and trains the vocoder to produce clean audio from them.
This teaches the vocoder to compensate for LoRA latent distribution
shift without requiring perfectly aligned training pairs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 23:26:36 +02:00
Ethanfel e16480b4c9 feat: add PiSSA/rsLoRA support to scheduler and PiSSA sweep experiment
Thread init_mode and use_rslora through the scheduler's config parsing,
experiment record, and _train_inner call. Default alpha changed to 2*rank
to match trainer. Add pissa_sweep.json with 7 experiments ablating PiSSA
init vs standard, rsLoRA scaling, and learning rate variations at rank 128.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 22:07:27 +02:00
Ethanfel 784fb2753f feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem:

1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of
   pretrained weight. Starts on-manifold, eliminates intruder dimensions
   at init. Base weight stores residual W_res = W - B@A*scale.

2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of
   alpha/rank. Prevents gradient collapse at high ranks (128+).

3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained
   LoRA update, gradient-sensitivity reweighting to suppress remaining
   intruder dimensions. Runs automatically after training completes.

4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder
   dimensions per arXiv:2410.21228.

5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents
   erasing learned style weights).

6. random.choices replaced with random.sample when batch_size <= dataset
   size (eliminates duplicate samples per batch).

PiSSA checkpoints include base weights (residual). Loader/evaluator
updated to handle both standard and PiSSA checkpoint formats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 21:54:36 +02:00
Ethanfel ecf828b007 fix: move vocoder to correct device after GAFilter injection
inject_gafilters creates Conv1d modules on CPU. load_state_dict
preserves existing param devices but GAFilter params stay on CPU,
causing device mismatch during vocode. Save target device before
injection, then move entire vocoder after loading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 20:28:55 +02:00
Ethanfel 793368af18 fix: strip inference flag before unnormalize in LoRA trainer eval
x1_pred is an inference tensor (computed from inference-mode weights
loaded by ComfyUI). generator.unnormalize() uses in-place mul_/add_
which fails on inference tensors. Clone strips the flag.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 20:01:53 +02:00
Ethanfel 1d1ae61409 fix: move only VAE+vocoder to GPU during eval to prevent device mismatch
The previous check (next(feature_utils_orig.parameters()).device) only
inspected the first parameter (from CLIP), missing CPU-stranded vocoder
weights when the module was in a mixed-device state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 19:36:02 +02:00
Ethanfel 8fa2699551 fix: correct DITTO reference latent space mismatch
References were stored in normalized flow-matching space
(net_generator.normalize(z_sample)) but the style loss compares against
unnormalize(x) which is in VAE latent space. The optimizer was minimizing
L1 between tensors at different scales, pushing the ODE endpoint out of
distribution and producing noise.

Fix: store reference latents in VAE space (z_sample directly) so both
ref_mean/ref_gram and x_un are in the same coordinate system.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:57:08 +02:00
Ethanfel 14fabf01f9 fix: reduce opt_lr step to 0.001 to allow finer lr control in DITTO
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:40:21 +02:00
Ethanfel 445da1e69b fix: replace std clamp with anchor regularization to prevent OOD noise
The std clamp was post-hoc and only addressed magnitude, not direction.
x0 was drifting to mean=-0.55/std=3.1 (ODE expected mean=0/std=1).

Replace with anchor_weight * MSE(x0, x0_init) added directly to the loss.
The optimizer now balances style matching against staying near the initial
N(0,1) noise — gradient-aware, prevents both magnitude and mean drift.

Also logs style/anchor losses and x0_std per step for diagnostics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:30:05 +02:00
Ethanfel fa6c4fa834 fix: clamp x0 std after each optimizer step to prevent OOD noise
Optimized x0 was reaching std=2.72 vs expected ~1.0 for flow matching.
An out-of-distribution initial condition maps to white noise in the output.
After each step, rescale x0 back toward unit std if it exceeds 1.5.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:23:39 +02:00
Ethanfel 286681edff fix: cast mel to model dtype before VAE encode in DITTO reference loading
mel_converter outputs float32 (cuFFT requirement), but VAE encoder weights
are bfloat16. Cast mel to dtype before encode to avoid type mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:18:41 +02:00
Ethanfel 056a7b973d fix: enable VAE encoder in model loader — required for DITTO reference encoding
need_vae_encoder=False was deleting the encoder to save a small amount of VRAM.
DITTO now needs it to encode reference clips to latent space for style loss.
The spectrogram VAE encoder is small enough that the overhead is negligible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:15:27 +02:00
Ethanfel 633fe36fbb fix: compute DITTO style loss in latent space to eliminate VAE decoder noise
Root cause of white noise: backpropagating through vae.decode produces
unstable gradients — the VAE decoder was designed for inference only.

Fix: encode reference clips to VAE latent space once (no grad), compute
mean + Gram matrix statistics there, and compute style loss directly on
net_generator.unnormalize(x) — a single differentiable linear operation.
The gradient path is now: loss → x (unnormalized) → ODE → x0, with no
decoder in the backward pass.

Also adds VAE encoder availability check (fails cleanly if encoder was
deleted to save VRAM).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:12:31 +02:00
Ethanfel 8862089fd0 fix: remove 32-clip cap on DITTO reference loading — use all available clips
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:10:10 +02:00
Ethanfel 608e7df04b feat: add gram_weight param to DITTO, reduce default style_weight to 0.1
White noise on output was caused by the Gram matrix loss pushing the latent
into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only)
and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram
gradually once mean-only optimization converges cleanly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:03:32 +02:00
Ethanfel 101b1bdb41 fix: _do_optimize returns dict not tuple — prevent double-wrapping AUDIO output
optimize() does return (_result[0],) to wrap for ComfyUI. _do_optimize was
returning (dict,) instead of dict, causing double-wrapping: ((dict,),).
ComfyUI then received a tuple as audio and failed on audio["waveform"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:56:59 +02:00
Ethanfel 732df151b0 fix: cast ref_mean/ref_gram to model dtype before loss computation
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires
float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to
float32, producing a float32 loss. loss.backward() then pushes float32
gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'.

Fix: clone().detach().to(dtype) at the start of _do_optimize.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:48:41 +02:00
Ethanfel 817b75df49 fix: bypass @torch.inference_mode() on decode to preserve gradient chain
feature_utils.decode and autoencoder.decode are both decorated with
@torch.inference_mode(), which unconditionally destroys grad_fn on all
outputs — making loss.backward() fail with 'does not require grad'.

Fix: call feature_utils.tod.vae.decode() directly, which has no decorator
and is fully differentiable. Transpose matches the original wrapper signature.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:44:35 +02:00
Ethanfel 1f02d73a3e fix: remove checkpoint wrapper on decode — direct call preserves grad chain
_unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving
inference-mode weight tensors during backward. Since _strip_inference() now
cleans all params/buffers before any forward pass, the checkpoint is no longer
needed and was silently breaking the gradient chain from mel_gen back to x0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:40:00 +02:00
Ethanfel fb255edaf0 fix: strip inference-mode tensor flags in DITTO before conditions computation
Root cause: net_generator/feature_utils/mel_converter parameters were loaded
in ComfyUI's inference_mode; operations on inference tensors propagate the flag,
so conditions computed from tainted weights were also tainted. checkpoint()
with use_reentrant=False then failed trying to save inference tensors during
the backward recompute pass.

Fix: _strip_inference() clones all params/buffers of all three models before
any forward pass, and _clone_nested() cleans any residual inference flags in
the conditions/empty_conditions output tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:35:15 +02:00
Ethanfel 8ccc2438e4 fix: remove FlashSR (audiosr incompatible with Python 3.12), add training loss CSV
- Drop SelvaFlashSR node — audiosr pins numpy<=1.23.5 which cannot build
  on Python 3.12 (pkgutil.ImpImporter removed); use Saganaki22/ComfyUI-AudioSR instead
- BigVGAN trainer now writes <output_stem>_training_log.csv alongside the
  checkpoint: step, total, fm, mel, stft, phase, l2sp columns, line-buffered
  so loss can be tailed live during training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:18:34 +02:00
Ethanfel 8371466e44 fix: guarantee length preservation in _ActivationWithGAFilter
Activation1d's anti-alias Kaiser sinc resampling (asymmetric pad_left /
pad_right) can produce ±1-2 sample rounding in edge cases, causing the
BigVGAN AMPBlock residual addition (xt + x) to fail with a size mismatch.

Trim or pad the output to exactly match the input length so the resblock
skip connection always has matching dimensions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:39:03 +02:00
Ethanfel ba0499b77c fix: FlashSR device handling and remove unused tmp_out
Use device="auto" for audiosr.build_model — safer than passing a device
string that may not be accepted in all audiosr versions.
Remove unused tmp_out temp file that was created but never written to.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:32:02 +02:00
Ethanfel ce62bccc1f feat: add post-generation audio enhancement nodes
Three new nodes for post-generation quality improvement:

- SelvaHarmonicExciter: multi-band exciter (HPF → tanh saturation → mix)
  restores harmonic richness lost in BigVGAN HF reconstruction

- SelvaFlashSR: audio super-resolution via FlashSR basic model
  (haoheliu/versatile_audio_super_resolution, requires pip install audiosr)
  predicts missing HF content above vocoder reconstruction ceiling

- SelvaOutputNormalizer: BS.1770-4 LUFS normalization + true peak limiting
  for consistent loudness on generated outputs (pyloudnorm)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:27:39 +02:00
Ethanfel 45fced55bc fix: exclude GAFilter params from L2-SP regularization
L2-SP anchors trainable params to their pretrained values. GAFilter is a
newly initialized module (identity FIR filter) with no pretrained values —
anchoring it to identity initialization would resist learning. Exclude
gafilter params from the L2-SP loss so they train freely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:19:52 +02:00
Ethanfel db112394e8 feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader
Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel
depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN
residual blocks. Initialized as identity so training starts from pretrained
behaviour.

- inject_gafilters() walks resblocks.*.activations and wraps each Activation1d
  with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically
- Trained alongside Snake alphas in snake_alpha_only mode
- Checkpoint saves has_gafilter + gafilter_kernel_size metadata
- Loader detects metadata and injects before load_state_dict so weights populate correctly
- Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:15:14 +02:00
Ethanfel c53ea5517c feat: add FA-GAN phase-aware STFT loss to BigVGAN trainer
Adds L1 loss on real, imaginary, and magnitude STFT components across
three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing
directly — magnitude-only losses cannot distinguish correct spectrum
with wrong phase from a smeared spectrum.

Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top
of both the discriminator FM path and the fallback mel+STFT path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:09:31 +02:00
Ethanfel 82e449681c fix: cast mel_converter and wav to float32 before cuFFT in DITTO
cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:59:55 +02:00
Ethanfel 15fc5f0793 feat: add SelvaDatasetCompressor node for parallel compression
Mild 2:1-3:1 parallel compression via pedalboard.Compressor to reduce
within-clip loudness variance after LUFS normalization. Blend ratio
keeps transients intact while tightening dynamics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:36:27 +02:00
Ethanfel 48493a3f0d feat: add SelvaDatasetSaver node with NPZ sidecar copy
Saves all clips in an AUDIO_DATASET to FLAC. When npz_source_dir is
provided, copies the matching .npz for each clip so FLAC/NPZ pairs
stay in sync after the inspector filters out bad clips.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:27:48 +02:00
Ethanfel becb38c27e fix: use soundfile for WAV/FLAC/OGG to bypass torchcodec/FFmpeg dependency
torchaudio was defaulting to the torchcodec backend which requires FFmpeg
shared libraries not present in the ComfyUI venv, silently skipping every
clip and producing an empty dataset.

Also add experiments/vocoder_finetune.json for the BJ vocoder LoRA run
(lr=3e-4, rank=128, 10k steps).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:16:22 +02:00
Ethanfel b9f95cfd7e fix: detect silent discriminator load failure and fall back explicitly
If no matching key was found for MPD or MRD in the checkpoint, the for-loops
completed silently and randomly-initialized discriminators were used as frozen
feature extractors — producing meaningless feature matching loss while
appearing to work. Now raises RuntimeError (caught by outer except) which
triggers the existing fallback to mel+STFT losses with a clear warning.
Also prints available checkpoint keys to help diagnose format mismatches.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:39:55 +02:00
Ethanfel f50afa9796 fix: guard _estimate_snr against short clips, fix freqs device in _check_hf_shelf
Bug 1: mono.unfold(0, 2048, 512) returns an empty tensor for clips shorter
than 2048 samples (~46ms). torch.quantile on an empty tensor crashes with
"quantile() input tensor must be non-empty". Guard: return 60.0 (assume
clean) for clips too short to frame — the pipeline has no minimum-length
filter so any short file in the dataset folder would crash the Inspector.

Bug 2: torch.linspace(...) in _check_hf_shelf created a CPU tensor, making
band_lo/band_hi CPU boolean masks. Indexing a GPU mag_sq tensor with CPU
masks crashes. Pass device=mono.device so freqs lands on the same device
as the audio.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:28:36 +02:00
Ethanfel 8a85819f97 feat: register audio dataset pipeline nodes in __init__.py 2026-04-09 14:25:57 +02:00
Ethanfel f1c4654bab feat: add SelvaDatasetItemExtractor node 2026-04-09 14:24:58 +02:00
Ethanfel 2d06cb2f52 fix: pass device to hann_window in _check_hf_shelf to avoid GPU mismatch 2026-04-09 14:22:13 +02:00
Ethanfel 0731addea9 feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:20:03 +02:00
Ethanfel 7eb9bd5745 feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:17:44 +02:00
Ethanfel 057bfb813d feat: add SelvaDatasetResampler node (soxr VHQ)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:13:45 +02:00
Ethanfel 2c71d4c184 feat: add SelvaDatasetLoader node
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:09:43 +02:00
Ethanfel d25df10aa5 feat: add audio dataset pipeline skeleton 2026-04-09 14:05:31 +02:00
Ethanfel d70a4d2123 docs: add audio dataset pipeline implementation plan 2026-04-09 14:02:46 +02:00
Ethanfel 2b10205657 fix: raise segment_seconds max from 4s to 30s
Hardcoded max of 4.0 prevented using full 8s clips. Raised to 30s.
Also bumped default from 1.0 to 2.0 as a more sensible starting point.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:49:50 +02:00
Ethanfel 8166c56552 perf: gradient checkpointing on vocoder forward to reduce activation memory
BigVGAN's 512x upsampling stack stores huge intermediate activations for
backward even in snake_alpha_only mode (only 5K trainable params, but
activation graph runs through the full network after each snake op).

Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations
during backward instead of storing them — ~2x compute cost, large reduction
in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:45:24 +02:00
Ethanfel eece79ccae fix: correct MRD channel width to 128 and unload models before training
Two bugs:

1. _DiscriminatorR used channels=32 but the BigVGAN pretrained discriminator
   checkpoint has channels=128. All convs in _DiscriminatorR now use 128,
   matching the checkpoint architecture so state_dict loads without error.

2. BigVGAN trainer OOM: SelVA generator and other ComfyUI models remain in
   VRAM during training (~90 GiB used). Add unload_all_models() + cache
   flush before the training loop to reclaim VRAM headroom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:40:01 +02:00
Ethanfel 357b875e5e fix: strip inference tensor flags in DITTO optimizer
Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":

1. clip_f / sync_f loaded from main-thread inference_mode carry the
   inference flag. Clone them on entry to the worker thread so the
   conditions built from them are clean non-inference tensors.
   Also clone x after Phase 1 before the STE reconnection — Phase 1
   runs under no_grad and produces outputs that may still carry the
   flag through the conditions path.

2. net_generator.unnormalize + feature_utils.decode called outside
   any checkpoint wrapper with requires_grad=True input. Backward
   tried to save inference-flagged model weights. Wrapped both calls
   in checkpoint(use_reentrant=False) so they recompute on backward
   instead of storing activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:18:20 +02:00
Ethanfel 211494a91c fix: DITTO gradient never reached x0, remove unused imports and dead code
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.

Fix: use a straight-through estimator after the no-grad prefix:
  x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.

Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:10:02 +02:00
Ethanfel 1e9551152e feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes
BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:04:05 +02:00
Ethanfel f17f6f0863 feat: save ground truth spectrogram once for direct comparison
Writes _gt_spec.png from ref_mel before training starts so each step's
_spec.png can be compared against the unmodified vocoder roundtrip target.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:05:47 +02:00
Ethanfel 304d9d01bf feat: save mel spectrogram PNG alongside each eval sample
Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample
call now writes both a .wav and a _spec.png so training progress is visible
without listening. Colour map is blue→green→yellow (viridis-ish), low
frequencies at the bottom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:03:28 +02:00
Ethanfel 0128a81cc2 fix: use full first clip for eval samples instead of 1s segment
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:01:52 +02:00
Ethanfel 710261f5be fix: add soundfile fallback for torchaudio.save in sample writing
Same environment has no compatible ffmpeg/torchcodec for saving.
Mirror the _load_wav pattern: try torchaudio, fall back to soundfile.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:58:07 +02:00
Ethanfel 5df2abd6dd fix: handle all three inference-tensor sources in vocoder sanitization
remove_parametrizations() stores weight as a plain __dict__ tensor (not
nn.Parameter), making it invisible to _parameters iteration. Also, buffers
(Activation1d anti-aliasing filters) are inference tensors that break the
backward graph mid-network. Fix all three categories:
1. _parameters: clone().detach(), wrap as Parameter
2. plain __dict__ tensors: clone(), register_parameter (also makes trainable)
3. _buffers: clone() to strip inference flag without parametrizing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:54:41 +02:00
Ethanfel b243908873 debug: inspect conv_pre parametrizations and _parameters keys
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:46:16 +02:00
Ethanfel 9df855ee0e debug: print is_inference() status before failing conv_pre call
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:41:51 +02:00
Ethanfel 78f8aa98ad fix: clone inference tensors at thread entry to strip the inference flag
torch.inference_mode is thread-local, but the inference flag lives on the
tensor object. Operations on inference tensors always propagate it, even in
a clean thread. The only escape is .clone() called outside inference_mode.
At thread entry (inference_mode disabled): clone clips and mel_converter
buffers to get clean normal tensors before any training computation.
Vocoder parameter clone() also now works correctly in this thread context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:35:48 +02:00
Ethanfel e870446b0f fix: run BigVGAN training in a fresh thread to escape inference_mode
torch.inference_mode is thread-local. ComfyUI sets it on the node-execution
thread; inference_mode(False) alone is insufficient to escape it in some
environments (e.g. async wrappers, lora-manager hook). A new thread always
starts clean. Moved all training logic into _do_train() called via
threading.Thread so every tensor is a normal autograd tensor by default.
Simplified parameter cloning: clone().detach().requires_grad_(True).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:30:53 +02:00
Ethanfel df63b147e9 fix: sanitize all submodule buffers of mel_converter + guarantee target_mel output
Previous fix only iterated mel_converter._buffers (direct buffers). Submodules
(e.g. Spectrogram.window) still held inference tensors. Switch to .modules()
to cover all nested buffers, matching the vocoder parameter sanitization.
Also add a zeros+copy_ safety net on target_mel output so conv can save it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:14:12 +02:00
Ethanfel 51ac099073 fix: sanitize target_flat — clips are inference tensors from outer inference_mode
The clips list is built inside ComfyUI's inference_mode context, so every
element is an inference tensor. torch.stack().clone() propagates the flag.
Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor,
so mel_converter(target_flat) inside no_grad produces a saveable input.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:09:26 +02:00
Ethanfel b7565ec458 fix: sanitize inference tensors in BigVGAN trainer via zeros+copy_ pattern
param.data.clone() and tensor.detach() on inference tensors both produce
inference tensors — the flag propagates through all operations on them.
Inside inference_mode(False), torch.zeros() creates genuine normal tensors.
Use zeros+copy_ to sanitize both vocoder parameters and mel_converter
buffers once before training, so autograd can save inputs for backward.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:05:36 +02:00
Ethanfel 0fcb6d3106 fix(bigvgan-trainer): replace parameter objects to fully strip inference tensor flag
param.data = clone() only replaces storage — the nn.Parameter object itself
retains the inference tensor flag set when the model was loaded. Replace each
parameter with a fresh nn.Parameter(data.clone()) created inside
inference_mode(False) so both the object and its data are normal tensors.
Move optimizer creation to after re-creation so it references the new objects.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:58:57 +02:00
Ethanfel c86306bde8 fix(bigvgan-trainer): clone vocoder parameters to strip inference tensor flag
The vocoder is loaded inside ComfyUI's torch.inference_mode(), making all
its parameters inference tensors. Autograd cannot save inference tensors
for backward even with requires_grad=True. Clone all parameters inside
torch.inference_mode(False) before training to get normal tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:55:16 +02:00
Ethanfel f04d59fe63 fix(bigvgan-trainer): clone mel outputs to strip inference tensor flag from buffers
mel_converter buffers (mel_basis, hann_window) are inference tensors
because the model was loaded inside ComfyUI's torch.inference_mode().
Operations on them propagate the flag to outputs. Clone both target_mel
and pred_mel to get normal autograd-compatible tensors. .clone() is
differentiable so the grad graph to vocoder parameters is preserved.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:51:28 +02:00
Ethanfel daa36a5f7b fix(bigvgan-trainer): clone target tensor to exit inference mode before backward
Clips loaded outside torch.inference_mode(False) are inference tensors.
Autograd cannot save them for backward. .clone() creates a normal tensor,
same fix pattern as selva_lora_trainer's dist.mode().clone().

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:47:47 +02:00
Ethanfel 16e20b30ce fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers
Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded
from disk is float32, causing matmul dtype mismatch. Cast all audio
tensors to model["dtype"] before passing to mel_converter/vocoder.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:46:01 +02:00
Ethanfel ea7dfed27a fix(bigvgan-trainer): fallback to soundfile when torchaudio ffmpeg backend fails
torchcodec/libavutil soname mismatch causes torchaudio to fail on every
file load, silently emptying clips. Add _load_wav() that tries torchaudio
first then falls back to soundfile (handles wav/flac without ffmpeg).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:41:59 +02:00
Ethanfel 81ff0d46c9 fix(bigvgan-trainer): resolve device mismatch in _save_sample after offload
After the finally block, offload_to_cpu moves the vocoder to CPU while
ref_mel stays on GPU. Fix: detect vocoder's current device via
next(vocoder.parameters()).device and move ref_mel there before vocoding.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:35:07 +02:00
Ethanfel 9fdeb65182 feat(bigvgan-trainer): add eval samples at checkpoints and end
Saves baseline.wav (ground truth roundtrip before training), stepN.wav
at each save_every checkpoint, and final.wav after training completes.
All use the same fixed reference segment (clip 0, position 0) for
direct comparison across checkpoints.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:30:34 +02:00
Ethanfel 790a53e3df fix(bigvgan): add 44k/BigVGANv2 support to trainer and loader
44k variants use BigVGANv2 directly as the vocoder (no wrapper, no
@inference_mode decorator), accessible at feature_utils.tod.vocoder.
16k wraps BigVGANVocoder inside BigVGAN, accessed at .vocoder.vocoder.
Both trainer and loader now branch on model["mode"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:28:32 +02:00
Ethanfel 9c784b4bdb feat: add BigVGAN vocoder fine-tuner and loader nodes
Spectral-loss-only fine-tuning of the BigVGAN vocoder (mel→waveform)
on BJ audio clips. DiT and VAE are completely frozen. Losses: mel L1
reconstruction + multi-resolution STFT magnitude L1 (same three
resolutions as the BigVGAN discriminator config). Saves in
{'generator': state_dict} format compatible with the original BigVGAN
checkpoint. Loader replaces vocoder weights in the loaded SELVA_MODEL
in-place so no full model reload is needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:26:12 +02:00
Ethanfel 115a0c3718 feat(steering): conditional-only injection + per-position vectors
Two improvements for stronger steering effect:

1. Apply steering only during the conditional predict_flow pass by
   monkey-patching predict_flow to set a flag via identity check
   (cond is conditions). Hooks skip the unconditional pass, so
   steering is amplified by cfg_strength (~4.5x) instead of canceling
   out in the CFG guidance term.

2. Restore per-position [seq, hidden] steering vectors instead of
   seq-averaged [hidden]. More spatially specific — captures positional
   activation patterns rather than a global mean. Seq length mismatch
   at inference time handled via linear interpolation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:02:51 +02:00
Ethanfel 95923cdf42 feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual
inversion. Extractor runs frozen generator on dataset with BJ vs empty
conditions, records mean hidden-state delta per block, saves [hidden_dim]
vectors (seq-averaged so they broadcast to any inference duration). Loader
reads the bundle. Sampler registers forward hooks during the ODE that add
strength × vec to each block output, cleaned up in a finally block.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:38:26 +02:00
Ethanfel 28ee3db337 feat(sampler): add ti_strength blend for TI injection
TI via text conditioning produces buzz because SelVA's text path is
mean-pooled into a global DiT bias — not rich per-token cross-attention
like SD. The optimizer learns a constant spectral artifact rather than
semantic style shift.

ti_strength=1.0 (default) = full injection as before.
ti_strength<1.0 = lerp between original and injected text_clip,
allowing the effect to be dialled back without retraining.
Applies to both text_clip and neg_text_clip symmetrically.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:07:57 +02:00
Ethanfel b89167cfae fix(ti-trainer): clamp token norm to CLIP manifold to prevent buzz artifacts
Diagnosis: learned tokens grew to norm ~3.2 while real CLIP content tokens
sit at ~1.0. Model never trained on embeddings that large — activates buzz
artifact instead of semantic style shift.

Fix: measure mean token norm from content positions (1–20) of dataset CLIP
embeddings at startup, clamp learned_tokens per-token after every optimizer
step to max 1.5× that reference (50% headroom). Token norm is now logged
as current/limit for easy monitoring.

ti_sweep_1.json: rebuild around norm_clamp group — n4_clamped (primary
diagnostic), prefix_clamped, n8_prefix_clamped, warm_clamped.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:54:23 +02:00
Ethanfel f9d092158a fix(ti): lower default lr/batch, add lr_batch sweep group
n4_baseline showed token_norm growing linearly without plateau — classic
sign of lr too high relative to parameter count. With only K×1024 params,
gradient signal per param is already high-magnitude; high lr causes
overshoot rather than convergence.

- Default lr: 1e-3 → 2e-4 (matches LoRA working regime)
- Default batch_size: 16 → 4 (more diverse gradients, helps norm saturate)
- ti_sweep_1.json: add lr_batch group (lr_low_b4, lr_mid_b8,
  lr_low_b4_prefix, lr_2e3), restructure with clearer groups,
  annotate n4_baseline as completed with findings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:42:22 +02:00
Ethanfel 92535deab2 fix(ti-scheduler): save comparison image after each completed experiment
Previously the comparison PNG was only written at the very end of the sweep,
so an interrupted run produced no image at all. Now _save_comparison() is
called right after _write_summary() for every successful experiment, keeping
loss_comparison.png current throughout the sweep.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:39:30 +02:00
Ethanfel 0b24207ca5 feat(ti-trainer): generate baseline.wav once before training starts
Saves baseline.wav + baseline.png in the checkpoint dir using the same
seed as the TI eval samples — direct A/B comparison at every checkpoint
without re-generating the baseline each time.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:33:28 +02:00
Ethanfel e1a2f0ed7d feat: add inject_mode (suffix/prefix) to TI pipeline
Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps),
token_norm grew linearly without plateau — generator likely ignores last-K
CLIP positions (EOS/padding zone) where suffix injects.

Fix: add inject_mode parameter throughout the pipeline:
- "suffix": replace last K positions (original behavior, model may ignore)
- "prefix": replace positions 1:1+K right after BOS — highest attention
  weight in CLIP, much stronger gradient signal expected

Changes:
- selva_textual_inversion_trainer.py: _inject_tokens() helper centralises
  the torch.cat construction for both modes; used in training loop and eval;
  inject_mode stored in checkpoint files
- selva_textual_inversion_loader.py: reads inject_mode from checkpoint,
  includes in TEXTUAL_INVERSION bundle
- selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field
- selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and
  _train_inner call
- ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm);
  n4_baseline marked completed; suffix experiments retained for comparison

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:31:52 +02:00
Ethanfel f96265da23 feat(ti-trainer): add loss curve IMAGE output
Reuses _draw_loss_curve + _smooth_losses + _pil_to_tensor from the LoRA
trainer — raw loss in light blue, smoothed overlay in blue, matches the
LoRA trainer's visual style.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:20:44 +02:00
Ethanfel c0d95ce356 feat: add ti_sweep_1 experiment file
First TI sweep covering the three most impactful axes:
- token_count group: n_tokens 4 / 8 / 16 (capacity vs overfitting)
- learning_rate group: 5e-4 / 1e-3 / 2e-3 with n_tokens=4
- warm_init group: n4 and n8 seeded from 'mechanical impact sound design'

7 experiments total, 3000 steps each, same data_dir as LoRA sweeps.
n4_baseline (lr=1e-3, random init) is the primary reference point.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:14:31 +02:00
Ethanfel e37bfe1b1c feat: add SelVA TI Scheduler for sweep-based textual inversion experiments
- SelvaTiScheduler: runs a JSON-defined sweep of TI training experiments,
  loading the dataset once and reusing it across runs
- Collects per-experiment loss history, final/min loss, stability metric
  (loss_std_last_quarter), and duration — written to experiment_summary.json
  after each completed run so partial sweeps survive interruption
- Resume-aware: skips experiments already marked completed in an existing
  summary file
- Outputs smoothed loss comparison chart (same axes, one curve per experiment)
- SelvaTextualInversionTrainer._train_inner now returns a dict
  {embeddings_path, loss_history} so the scheduler can read results;
  train() extracts just the path for ComfyUI

JSON format: name, description, data_dir, output_root, base config,
experiments list with id + param overrides

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:13:04 +02:00
Ethanfel bb07bc8169 fix(ti-trainer): guard spectral metrics, drop unused imports
- Wrap _spectral_metrics + _save_spectrogram in try-except so a matplotlib
  or STFT error doesn't abort the checkpoint save (matches LoRA trainer)
- Remove unused `import math` and `_pil_to_tensor` import
- Drop dead `img` variable (_save_spectrogram returns None)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:10:19 +02:00
Ethanfel e36cdd7947 fix(ti-trainer): fix gradient flow and spectral metric shapes
- Replace in-place text_clip assignment with torch.cat so the computation
  graph correctly links text_input → learned_tokens; in-place assignment
  into a requires_grad=False leaf severs the graph and learned_tokens
  receives no gradients
- _spectral_metrics(wav, sr): was passing wav.unsqueeze(0) [1,1,L] instead
  of wav [1,L]; stft mean(dim=1) would return wrong shape [1,T] not [n_freqs]
- _save_spectrogram(wav, sr, ...): was passing wav.squeeze(0) [L] (1D)
  instead of wav [1,L] as the function expects

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:08:13 +02:00
Ethanfel e56ece9c1c feat: add SelVA Textual Inversion Trainer and Loader nodes
Learns K CLIP token embeddings ([K, 1024]) with all model weights frozen,
keeping generated latents on the decoder's natural manifold — avoids the
quality degradation that affects LoRA on BJ's audio dataset.

- selva_textual_inversion_trainer.py: trains learned_tokens via AdamW,
  injects into last K positions of 77-token CLIP embedding, checkpoints
  with eval audio + spectral metrics
- selva_textual_inversion_loader.py: loads .pt bundle, returns
  TEXTUAL_INVERSION dict for sampler
- selva_sampler.py: optional textual_inversion input; injects into both
  text_clip and neg_text_clip before preprocess_conditions
- __init__.py: registers both new nodes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:01:44 +02:00
Ethanfel eed7eefeac feat: add SelVA HF Smoother and Spectral Matcher preprocessing nodes
Two ComfyUI nodes to reduce domain mismatch between custom training audio
and the MMAudio VAE's expected spectral distribution:

SelvaHfSmoother: blends a low-pass filtered copy (biquad) with the original
at a configurable cutoff and blend ratio. Attenuates extreme HF content that
BigVGANv2 handles poorly. RMS-preserving.

SelvaSpectralMatcher: computes the log-mel energy profile of the clip,
compares it per-band to the VAE's normalization means (DATA_MEAN_80D/128D),
and applies a smooth STFT-domain gain correction to match the codec's training
distribution. Configurable strength and max_gain_db clamp. RMS-preserving.

Recommended workflow: SpectralMatcher → HfSmoother → feature extraction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 20:28:16 +02:00
Ethanfel 107bb05f17 fix(vae-roundtrip): pass bigvgan path to encoder-only FeaturesUtils
AutoEncoderModule unconditionally asserts vocoder_ckpt_path is not None
even when need_vae_encoder=True. Pass best_netG.pt to satisfy the assert;
the vocoder weights are not actually used since decode+vocode go through
model["feature_utils"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 20:05:44 +02:00
Ethanfel 10e6095e31 fix(vae-roundtrip): use model feature_utils for decode, add normalize/unnormalize, normalize output
- Load fresh FeaturesUtils only for encoding; use model["feature_utils"] for
  decode+vocode to mirror the exact path the sampler takes
- Apply generator.normalize() → unnormalize() around the encoded latent so the
  decoder receives latents in the same space it expects from inference
- Log both encoded and norm→unnorm latent stats to diagnose round-trip fidelity
- Normalize output to -27 dBFS (matching training clip RMS) and clamp to [-1, 1]
  to prevent clipping artifacts in the output waveform

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:50:01 +02:00
Ethanfel 528d33be39 fix: trim/pad latent to seq_cfg.latent_seq_len before decoding
Without this the decoder produced 7s instead of 8s due to STFT rounding.
Same fix as _prepare_dataset uses for training data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:22:09 +02:00
Ethanfel 8195c3114a feat: add SelVA VAE Roundtrip node
Encodes audio through the VAE then decodes straight back, bypassing the
diffusion model entirely. Use this to isolate whether saturation artifacts
are introduced by the codec reconstruction (VAE/DAC) or by the LoRA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:15:20 +02:00
Ethanfel c8e6b91f67 feat: add alpha_scale_sweep to fix LoRA noise contamination
Previous sweep used alpha=rank (scale=1.0) which at rank 128/256 drowned
base model priors — spectral flatness went from 0.013 (baseline) to 0.094.
This sweep tests alpha dramatically below rank across r16/r32/r128 to find
the scale where LoRA nudges rather than overwrites.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:55:05 +02:00
Ethanfel fdce9cbbf1 feat: evaluate adapters on all dataset clips, not just clip_001
- _eval_sample gains clip_idx param (default 0, backward compatible)
- Evaluator loops over all dataset clips per adapter, saves one WAV per clip
- Reference metrics computed for all clips and averaged
- Comparison chart and summary use avg_metrics across all clips
- Eliminates bias from evaluating on an unrepresentative single clip

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:42:55 +02:00
Ethanfel 42ceb4b153 fix: preserve original audio extension when copying reference file
shutil.copy2 was writing FLAC binary to reference.wav — unplayable.
Now copies as reference{.flac/.wav/etc} matching the source extension.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:31:26 +02:00
Ethanfel 4505b89db1 feat: add reference audio to LoRA evaluator
Loads the first clip's original audio (same clip used for inference),
copies it to output_dir/reference.wav, runs spectral metrics and
saves a spectrogram. Appears first in the comparison chart so generated
samples can be judged against the target sound.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:30:33 +02:00
Ethanfel dbfa7b23fe feat: add eval_r128_candidates.json
Evaluates top 5 adapters from r128_sweet_spot: baseline, lr_5e4_r128,
lr_3e4_r256, lr_3e4_r128, curriculum_lr_3e4 final + step 6000 checkpoint
(before regression) for spectral comparison.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:28:28 +02:00
Ethanfel d2e1ea7b80 feat: add SelVA LoRA Evaluator node
Generates audio samples from a list of adapters against a fixed reference
clip, collects spectral metrics for each, and outputs a comparison bar
chart + eval_summary.json. Useful for comparing sweep candidates before
committing to a next round of training.

JSON format: name, data_dir, output_dir, steps, seed, adapters[{id, path}].
Empty path = baseline (no LoRA).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:26:50 +02:00
Ethanfel 9a47508d2d fix: lower RMS normalization target from -23/-20 to -27 dBFS
Training clips at -23 LUFS measure -25 to -31 dBFS RMS (avg ~-27).
Normalizing output to -23 dBFS was 4-8 dB too loud, causing saturation
on clips with high crest factor and peaks near 0 dBFS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:19:20 +02:00
Ethanfel 678c050f11 fix: make normalize(x1) assignment explicit in training loop
normalize() uses in-place ops so it worked, but reading the return value
makes the intent clear and guards against future refactors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 15:43:42 +02:00
Ethanfel 1be07a80d2 feat: add cosine LR decay schedule to trainer and scheduler
- Add lr_schedule param (constant|cosine) to SelvaLoraTrainer
- Cosine decays LR from initial value to ~0 after warmup, preventing
  the oscillation observed at steps 6000-8000 with lr=2e-4 flat
- Wire lr_schedule through scheduler _PARAM_DEFAULTS and _train_inner call
- Add g5_r128_lr_2e4_cosine and g5_r128_lr_3e4_cosine to r128_sweet_spot sweep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:25:01 +02:00
Ethanfel 58e1985af2 feat: SelVA Skip Experiment node + save partial scalars on skip
- New node: SelVA Skip Experiment — writes skip_current.flag from UI,
  queue in a second workflow tab while scheduler is running
- SkipExperiment now attaches partial loss/grad/spectral data to the
  exception so the scheduler saves all collected scalars in the summary

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:10:43 +02:00
Ethanfel 264dc49d42 feat: skip_current.flag to cancel experiment and move to next
Create the flag file in the sweep output_root to skip the running
experiment at the next log interval (every 50 steps):
  touch /path/to/experiment/skip_current.flag

Scheduler marks it as 'skipped' in the summary and continues.
Skipped experiments are NOT resumed on restart (unlike failed ones).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:09:01 +02:00
Ethanfel fec5c86f09 feat: add spectral_flatness and temporal_variance to eval metrics
spectral_flatness (Wiener entropy) — 0=tonal, 1=white noise.
Rising value across steps directly flags noise contamination.
temporal_variance — RMS std/mean per frame. Low = lifeless/compressed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:45:40 +02:00
Ethanfel 2861327016 feat: spectral metrics per eval sample in experiment summary
Computes hf_energy_ratio (>4kHz), spectral_centroid_hz, spectral_rolloff_hz
at each save_every checkpoint. Logged to console and stored in
experiment_summary.json under results.spectral_metrics[step].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:44:43 +02:00
Ethanfel c4687521ef feat: save spectrogram PNG alongside each eval sample
Log-frequency dB spectrogram (inferno colormap, 100Hz–16kHz) saved as
step_XXXXX.png next to step_XXXXX.wav in samples/ subfolder.
Makes high-frequency rolloff (low bitrate signature) immediately visible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:42:34 +02:00
Ethanfel 8717af2728 fix: prevent saturation from RMS normalization clipping peaks
RMS normalize to target then scale back if peaks exceed 1.0,
preserving dynamics instead of hard-clipping transients.
Eval sample target updated to -23 dBFS to match training data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:29:29 +02:00
Ethanfel 78e9838a83 fix: replace peak normalization with RMS normalization at -20 dBFS
Peak norm was slamming output to full scale regardless of content level,
making generated audio several times louder than training clips.
RMS norm to -20 dBFS matches typical processed audio level.
Sampler exposes target_lufs (-40 to -6, default -20) for user control.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:06:48 +02:00
Ethanfel 94610b8943 feat: r128_sweet_spot sweep — noise-free LR search + rank 256
9 experiments targeting loss 0.25-0.35 without LoRA+ noise.
Tests higher base LR (2e-4/3e-4/5e-4), curriculum combos, conservative
LoRA+ ratio=4, and rank 256 baseline + lr=3e-4.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:46:08 +02:00
Ethanfel f5f7f2ae68 fix: eval sample seed 0 -> 42
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:32:43 +02:00
Ethanfel 1663b39833 fix: bump eval sample to 25 ODE steps (was 8)
Inference is fast on RTX PRO 6000 — 8 steps was washing out quality
differences between experiments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:32:27 +02:00
Ethanfel a7923d5fb7 feat: r64_overnight sweep — focused rank-64 ablation at 8000 steps
15 experiments across rank (64/128), alpha, regularisation, LR, target
layers, and combined stacks. Based on tier1_thorough early results
confirming rank 64 sounds best perceptually.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 01:32:23 +02:00
Ethanfel 786a57c424 feat: sweep resume + 5 additional experiments (LR, target, extended)
Scheduler: on re-run, reads existing experiment_summary.json and skips
already-completed experiments — safe to stop and restart mid-sweep.

tier1_thorough: adds g5 (lr 3e-5/3e-4), g6 (full target attn.qkv+linear1
at r16 and r64), and g4_full_r64_6k (6000-step extended run) — 17 total.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:59:16 +02:00
Ethanfel f15e02b0b8 fix: eval samples use fixed clip/seed, save to samples/ subfolder
- Always sample dataset[0] with fixed noise seed so checkpoints are
  directly comparable (hear the model improve step by step)
- Save to output_dir/samples/step_XXXXX.wav instead of alongside checkpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:54:37 +02:00
Ethanfel 0682a536cb fix: point data_dir to features/ subdir where .npz and audio live
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:45:32 +02:00
Ethanfel 0000878e76 feat: thorough overnight sweep + dataset browser updates
- Dataset browser: audio/features now resolve through features/ subdir
- tier1_sweep.json: update data_dir to BJ dataset path
- tier1_thorough.json: 12-experiment overnight sweep across 4 groups
  (rank 16/32/64, alpha scaling, LoRA+/dropout/curriculum isolation,
  full Tier 1 stack at r16 and r64) — output to BJ/experiment/

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:38:19 +02:00
Ethanfel 675644189d feat: add SelVA Dataset Browser node
Companion node for inspecting dataset.json entries by integer index.
Outputs video (.mp4), audio (.wav/.flac), features (.npz), frames dir,
mask dir, label, and max_index for constraining the index widget range.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:55:27 +02:00
Ethanfel 82fb7a0009 docs: note AudioX shows no perceptual quality gain on V2A vs SelVA
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 09:12:00 +02:00
Ethanfel af4777d2d7 docs: add AudioX vs SelVA evaluation
Architecture comparison, capability matrix, integration cost estimate,
LoRA training difficulty analysis, and license implications.
Verdict: SelVA remains preferred for V2A + LoRA fine-tuning; AudioX
adds value for music generation, inpainting, and text-to-audio tasks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 09:11:09 +02:00
Ethanfel ed8abf7a5b docs: add video format recommendations to dataset preparation section
New section 1.1 covers aspect ratio (16:9 landscape preferred), resolution
(≥480p), frame rate (any, use VHS_VIDEOINFO), and portrait handling
(center-crop to square). Based on CLIP 384px and Synchformer 224px internals.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:44:14 +02:00
Ethanfel 21ed93d3ee docs: add audio dataset pipeline reference doc
Full research notes on cleaning, augmentation, and quality metrics for
generative model training. Covers LUFS normalization, AudioSep, waveform
augmentation (pitch shift, RIR, EQ), latent mixup, DNSMOS gating, tool
install commands, and key paper references.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:37:48 +02:00
Ethanfel f1e2bbd55b feat: add first experiment sweep file for Tier 1 ablation
6 experiments: baseline, LoRA+ (ratio=16), dropout 0.05, dropout 0.1,
curriculum sampling, and all three combined. bf16 batch 16, 2000 steps,
seed 42. data_dir placeholder needs to be updated before running.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:15:06 +02:00
Ethanfel 3d9221c248 fix: three bugs in scheduler and trainer
- trainer: raise ValueError early when remaining steps < log_interval (50)
  instead of UnboundLocalError on smoothed_img/final_path at return
- trainer: use None in grad_norm_history instead of silent 0.0 when
  grad_accum > log_interval and no optimizer step fired in the interval
- trainer: include start_step in _train_inner return dict
- scheduler: use start_step from result dict for min_loss_step and
  loss_at_steps (fixes wrong step labels on resumed experiments)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:11:25 +02:00
Ethanfel 2d200395af feat: add grad norm logging and richer experiment summary output
trainer:
- Track gradient norm before clipping at each optimizer step
- Log avg grad_norm per log_interval alongside loss in console output
- Include grad_norm_history in _train_inner return dict

scheduler:
- Add system block to summary (GPU name, VRAM, torch/CUDA version)
- Include full loss_history and grad_norm_history arrays in each
  experiment result (50-step resolution, not just save_every checkpoints)
- Add loss_std_last_quarter stability metric (std dev of raw loss over
  last 25% of steps — high value indicates unstable training)
- Add log_interval field so consumers know the x-axis resolution

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:06:39 +02:00
Ethanfel 3ec380a27e feat: add SelVA LoRA Scheduler node for automated experiment sweeps
- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level
  function so the dataset can be encoded once and reused across experiments
- Change _train_inner() return value from tuple to dict (adds loss_history,
  meta, completed; train() unpacks for ComfyUI — no change to node outputs)
- New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments
  sequentially, writes experiment_summary.json (updated after each run) and
  loss_comparison.png with all smoothed curves overlaid on the same axes
- Register SelvaLoraScheduler in nodes/__init__.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:03:21 +02:00
Ethanfel 9bc2568543 docs: document LoRA dropout, LoRA+, and curriculum timestep sampling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 12:45:53 +02:00
Ethanfel eb63c1ead7 feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights),
  0.05–0.1 helps regularize on small datasets (arXiv:2404.09610)
- LoRA+: separate optimizer param groups for lora_A and lora_B with
  configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354)
- Curriculum mode: logit_normal for first N% of steps then uniform,
  directly addresses early convergence + fine-detail degradation at
  boundaries (arXiv:2603.12517)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 12:43:18 +02:00
Ethanfel 5baa070e61 docs: add observations section with fp32/batch/precision findings
Work-in-progress empirical notes: fp32 batch 32 reaches same quality as
bf16 batch 16 in 1/3 the steps but overfits past ~2000 steps on 10 clips.
Lower loss does not reliably mean better audio on small datasets.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 02:34:53 +02:00
Ethanfel 9fc739fe9e docs: add prompt guide and masking note to dataset preparation section
Poor prompts and missing masks are a common source of white noise in LoRA
training — imprecise sync features force the adapter to compensate with noise.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:43:28 +02:00
Ethanfel 57fae4a8ce chore: default timestep_mode back to uniform
logit_normal reaches lower loss but perceptual improvement over uniform
is dataset-dependent. Keeping uniform as default to match original MMAudio
training behavior; logit_normal remains available as an option.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:21:08 +02:00
Ethanfel 8e919c0459 fix: resolve relative and Unix-style output_dir paths to ComfyUI output folder
On Windows, /folder is drive-relative (no drive letter) rather than a real
absolute path. Redirect these to ComfyUI's output directory so files don't
land at C:\folder. Also redirects plain relative paths (e.g. lora_output)
to output/ instead of the process working directory.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:14:04 +02:00
Ethanfel fec8eaac95 fix: save adapter and loss curves on cancel, not only on normal completion
Wraps training loop in try/finally so adapter_final.pt and loss PNGs are
always written. On cancellation the adapter is named
adapter_cancelled_stepXXXXX.pt so it can be used with --resume.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:06:44 +02:00
Ethanfel d83632e754 fix: pad/trim clip and sync features to fixed seq_len at dataset load time
Clips from shorter videos produce fewer CLIP frames (e.g. 2s → 16 frames,
8s → 64 frames). Mixed-length datasets would cause torch.stack() to fail
during batching. Normalize to seq_cfg.clip_seq_len / sync_seq_len at load,
same as latents are already normalized to latent_seq_len.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:54:05 +02:00
Ethanfel a5014e49eb feat: add logit-normal timestep sampling to reduce white noise artifacts
Uniform timestep sampling undertrained t>0.8 (the final denoising steps),
leaving residual noise that CFG amplifies at inference. Logit-normal sampling
concentrates training near t=0.5 while still covering the full range, improving
high-t coverage and reducing noise floor in generated audio.

Default changed from uniform to logit_normal (sigma=1.0). Previous behavior
available with timestep_mode=uniform.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:35:42 +02:00
Ethanfel 8ae0ba3c7d fix: increment adapter_final filename on resume to avoid overwriting previous final
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:15:31 +02:00
Ethanfel 2b2b438307 fix: set OUTPUT_NODE=True on SelVA Feature Extractor so it runs without connected outputs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:11:16 +02:00
Ethanfel 39984f73c2 docs: add observed batching results to training guide
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:05:16 +02:00
Ethanfel 1f8cd6f930 docs: rewrite LORA_TRAINING.md with real-world findings
- Added batch_size VRAM table and updated step recommendations for batched training
- Added adapter strength section with practical guidance (0.6-0.7 for noise)
- Added ComfyUI node as Option A for training (not just CLI)
- Noted .mp3 as not recommended, soundfile fallback implied
- Added output files section with sample_*.wav and loss curve PNGs
- Added "LoRA has no effect" troubleshooting (wrong node wired)
- Updated loss convergence targets based on observed training runs
- Clarified linear1 target: 150+ clips recommended

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:00:36 +02:00
Ethanfel 20f8138146 chore: show batch_size in training step log
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:45:43 +02:00
Ethanfel 09b3b94ddd feat: add batch_size parameter to training (default 4)
Replaces single-sample steps with batched sampling via random.choices().
Tensors are stacked to [B, T, C] before the forward pass; t is now [B].
Default grad_accum lowered to 1 since real batching gives stable gradients.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:36:12 +02:00
Ethanfel 3f67de694c feat: save loss_raw.png and loss_smoothed.png to output_dir
Raw curve shown in light blue, EMA-smoothed (beta=0.9) overlay in darker
blue. Both saved as PNG at end of training. The node IMAGE output now
returns the smoothed version. Live preview also uses the smoothed overlay.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:15:48 +02:00
Ethanfel 423e174b88 debug: print lora_A norm after loading to confirm adapter applied
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:05:23 +02:00
Ethanfel 4806daa4ca chore: lower default warmup_steps from 500 to 100
500 warmup steps is 25% of a 2000-step run — too long. 100 steps lets
the full lr kick in much earlier without sacrificing stability.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:51:27 +02:00
Ethanfel 16b3eb11cc fix: pass max_size=800 to progress bar preview (was 85px wide)
The third element in ComfyUI's preview tuple is max_size in pixels, not
JPEG quality. Passing 85 was capping the live loss curve at 85×40px.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:48:56 +02:00
Ethanfel 004ea63f62 fix: fall back to soundfile for torchaudio.save when torchcodec unavailable
Same torchcodec/FFmpeg issue as the load path, now on the eval sample save.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:44:04 +02:00
Ethanfel afb3242eca fix: disable inference_mode entirely for training via inference_mode(False)
torch.enable_grad() alone is insufficient: operations on inference tensors
(created inside ComfyUI's outer inference_mode context) produce inference
tensors even inside enable_grad, breaking autograd. inference_mode(False)
exits the inference context so the deepcopy, apply_lora, and training loop
run with a fully clean autograd context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:40:50 +02:00
Ethanfel 849f31e2a6 fix: create LoRA params inside torch.enable_grad() to escape inference_mode
torch.enable_grad() re-enables grad tracking but nn.Parameters created while
torch.inference_mode() is active are inference tensors that can't enter autograd
regardless. Splitting into _train_inner() and calling it inside enable_grad()
ensures the deepcopy, apply_lora, and the training loop all run with a clean
autograd context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:36:28 +02:00
Ethanfel 505d445eb3 fix: wrap training loop in torch.enable_grad()
ComfyUI executes all nodes inside torch.no_grad(), which prevents gradient
tracking and makes loss.backward() fail. torch.enable_grad() overrides it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:32:00 +02:00
Ethanfel 8fade1b0e3 fix: initialize LoRA params on same device as wrapped linear
apply_lora() is called after generator.to(device), so lora_A/lora_B were
being created on CPU while the rest of the model was on CUDA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:17:29 +02:00
Ethanfel ad57432803 fix: pad/trim latent to exact latent_seq_len after VAE encoding
STFT hop-size rounding produces ±1 latent frame vs the expected seq length.
Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:12:20 +02:00
Ethanfel 43f732f904 fix: transpose VAE latent from [B,C,T] to [B,T,C] before generator
VAE encoder returns channels-first [B, latent_dim, T]; the generator
expects time-first [B, T, latent_dim] (same convention as decode which
already does .transpose(1,2)). Fixes normalize() size mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:08:00 +02:00
Ethanfel 6b9adf0816 fix: fall back to soundfile when torchcodec FFmpeg libs are missing
Recent torchaudio defaults to torchcodec as the audio backend, which requires
FFmpeg shared libraries. Falls back to soundfile for envs where torchcodec
can't load (e.g. containerised ComfyUI without system FFmpeg).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:03:57 +02:00
Ethanfel 52434a053a fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure
torch.stft requires float32 input — casting vae_utils to bf16 caused silent
failures during dataset pre-loading. Also adds traceback.print_exc() so future
clip-load errors are visible in the ComfyUI log.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 21:57:20 +02:00
Ethanfel 56c8d5d6b4 feat: save eval audio sample alongside each checkpoint
At every save_every steps, run a quick 8-step no-CFG inference pass on
a random training clip and save the decoded waveform as
sample_stepXXXXX.wav next to the checkpoint. Uses the existing
generator.unnormalize + feature_utils.decode + vocode pipeline from
the sampler. Failure is non-fatal (logged and skipped).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 21:47:02 +02:00
Ethanfel b430953602 feat: live loss curve preview during training
- Send updated loss curve to ComfyUI frontend every 50 steps via
  pbar_train.update_absolute() with a JPEG preview tuple — same
  mechanism as KSampler's denoising previews.
- Fix x-axis step labels for resumed runs (previously always started
  at 0; now correctly shows start_step + offset).
- Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor
  (converts for ComfyUI IMAGE output).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:11:38 +02:00
Ethanfel 57cd3dd4b4 fix: use load_lora for resume and remove redundant inference_mode wrapper
- Resume now calls load_lora() instead of load_state_dict() directly,
  giving proper warnings for missing/unexpected LoRA keys.
- Remove redundant `with torch.inference_mode():` around encode_audio
  (already @inference_mode decorated); dist.mode().clone() pattern
  is now clearer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:09:35 +02:00
Ethanfel f206a1b38c feat: add SelVA LoRA Trainer ComfyUI node
Runs the full training loop inside ComfyUI. Reuses the already-loaded
CLIP model from the inference model for text encoding; loads only a
minimal VAE encoder separately (freed after dataset pre-loading).

Outputs:
- SELVA_MODEL with LoRA applied (ready to connect directly to Sampler)
- adapter_path STRING (for SelVA LoRA Loader in future sessions)
- loss_curve IMAGE (PIL-rendered line chart of training loss per 50 steps)

Progress is shown via ComfyUI ProgressBar (two phases: dataset loading,
then training steps). Resume is supported via resume_path input.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:07:38 +02:00
Ethanfel 2f4641247a feat: add resume support to train_lora.py
Step checkpoints now save optimizer state, scheduler state, and step
number alongside the LoRA weights. Pass --resume path/to/adapter_stepXXXXX.pt
to continue training from that checkpoint. --steps always means total steps,
so resuming from 1000 with --steps 2000 trains 1000 more steps.

adapter_final.pt format is unchanged (state_dict + meta only) so
SelvaLoraLoader remains compatible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 16:59:30 +02:00
Ethanfel 8e9114b92c docs: add clip length and scalable dataset size recommendations
- Clip length section: fixed 8s duration, padding/trim behavior, per-sound-type
  strategies (continuous, short events, repeating, onset placement).
- Dataset size table: 5-10 / 15-30 / 30-60 / 60-150 / 150-300 / 300+ clips
  with scenario and expected result for each tier.
- Note on diversity vs quantity.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 16:34:50 +02:00
Ethanfel 63b4391573 fix: named .npz files always start at _001
dog_bark_001.npz, dog_bark_002.npz instead of dog_bark.npz, dog_bark_001.npz.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:44:26 +02:00
Ethanfel 89af5a468c docs: add LoRA training guide
Covers dataset preparation (ComfyUI feature extraction + clean audio),
training CLI reference, tuning guide (rank/steps/lr), adapter loading
in ComfyUI, and troubleshooting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:43:09 +02:00
Ethanfel c88e27742c fix: sanitize name field and remove double load_npz call
- _resolve_named_path: replace / \ and null in name to prevent path
  traversal outside cache_dir (would cause a confusing FileNotFoundError
  at np.savez time instead of at path resolution).
- train_lora: load_npz was called twice per clip when prompt was in
  prompts.txt; consolidate to a single call before prompt resolution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:30:25 +02:00
Ethanfel cbcd154c96 feat: add name field with auto-increment to SelvaFeatureExtractor
When name is provided, features are saved as name.npz (or name_001.npz,
name_002.npz etc. if the file already exists) instead of a content hash —
useful for building a named training dataset. Hash-based caching is
unchanged when name is left empty.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:16:51 +02:00
Ethanfel 1eb82d8050 refactor: train_lora accepts .npz + audio pairs instead of raw video
- Input is now pre-extracted .npz files (from SelvaFeatureExtractor) paired
  with clean audio files (same stem). Visual features no longer re-extracted
  during training.
- FeaturesUtils loaded with enable_conditions=False (VAE only) — Synchformer
  and T5 are no longer loaded, saving ~3-4 GB VRAM.
- CLIP text encoder loaded separately via patch_clip so text prompt can differ
  from the one used during feature extraction.
- Prompt priority: prompts.txt override > embedded in .npz > directory name.
- Removed: torchvision video loading, frame sampling/resizing, net_video_enc,
  synchformer path check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:14:26 +02:00
Ethanfel cde280049b fix: correct LoRALinear dtype and remove unused import
- LoRALinear now creates lora_A/lora_B with dtype matching the base
  linear's weight, preventing a float32/bf16 mismatch at forward time
  when the generator is loaded in bf16 or fp16.
- Remove unused `import math` from train_lora.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:57:09 +02:00
Ethanfel 437c62b28f feat: LoRA fine-tuning for SelVA generator
Teaches the model new/partial sound classes from custom video+audio pairs.
Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model.

selva_core/model/lora.py
  LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices.
  B initialised to zero → zero adapter contribution at init.
  apply_lora(): walks named_modules, replaces matching nn.Linear in-place.
  Default target: "attn.qkv" (all 21 SelfAttention QKV projections in
  large_44k). Add "linear1" to also wrap post-attention output projections.
  get_lora_state_dict() / load_lora() for ~10 MB save/load.

train_lora.py (standalone script, no ComfyUI dependency)
  Data format: directory of video files + optional prompts.txt
  ("filename: description"). Falls back to directory name as prompt.
  Pre-extracts features for all clips into RAM, then trains from those.
  Training loop: encode audio→latent (need_vae_encoder=True), flow
  matching MSE loss on velocity prediction, backward on LoRA params only.
  Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata.
  Key verified interfaces used:
    encode_audio() → DiagonalGaussianDistribution; .mode().clone() required
    normalize() is in-place
    forward(latent, clip_f, sync_f, text_f, t) takes raw tensors

nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node)
  Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights.
  strength param scales lora_B to adjust adapter contribution at inference.
  Reads rank/alpha/target from embedded metadata if present.
  Returns a patched SELVA_MODEL bundle for use with the existing Sampler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:38:46 +02:00
45 changed files with 11164 additions and 118 deletions
+459
View File
@@ -0,0 +1,459 @@
# LoRA Training for SelVA
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
---
## Overview
Training is split into two steps:
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 34 GB of VRAM.
---
## Requirements
Same environment as SelVA inference. Additional Python packages:
```
torchaudio
soundfile
```
---
## Step 1 — Prepare the dataset
### 1.1 Video format
The feature extractor accepts any input but internally resamples frames to fixed square resolutions (384×384 for CLIP, 224×224 for Synchformer). Both encoders were trained on standard video datasets — predominantly landscape footage. This has two practical implications:
**Aspect ratio** — use **16:9 landscape** whenever possible. Portrait clips (9:16) are mechanically supported but the bicubic stretch into square distorts the image relative to the encoders' training distribution, which can degrade sync feature quality. If your source is portrait, center-crop to square before extraction. Square (1:1) is also fine.
**Resolution** — anything ≥ 480p is sufficient. The extractor downscales to 384px and 224px regardless of source resolution; higher resolution adds no benefit.
**Frame rate** — any. Connect `VHS_VIDEOINFO` from VHS LoadVideo to the feature extractor so fps is read automatically from the file instead of being entered manually.
| Format | Recommendation |
|---|---|
| Aspect ratio | 16:9 landscape (preferred) or 1:1 square |
| Resolution | ≥ 480p (720p+ is fine, no upper limit that matters) |
| Frame rate | Any — set via VHS_VIDEOINFO |
| Portrait (9:16) | Center-crop to square before extraction |
### 1.2 Extract visual features in ComfyUI
For each video clip you want to train on:
1. Load the video with a VHS LoadVideo node.
2. Connect it to **SelVA Feature Extractor**.
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
### Prompt guide
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
**Good prompts are specific about:**
- The sound source (what object is making the sound)
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
- The action producing the sound (if applicable)
| Sound | Weak prompt | Strong prompt |
|---|---|---|
| Dog bark | `dog` | `a large dog barking loudly` |
| Footsteps | `walking` | `heavy boots on a wooden floor` |
| Water | `water` | `water dripping into a metal bucket` |
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
| Door | `door` | `a heavy wooden door slamming shut` |
**Rules of thumb:**
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
### Masking note
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
### 1.3 Collect clean audio
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
```
dataset/my_sound/
dog_bark_001.npz ← from SelVA Feature Extractor
dog_bark_001.wav ← clean isolated audio recording
dog_bark_002.npz
dog_bark_002.wav
dog_bark_003.npz
dog_bark_003.wav
```
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
### 1.4 Optional: prompts.txt
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
```
# One line per file: filename: prompt text
dog_bark.npz: a large dog barking aggressively
dog_bark_001.npz: a dog barking in the distance
```
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
---
## Step 2 — Train
### Option A — SelVA LoRA Trainer node (ComfyUI)
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
```
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
```
### Option B — Command line
```bash
python train_lora.py \
--data_dir dataset/my_sound \
--output_dir lora_output/my_sound \
--variant large_44k \
--selva_dir /path/to/ComfyUI/models/selva \
--rank 16 \
--steps 4000 \
--batch_size 4 \
--lr 1e-4
```
The script will:
1. Load the VAE, CLIP text encoder, and generator.
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
3. Train LoRA adapters for the specified number of steps.
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
---
## CLI Reference
| Argument | Default | Description |
|---|---|---|
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
| `--selva_dir` | required | Path to SelVA model weights directory |
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
| `--lr` | `1e-4` | Learning rate |
| `--steps` | `2000` | Total training steps |
| `--warmup_steps` | `100` | Linear LR warmup steps |
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
| `--save_every` | `500` | Save a checkpoint every N steps |
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
| `--seed` | `42` | Random seed |
| `--timestep_mode` | `uniform` | Timestep sampling: `uniform`, `logit_normal`, or `curriculum` |
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` / `curriculum` |
| `--curriculum_switch` | `0.6` | Fraction of steps to use logit_normal before switching to uniform. Only with `curriculum` |
| `--lora_dropout` | `0.0` | Dropout on the LoRA path only. `0.05``0.1` helps regularize on small datasets |
| `--lora_plus_ratio` | `1.0` | LoRA+ LR ratio: `lr_B = lr × ratio`. `1.0` = standard LoRA, `16.0` = LoRA+ |
---
## Step 3 — Load the adapter in ComfyUI
Connect **SelVA LoRA Loader** between the model loader and the sampler:
```
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
```
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
| Input | Description |
|---|---|
| `model` | SELVA_MODEL from the model loader |
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
---
## Tuning Guide
### Clip length
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
**Practical recommendations:**
| Sound type | Clip strategy |
|---|---|
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 23 times per clip |
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
| Sound with a clear onset (explosion, splash) | Put the onset at ~12 s from the start, not at 0 s — gives the model context |
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
### How many clips do I need?
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
| Dataset size | Scenario | Expected result |
|---|---|---|
| **510 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
| **1530 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
| **3060 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
| **60150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
| **150300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
### Batch size
| Batch size | VRAM (large_44k) | Use case |
|---|---|---|
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
| `8` | ~15 GB | Better convergence on larger datasets |
| `16` | ~20 GB | Best gradient quality when VRAM allows |
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
### Rank
| Rank | Use case |
|---|---|
| `8` | Fine details on a sound the model already knows well |
| `16` | Default — good balance of capacity and VRAM |
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
Higher rank increases VRAM usage and overfitting risk on small datasets.
### Steps
With `batch_size=4` as the default, these are rough guidelines:
| Dataset size | Recommended steps |
|---|---|
| 1020 clips | 20004000 |
| 2050 clips | 40008000 |
| 50+ clips | 600015000 |
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
### Learning rate
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
### Target layers
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
```bash
--target attn.qkv linear1
```
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
### Timestep sampling mode
Controls how training timesteps are sampled at each step.
`uniform` (default) samples all timesteps equally — equivalent to original MMAudio training.
`logit_normal` concentrates more steps near t=0.5 via `sigmoid(N(0, σ))`. This is the semantically rich mid-noise region. Consistently reaches a lower loss floor but the perceptual improvement on small datasets is marginal.
`curriculum` uses logit_normal for the first `curriculum_switch` fraction of steps (default 60%), then switches to uniform for the remainder. The motivation: logit_normal accelerates early structure learning but undertrains the high-t boundary region; uniform then fills in the fine detail. A switch message is logged when the transition happens.
| Mode | When to use |
|---|---|
| `uniform` (default) | Baseline — safe, equivalent to original training |
| `logit_normal` | When you want a lower loss floor; marginal on small datasets |
| `curriculum` | Experimental — may improve convergence quality on small datasets |
The `logit_normal_sigma` parameter controls the width of the logit-normal distribution (used by both `logit_normal` and the first phase of `curriculum`):
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
- σ=0.5: sharper peak, less coverage of extremes
- σ=2.0: broader, approaches uniform
### LoRA dropout
`lora_dropout` applies dropout to the input of the LoRA path (not the frozen base linear). It regularizes the low-rank update without disturbing pretrained weights — helpful on small datasets where the LoRA would otherwise overfit to the training clips.
| Value | Use case |
|---|---|
| `0.0` (default) | No regularization — fine for 30+ clips |
| `0.05` | Light regularization — recommended starting point on 1020 clips |
| `0.1` | Stronger regularization — use if loss plateaus but audio is still noisy |
Dropout is not saved in the adapter file — it only affects training. Loading the adapter at inference does not require setting dropout.
### LoRA+ (asymmetric learning rate)
`lora_plus_ratio` splits the learning rate between LoRA A and B matrices: `lr_B = lr × ratio`. The B matrix is the output-side projection and benefits from a higher LR. Setting ratio to 16 enables the LoRA+ scheme from arXiv:2402.12354.
| Ratio | Effect |
|---|---|
| `1.0` (default) | Standard LoRA — identical A and B learning rates |
| `4.0` | Mild asymmetry |
| `16.0` | LoRA+ — faster convergence, especially on early steps |
LoRA+ is orthogonal to dropout and curriculum sampling — all three can be combined.
### Adapter strength at inference
| Strength | Effect |
|---|---|
| `0.50.7` | Conservative — blends adapter with base model, less noise |
| `1.0` | Full adapter strength (default) |
| `>1.0` | Exaggerated effect, may introduce artifacts |
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.60.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
### Loss interpretation
A typical loss curve:
- Starts around `0.81.0`
- Should reach `0.550.65` after convergence on a clean sound class with 1030 clips
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
- Below `0.1` on a small dataset means overfitting
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
### Precision
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
---
## Output files
```
lora_output/my_sound/
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
adapter_step01000.pt
...
adapter_final.pt ← final adapter with embedded metadata (inference only)
meta.json ← human-readable metadata
sample_step00500.wav ← quick eval sample at each checkpoint
loss_raw.png ← raw loss curve
loss_smoothed.png ← EMA-smoothed loss curve
```
`adapter_final.pt` format:
```python
{
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
"meta": {
"variant": "large_44k",
"rank": 16,
"alpha": 16.0,
"target": ["attn.qkv"],
"steps": 2000
}
}
```
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
---
## Troubleshooting
**`No layers matched target=...`**
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
**`No .npz files found in ...`**
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
**`No audio file found for clip.npz`**
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
**The sound is audible but there is white noise on top**
Lower the adapter strength to `0.60.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
**LoRA appears to have no effect**
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
**Loss does not decrease**
- Increase `batch_size` for more stable gradients.
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
- Check that the audio files are clean and actually contain the target sound.
- Check that the `.npz` features were extracted with a relevant prompt.
**Loss explodes or NaN**
- Lower the learning rate (`5e-5`).
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
**Loss plateaus early (above 0.7)**
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
---
## Observations (work in progress)
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
### Precision and batch size
| Config | Smoothed loss at step 2000 | Notes |
|---|---|---|
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 60008000 at ~0.59 |
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
### logit_normal vs uniform
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
### White noise
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
- Too few clips for the model to confidently predict the target sound
- Imprecise extraction prompts producing unfocused sync features
- Missing mask when multiple objects are in frame
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.03.5 or adapter strength to 0.60.7 helps at inference.
+254 -8
View File
@@ -58,7 +58,7 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From SelVA Model Loader | | `model` | From SelVA Model Loader (or any loader/loader chain) |
| `features` | From SelVA Feature Extractor | | `features` | From SelVA Feature Extractor |
| `prompt` | Text description — leave empty to use the prompt stored in features | | `prompt` | Text description — leave empty to use the prompt stored in features |
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) | | `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
@@ -66,22 +66,261 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| `steps` | Sampling steps (default: 25) | | `steps` | Sampling steps (default: 25) |
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) | | `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
| `seed` | RNG seed | | `seed` | RNG seed |
| `normalize` | Peak-normalize output to [-1, 1] (default: true) | | `normalize` | RMS-normalize output to `target_lufs` (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
**Output:** `AUDIO` **Output:** `AUDIO`
--- ---
## Workflow ### SelVA LoRA Loader
Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
**Output:** `model` (SELVA_MODEL with adapter injected)
---
### SelVA LoRA Trainer
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
---
### SelVA LoRA Scheduler
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `experiments_file` | Path to JSON sweep config |
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Skip Experiment
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
**Output:** `flag_path` (STRING)
---
### SelVA LoRA Evaluator
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
---
### SelVA Dataset Browser
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
**Outputs:** video path, audio path, frames directory, label, total count
---
### SelVA VAE Roundtrip
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `audio` | AUDIO to test |
**Output:** `audio_reconstructed` (AUDIO)
---
### SelVA HF Smoother
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
**Output:** `audio` (AUDIO)
---
### SelVA Spectral Matcher
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
**Output:** `audio` (AUDIO)
---
### SelVA Textual Inversion Trainer
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
---
### SelVA Textual Inversion Loader
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
---
### SelVA TI Scheduler
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Activation Steering Extractor
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with `.npz` feature files |
| `output_path` | Where to save `steering_vectors.pt` |
| `n_samples` | Clips to average over (default: 16) |
| `seed` | RNG seed |
**Output:** `steering_path` (STRING)
---
### SelVA Activation Steering Loader
Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
**Output:** `steering_vectors` (STEERING_VECTORS)
---
### SelVA BigVGAN Trainer
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with target-style audio files (searched recursively) |
| `output_path` | Where to save the fine-tuned vocoder `.pt` |
| `train_mode` | `snake_alpha_only` (default) or `all_params` |
| `steps` | Training steps (default: 2000) |
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
| `batch_size` | Clips per step (default: 4) |
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
| `save_every` | Checkpoint interval in steps (default: 500) |
| `seed` | RNG seed |
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
---
### SelVA BigVGAN Loader
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
---
### SelVA DITTO Optimizer
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `features` | From SelVA Feature Extractor |
| `prompt` | Sound description (leave empty to use features prompt) |
| `negative_prompt` | Sounds to suppress |
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
| `steps` | Euler steps for the final generation pass (default: 25) |
| `cfg_strength` | CFG scale (default: 4.5) |
| `seed` | RNG seed |
| `normalize` | *(optional)* RMS normalize output (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
**Output:** `AUDIO`
---
## Workflows
### Basic generation
``` ```
VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
│ (video_info) ─► (fps auto) │ (video_info)
│ (features) ────────────────────────────────────►│ │ (features) ──────────────────────────────────►│
│ (prompt) ──────────────────────────────────────►│ │ (prompt) ────────────────────────────────────►│
``` ```
Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction. ### DITTO style transfer (recommended first approach)
```
SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
SelVA Feature Extractor ──(features)────────────────────────────────────►│
(prompt) ──────────────────────────────────────►│
BJ reference_dir ───────────────────────────────────────────────────────►│
```
No training required. Each run optimizes x₀ independently for the current video and reference set.
### Vocoder fine-tuning
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
▲ ▲
checkpoint .pt SelVA Feature Extractor
```
### LoRA training
See [LORA_TRAINING.md](LORA_TRAINING.md).
--- ---
@@ -127,8 +366,15 @@ The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available,
--- ---
## Style Transfer
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
---
## Credits ## Credits
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training - [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework - [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
+158
View File
@@ -0,0 +1,158 @@
# Style Transfer for SelVA
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
---
## Why standard fine-tuning is hard
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~1015 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.210.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
The approaches below are ordered by expected quality and ease of use.
---
## Tier 1 — DITTO (recommended first try)
**Node: SelVA DITTO Optimizer**
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
**How it works:**
For each video clip you want to process:
1. Run SelVA Feature Extractor as usual.
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
```
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
SelVA Feature Extractor ──(features)────────────────────────►│
(prompt) ──────────────────────────►│
BJ clips ───────────────────────────(reference_dir) ─────────►│
```
**Tuning guide:**
| Parameter | Starting value | When to adjust |
|---|---|---|
| `n_opt_steps` | 50 | Increase to 100200 if style shift is too subtle |
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
| `style_weight` | 1.0 | Increase to 25 for stronger BJ character; watch for incoherence |
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~46 GB additional VRAM over baseline inference.
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 515 minutes depending on GPU.
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
---
## Tier 2 — Vocoder Fine-tuning
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
### Why plain mel L1 loss fails
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 38 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
### `snake_alpha_only` mode (default, recommended)
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
### `all_params` mode with discriminator
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
### Workflow
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor
```
### Tuning guide
| Parameter | Default | Notes |
|---|---|---|
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
| `steps` | 2000 | 10002000 for snake_alpha_only; 30005000 for all_params |
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
| `batch_size` | 4 | 48 for stable gradients |
| `segment_seconds` | 1.0 | 12 s segments recommended |
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
---
## Tier 3 — DITTO + Vocoder (combined)
Stack both:
```
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
```
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
---
## What doesn't work (and why)
### Standard LoRA
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
### Textual inversion
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
### Activation steering (global mean difference)
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 36 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
---
## Reference dataset preparation
Use the same audio clips for both DITTO and vocoder fine-tuning:
- **Minimum:** 2030 clips. DITTO works from 5+; vocoder benefits from 40+.
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
+170
View File
@@ -0,0 +1,170 @@
# Audio Dataset Pipeline for Generative Model Training
Research notes on audio cleaning, augmentation, and quality metrics for LoRA fine-tuning of MMAudio/SelVA. Based on papers and tooling survey (April 2026).
---
## Core Principle
Augmentation for generative models ≠ augmentation for classifiers.
The goal is **not invariance** — it is expanding the training manifold so the model learns the distribution of a sound rather than memorizing a fixed set of waveforms.
With 10 clips, velocity field collapse (arXiv:2410.23594) is mathematically expected: the flow-matching model memorizes the training trajectories instead of generalizing. More diverse data is the only real fix.
---
## Recommended Pipeline
### Step 1 — Quality Screening
```python
# Clipping check
clip_ratio = np.sum(np.abs(audio) >= 0.99) / len(audio) # flag if > 0.1%
# DC offset check + removal
dc = np.mean(audio)
audio -= dc
# LUFS normalization to -14 LUFS (essential for training consistency)
# pip install pyloudnorm
import pyloudnorm as pyln
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio)
audio = pyln.normalize.loudness(audio, loudness, -14.0)
# Or via ffmpeg: ffmpeg -af loudnorm=I=-14:LRA=7:TP=-1
# DNSMOS quality gate (discard if OVRL < 3.5 for training; < 2.5 is unusable)
# from Microsoft DNS-Challenge repo
```
### Step 2 — Cleaning
| Tool | Install | Use |
|---|---|---|
| **AudioSep** | `pip install audiosep` | Isolate target sound from background — most impactful tool |
| **noisereduce** | `pip install noisereduce` | Light stationary/non-stationary denoising, preserves character |
| **librosa** | `pip install librosa` | Silence trimming: `librosa.effects.trim(audio, top_db=30)` |
| **torchaudio.transforms.Fade** | (torchaudio) | Prevent click artifacts at clip edges |
| **DeepFilterNet** | `pip install deepfilternet` | Heavy denoising — good for speech, may alter tonal sounds |
**AudioSep usage:**
```python
from audiosep import AudioSep
model = AudioSep.from_pretrained("audio-agi/audiosep")
# ~1.5 GB checkpoint, ~4 GB VRAM
model.inference(audio_path, "a dog barking loudly", output_path)
```
### Step 3 — Waveform Augmentation (10 clips → 50100)
Apply stochastically per clip:
| Transform | Params | Notes |
|---|---|---|
| **PitchShift** | ±13 semitones | 3 variants per clip. Limit to ±1 st for tonal/pitched sounds |
| **ApplyImpulseResponse** | 5 different RIRs | 5 variants per clip — EchoThief (~150 free IRs) or pyroomacoustics |
| **LoudnessNormalization** | ±2 dB random | Subtle level variation |
| **SevenBandParametricEQ** | ±3 dB | Gentle spectral variation |
| **TimeStretch** | 0.91.1× only | Do NOT use 2× to pad short clips — breaks video sync |
```python
# pip install audiomentations pedalboard pyroomacoustics
import audiomentations as A
augment = A.Compose([
A.PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
A.ApplyImpulseResponse(ir_paths="path/to/irs/", p=0.5),
A.SevenBandParametricEQ(min_gain_db=-3, max_gain_db=3, p=0.3),
A.LoudnessNormalization(min_lufs=-16, max_lufs=-12, p=0.5),
A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3),
])
audio_aug = augment(samples=audio, sample_rate=sr)
```
**RIR sources:**
- EchoThief: ~150 free real-world IRs (churches, caves, parking garages)
- pyroomacoustics: synthetic room simulation, fully controllable
### Step 4 — Latent Augmentation (at training time)
After VAE encoding:
**Latent mixup** between same-category pairs:
```python
# Mix latents BEFORE flow-matching noise is added
# Only mix clips from the same sound category — cross-category mixing produces garbage
lam = torch.distributions.Beta(0.4, 0.4).sample()
z_mix = lam * z1 + (1 - lam) * z2
```
With 10 clips: C(10,2) = 45 possible pairs → significant expansion without new recordings.
**Small Gaussian noise:**
```python
z_noised = z + torch.randn_like(z) * 0.02 * z.std()
```
Prevents trivial memorization of exact latent coordinates.
MusicLDM (arXiv:2308.01546) shows latent mixup > waveform mixup for generative quality.
---
## Transforms to AVOID for Generative Training
| Transform | Why |
|---|---|
| ClippingDistortion, BitCrush, TanhDistortion, Mp3Compression | Model learns the artifact |
| Reverse | Breaks temporal structure for video-to-audio task |
| TimeMask (creating silence gaps) | Unnatural — model learns to produce silence |
| TimeStretch > 1.3× | Phase vocoder artifacts become part of the target distribution |
| Heavy background noise (< 15 dB SNR) | Model learns to reproduce the noise |
---
## Quality Metrics
| Metric | Tool | Threshold |
|---|---|---|
| DNSMOS P.835 (SIG/BAK/OVRL) | Microsoft DNS-Challenge | OVRL > 3.5 for training |
| LUFS | pyloudnorm | Normalize all clips to -14 LUFS |
| WADA-SNR | (standalone) | No-reference SNR estimate |
| Clipping ratio | NumPy | Flag if > 0.1% of samples at ±0.99 |
---
## Tool Reference
| Tool | Install | Purpose |
|---|---|---|
| audiomentations | `pip install audiomentations` | Primary augmentation library |
| pedalboard | `pip install pedalboard` | Higher quality pitch shift, IR convolution |
| AudioSep | `pip install audiosep` | Source separation / isolation |
| noisereduce | `pip install noisereduce` | Non-stationary denoising |
| DeepFilterNet | `pip install deepfilternet` | Heavy denoising (speech-optimized) |
| pyloudnorm | `pip install pyloudnorm` | LUFS normalization |
| Silero VAD | `pip install silero-vad` | Voice/silence detection |
| pyroomacoustics | `pip install pyroomacoustics` | Synthetic RIR generation |
---
## Integration with PrismAudio / SelVA
No established ComfyUI audio preprocessing ecosystem as of early 2026. Build thin wrapper nodes around the tools above. PrismAudio already has all required patterns (subprocess isolation, AUDIO type transport).
**Target node set:**
- `SelVA Dataset Cleaner` — wraps noisereduce + LUFS normalization + trim + DNSMOS gate
- `SelVA Dataset Augmenter` — wraps audiomentations Compose pipeline
Steps 13 are preprocessing (run once before feature extraction).
Step 4 (latent mixup) is a training loop modification — integrate into `selva_lora_trainer.py`.
---
## Key Papers
| Paper | ArXiv | Finding |
|---|---|---|
| MusicLDM | 2308.01546 | Latent mixup > waveform mixup for generative quality |
| EDMSound | 2311.08667 | Memorization documented — same failure mode as 10-clip training |
| Synthio | 2410.02056 | Synthetic audio as augmentation data (ICLR 2025) |
| HunyuanVideo-Foley | 2508.16930 | V2A data pipeline at scale (100K hrs) |
| FM memorization | 2410.23594 | Velocity field collapse theory — proves early overfitting on small datasets |
+184
View File
@@ -0,0 +1,184 @@
# AudioX vs SelVA — Evaluation
AudioX (arXiv:2503.10522, ICLR 2026) is a unified multimodal audio generation model from HKUST.
This document compares it against SelVA/MMAudio and assesses the cost of adding it to PrismAudio.
---
## Quick Decision Guide
| Situation | Use |
|---|---|
| Video → realistic sound effects | **SelVA** — faster, purpose-built, MIT license |
| Music generation from video or text | **AudioX** — SelVA cannot do this |
| Audio inpainting / music continuation | **AudioX** — SelVA cannot do this |
| LoRA fine-tuning on a custom sound | **SelVA** — full training infrastructure already exists |
| Variable output duration | **AudioX** — SelVA is fixed at 8 s |
| Inference speed matters | **SelVA** — 25 steps vs 250 (10× faster) |
| Non-commercial research | Either |
| Any commercial use | **SelVA only** — AudioX is CC-BY-NC-4.0 |
---
## Architecture
| Dimension | SelVA (MMAudio) | AudioX-MAF |
|---|---|---|
| Core paradigm | Flow matching | Diffusion (k-diffusion / DPM++) |
| Inference steps | 25 ODE steps (Euler) | 250 diffusion steps (DPM++ 3M SDE) |
| Sample rate | 44.1 kHz (large) / 16 kHz (small) | 48 kHz (fixed) |
| Generator | MM-DiT, velocity prediction | ContinuousMMDiTTransformer |
| Video encoder | Synchformer | Synchformer (AudioX custom re-impl, same concept) |
| VAE / codec | DAC (descript-audio-codec) | DAC + AudioCraft options |
| Text encoder | T5-large | T5 (configurable small → XXL) |
| Video-audio fusion | Cross-attention in MM-DiT | MAF: dual-projection (dim alignment + seq length alignment) |
| Output duration | Fixed 8 s | Configurable via `sample_size` (default ~44 s at 48kHz) |
| Training data | ~2 M samples (MMAudio paper) | 7 M samples (IF-caps dataset, curated) |
| License | MIT | CC-BY-NC-4.0 |
**MAF (Multimodal Adaptive Fusion):** AudioX's key architectural contribution. Instead of directly
concatenating multimodal tokens into the DiT's cross-attention, MAF projects each modality to
match the latent's sequence length via a dedicated linear + transposed-conv stack, then applies
`MMDitSingleBlock` layers for cross-modal fusion. The paper reports this improves cross-modal
alignment particularly for video-to-audio tasks.
**Flow matching vs diffusion:** Flow matching (SelVA) trains a single velocity field to move
directly from noise to data along a straight trajectory — this is why 25 steps suffice. Standard
diffusion (AudioX) approximates a longer stochastic path, requiring 250 steps for quality output.
This is not a quality difference per se; flow matching is simply more efficient.
---
## Capabilities
| Task | SelVA | AudioX |
|---|---|---|
| Video → sound effects | ✓ (primary use case) | ✓ |
| Text → sound effects | Partial (T5 conditions quality but not primary) | ✓ (strong benchmark scores) |
| Video → music | ✗ | ✓ |
| Text → music | ✗ | ✓ |
| Audio inpainting | ✗ | ✓ (mask_args parameter) |
| Music continuation | ✗ | ✓ (init_audio parameter) |
| Variable output duration | ✗ (fixed 8 s) | ✓ |
| Multiple input modalities simultaneously | Partial | ✓ (text + video + audio at once) |
AudioX benchmarks claim superior results on text-to-audio (AudioCaps) and text-to-music
(MusicCaps) vs prior models. Video-to-audio comparison against MMAudio specifically is not
prominently featured in the paper. Perceptual evaluation confirms this: AudioX does not sound
noticeably better than SelVA on video-to-audio tasks. AudioX's advantage is **breadth**
(music, inpainting, variable duration), not raw video-to-audio quality.
---
## Integration Cost
Adding AudioX inference-only nodes to PrismAudio would require:
### New nodes (3 files)
```
nodes/
audiox_model_loader.py AUDIOX_MODEL loader — get_pretrained_model("HKUSTAudio/AudioX-MAF")
audiox_sampler.py wraps generate_diffusion_cond(), inputs: model + text + video + audio
audiox_feature_extractor.py optional — pre-extract Synchformer sync features (caching)
```
### Installation
```bash
pip install git+https://github.com/ZeyueT/AudioX.git
```
New dependencies not currently in PrismAudio:
- `pytorch-lightning==2.4.0`
- `k-diffusion==0.1.1`
- `v-diffusion-pytorch==0.0.2`
- `descript-audio-codec==1.0.0` (already used by SelVA — no conflict, same package)
- `gradio==4.44.1` (optional — only for the upstream Gradio UI)
Model weights: `HKUSTAudio/AudioX-MAF` on HuggingFace (~several GB).
### Inference API surface
```python
from audiox import get_pretrained_model
from audiox.inference.generation import generate_diffusion_cond
model, config = get_pretrained_model("HKUSTAudio/AudioX-MAF")
output = generate_diffusion_cond(
model,
steps=250,
cfg_scale=6.0,
conditioning={
"text_prompt": "a dog barking",
"video_prompt": {"video": frames_tensor, "sync_features": sync_feat},
"seconds_total": 8.0,
},
sample_size=384000, # 8 s at 48kHz
sample_rate=48000,
device="cuda",
)
# output: torch.Tensor (batch, channels, num_samples) float32 [-1, 1]
```
---
## LoRA Training
Adding AudioX LoRA training to PrismAudio is **significantly harder** than the SelVA trainer:
| Aspect | SelVA LoRA | AudioX LoRA |
|---|---|---|
| Loss function | Single MSE velocity loss | Diffusion loss over 250-step schedule |
| Training steps needed | ~2000 steps practical | Unknown — likely much more |
| Step cost | Fast (1 velocity prediction) | Slow (full diffusion forward pass per step) |
| Existing infrastructure | Full trainer + scheduler + experiments | Nothing — would need to build from scratch |
| Noise schedule | Trivial (linear interpolation) | Cosine alpha-sigma schedule |
| Prior art for LoRA | LoRA on flow matching well-studied | Less explored; closer to Stable Diffusion LoRA |
**Conclusion:** AudioX LoRA training is feasible (it would follow SD-style LoRA with the DPM++
noise schedule) but would be a substantial new project. Not worth building until inference nodes
are stable and there is a clear use case that SelVA cannot serve.
---
## License
AudioX weights are released under **CC-BY-NC-4.0** (Creative Commons Non-Commercial).
- Free for personal use, research, and non-commercial projects
- **Cannot be used in commercial products or services** without a separate agreement
- Attribution required
- SelVA/MMAudio: MIT (no restrictions)
If PrismAudio is ever distributed as part of a commercial tool, AudioX nodes must be clearly
opt-in with a license warning, or excluded entirely.
---
## Recommendation
**Short term:** AudioX is not a replacement for SelVA for the current use case (video → custom
sound effects with LoRA fine-tuning). SelVA is faster, has full training infrastructure, and
is MIT licensed.
**When AudioX becomes worth integrating:**
- If you need to generate background music synchronized to video
- If you need audio inpainting (fill a gap in an existing audio track)
- If you need text-to-audio generation without a video input
- After verifying the CC-BY-NC-4.0 license is acceptable for your use
**Estimated integration effort for inference nodes only:** 23 days of work (3 new node files,
dependency management, testing). No changes to existing SelVA nodes required — they would
coexist in the same package.
---
## References
- Paper: arXiv:2503.10522 — *AudioX: Diffusion Transformer for Anything-to-Audio Generation*
- GitHub: https://github.com/ZeyueT/AudioX
- Model weights: https://huggingface.co/HKUSTAudio/AudioX-MAF
- Demo: https://huggingface.co/spaces/Zeyue7/AudioX
- Project page: https://zeyuet.github.io/AudioX/
@@ -0,0 +1,606 @@
# Audio Dataset Pipeline Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** Add 5 chainable ComfyUI nodes for in-memory audio dataset preprocessing: load → resample → LUFS normalize → inspect/filter → extract single item.
**Architecture:** Single new file `nodes/selva_dataset_pipeline.py` defines a custom `AUDIO_DATASET` type (list of dicts) and all 5 node classes. Nodes are stateless transforms — each takes `AUDIO_DATASET` and returns `AUDIO_DATASET`. No disk I/O except in the Loader. Register all nodes in `nodes/__init__.py`.
**Tech Stack:** `pyloudnorm` (BS.1770-4 LUFS), `soxr` (VHQ resampling), `torchaudio`, `torch`. Both confirmed present in the ComfyUI environment at `/media/p5/miniforge3/envs/latestcomfyui`.
---
## The `AUDIO_DATASET` type
Used as the ComfyUI type string `"AUDIO_DATASET"`. At runtime it is a Python list of dicts:
```python
[
{
"waveform": torch.Tensor, # shape [1, C, L], float32, range [-1, 1]
"sample_rate": int,
"name": str, # original filename stem, for reporting
},
...
]
```
---
### Task 1: Create the file skeleton and AUDIO_DATASET constant
**Files:**
- Create: `nodes/selva_dataset_pipeline.py`
**Step 1: Write the file with imports and type constant only**
```python
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
```
**Step 2: Verify import works (no test framework needed — just a quick smoke check)**
```bash
cd /media/p5/Comfyui-Prismaudio
python3 -c "from nodes.selva_dataset_pipeline import AUDIO_DATASET; print(AUDIO_DATASET)"
```
Expected output: `AUDIO_DATASET`
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add audio dataset pipeline skeleton"
```
---
### Task 2: SelvaDatasetLoader
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Loader class**
```python
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = torchaudio.load(str(f)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader
node = SelvaDatasetLoader()
ds, = node.load('/media/unraid/davinci/Selva/BJ')
print(len(ds), 'clips', ds[0]['name'], ds[0]['waveform'].shape, ds[0]['sample_rate'])
"
```
Expected: prints clip count, first clip name, shape like `torch.Size([1, 2, 352800])`, sample rate.
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLoader node"
```
---
### Task 3: SelvaDatasetResampler
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Resampler class**
Uses `soxr` directly for VHQ quality. `soxr.resample` operates on numpy arrays, shape `[L, C]` (time-first).
```python
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetResampler
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetResampler().resample(ds, 44100)
print('ok', ds2[0]['sample_rate'], ds2[0]['waveform'].shape)
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetResampler node (soxr VHQ)"
```
---
### Task 4: SelvaDatasetLUFSNormalizer
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the LUFS normalizer class**
`pyloudnorm.Meter` requires numpy float64 array shape `[L]` (mono) or `[L, C]` (multichannel, channels last). True peak limit applied after gain.
```python
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetLUFSNormalizer
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetLUFSNormalizer().normalize(ds, -23.0, -1.0)
print('ok', ds2[0]['name'], ds2[0]['waveform'].abs().max().item())
"
```
Expected: peak ≤ ~0.89 (≈ -1 dBTP).
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)"
```
---
### Task 5: SelvaDatasetInspector
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add helper functions for artifact detection**
```python
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
window = torch.hann_window(n_fft)
mono = wav[0].mean(0) # [L]
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
```
**Step 2: Add the Inspector class**
```python
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
```
**Step 3: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetInspector
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
clean, report = SelvaDatasetInspector().inspect(ds, skip_rejected=False, min_snr_db=15.0, check_codec_artifacts=True)
print(report)
"
```
Expected: report with per-clip OK/FLAGGED lines and summary counts.
**Step 4: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)"
```
---
### Task 6: SelvaDatasetItemExtractor
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the extractor class**
```python
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetItemExtractor
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
audio, name, total = SelvaDatasetItemExtractor().extract(ds, 0)
print(name, total, audio['waveform'].shape, audio['sample_rate'])
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetItemExtractor node"
```
---
### Task 7: Register all nodes in __init__.py
**Files:**
- Modify: `nodes/__init__.py:4-25`
**Step 1: Add the 5 new entries to `_NODES`**
Add inside the `_NODES` dict, after `"SelvaDittoOptimizer"`:
```python
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
```
**Step 2: Verify registration**
```bash
python3 -c "
import sys; sys.path.insert(0, '/media/p5/Comfyui-Prismaudio')
from nodes import NODE_CLASS_MAPPINGS
keys = [k for k in NODE_CLASS_MAPPINGS if 'Dataset' in k]
print(keys)
"
```
Expected: list of 5 dataset node keys.
**Step 3: Final commit**
```bash
git add nodes/__init__.py
git commit -m "feat: register audio dataset pipeline nodes in __init__.py"
```
---
## Summary
5 nodes in `nodes/selva_dataset_pipeline.py`, all registered in `__init__.py`:
| Node | In | Out |
|------|----|-----|
| SelvaDatasetLoader | folder path | AUDIO_DATASET |
| SelvaDatasetResampler | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetLUFSNormalizer | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetInspector | AUDIO_DATASET | AUDIO_DATASET + STRING |
| SelvaDatasetItemExtractor | AUDIO_DATASET + index | AUDIO + name + total |
Dependencies: `pyloudnorm`, `soxr` — both confirmed present in the ComfyUI env.
+77
View File
@@ -0,0 +1,77 @@
{
"name": "alpha_scale_sweep",
"description": "Fix LoRA noise contamination (flatness 0.013→0.094 at alpha=rank). Root cause: alpha=rank (scale=1.0) at high rank drowns base model priors. Testing dramatically lower alpha to nudge rather than overwrite. All runs at lr=3e-4 (best stable LR from r128_sweet_spot).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/alpha_scale_sweep",
"base": {
"steps": 6000,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "g1_r16_alpha4",
"group": "conservative",
"description": "Back to basics: rank=16 alpha=4 (scale=0.25). Small adapter, gentle scale — cleanest possible LoRA signal.",
"rank": 16,
"alpha": 4.0
},
{
"id": "g1_r16_alpha16",
"group": "conservative",
"description": "rank=16 alpha=16 (scale=1.0) — the original default. Reference point: is the noise issue rank-specific or universal?",
"rank": 16,
"alpha": 16.0
},
{
"id": "g2_r32_alpha8",
"group": "mid",
"description": "rank=32 alpha=8 (scale=0.25). More capacity than r16 but still gentle scale.",
"rank": 32,
"alpha": 8.0
},
{
"id": "g2_r32_alpha32",
"group": "mid",
"description": "rank=32 alpha=32 (scale=1.0). Same rank, full scale — isolates whether scale or rank is causing noise.",
"rank": 32,
"alpha": 32.0
},
{
"id": "g3_r128_alpha8",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=8 (scale=0.0625). High capacity, very gentle contribution — can r128 stay clean at low alpha?",
"rank": 128,
"alpha": 8.0
},
{
"id": "g3_r128_alpha16",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=16 (scale=0.125). Slightly more signal than alpha=8.",
"rank": 128,
"alpha": 16.0
},
{
"id": "g3_r128_alpha32",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=32 (scale=0.25). Same scale as r16_alpha4 and r32_alpha8 — comparable contribution across ranks.",
"rank": 128,
"alpha": 32.0
}
]
}
+31
View File
@@ -0,0 +1,31 @@
{
"name": "bigvgan_disc_fm_retest",
"description": "Retest discriminator feature matching after bfloat16 dtype fix. Uses optimal config from overnight sweep (snake_alpha, GAFilter, lr=1e-4, phase=1.0, L2-SP=1e-3, 5000 steps).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_disc_fm_retest",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_5k_control",
"description": "Control: best config from overnight sweep without discriminator. Baseline for A/B comparison."
},
{
"id": "disc_fm_5k",
"description": "Discriminator feature matching at 5k steps. Tests if perceptual FM loss improves over mel+phase alone.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
}
]
}
@@ -0,0 +1,35 @@
{
"name": "bigvgan_optimized_dataset",
"description": "BigVGAN fine-tuning on optimized dataset (134 clips, 44.1kHz, LUFS-normalized). Standard mode (no LoRA) — trains decoder to faithfully reconstruct target domain audio from mel spectrograms. Uses optimal config from prior sweeps.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_optimized_dataset",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42
},
"experiments": [
{
"id": "standard_5k",
"description": "Standard mode: mel from clean FLAC → BigVGAN → reconstruct FLAC. No LoRA. Directly improves VAE roundtrip quality."
},
{
"id": "disc_fm_5k",
"description": "Standard mode + discriminator feature matching. Tests if perceptual loss helps on clean audio reconstruction.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "standard_10k",
"description": "Extended 10k steps. More data passes on 134 clips may extract more from the optimized dataset.",
"steps": 10000
}
]
}
+65
View File
@@ -0,0 +1,65 @@
{
"name": "bigvgan_overnight",
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
"base": {
"train_mode": "snake_alpha_only",
"steps": 3000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_3k_baseline",
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
},
{
"id": "snake_5k",
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
"steps": 5000
},
{
"id": "snake_no_gafilter",
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
"use_gafilter": false
},
{
"id": "snake_no_phase",
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
"lambda_phase": 0.0
},
{
"id": "snake_phase_2",
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
"lambda_phase": 2.0
},
{
"id": "snake_lr5e-5",
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
"lr": 5e-5,
"steps": 5000
},
{
"id": "snake_disc_fm",
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "all_2k_l2sp1e-2",
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
"train_mode": "all_params",
"steps": 2000,
"lr": 1e-5,
"lambda_l2sp": 1e-2
}
]
}
+39
View File
@@ -0,0 +1,39 @@
{
"name": "eval_r128_candidates",
"description": "Top candidates from r128_sweet_spot. Comparing the two lowest-loss runs, the stable lr=3e-4, and the curriculum run that hit 0.161 before regressing. Baseline included as perceptual reference.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_dir": "/media/unraid/davinci/Selva/BJ/evals/r128_candidates",
"steps": 25,
"seed": 42,
"adapters": [
{
"id": "baseline",
"description": "No LoRA — base model output for perceptual reference"
},
{
"id": "lr_5e4_r128",
"description": "Best loss overall (0.137), still descending at step 10k",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_5e4/adapter_final.pt"
},
{
"id": "lr_3e4_r256",
"description": "Tied with lr_5e4 at 0.139, higher rank — does extra capacity help perceptually?",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g4_r256_lr_3e4/adapter_final.pt"
},
{
"id": "lr_3e4_r128",
"description": "Stable plateau from step 4k to 10k (0.221) — visually confirmed clean spectrograms",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_3e4/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4",
"description": "Best min loss of all (0.161 at step 6k), regressed to 0.193 after curriculum switch — curious if the early checkpoint sounds better",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4_step6000",
"description": "Same run at its actual best step (before regression) — compare against adapter_final to hear the regression",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_step06000.pt"
}
]
}
+33
View File
@@ -0,0 +1,33 @@
{
"name": "lora_logit_cosine_combo",
"description": "Combine the two best findings from optimized dataset sweep: logit-normal timestep sampling + cosine LR schedule. Both individually outperformed baseline by large margins (56% and 68% lower loss). Tests if gains stack.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_logit_cosine_combo",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "logit_normal_cosine",
"description": "Logit-normal timesteps + cosine LR decay. Combining the two best individual improvements.",
"timestep_mode": "logit_normal",
"lr_schedule": "cosine"
},
{
"id": "logit_normal_control",
"description": "Control: logit-normal only (constant LR). Reproduces previous winner for direct comparison.",
"timestep_mode": "logit_normal"
}
]
}
+64
View File
@@ -0,0 +1,64 @@
{
"name": "lora_optimized_dataset",
"description": "LoRA training on optimized dataset (134 clips: resampled 44.1kHz, LUFS-normalized, spectral matched, HF smoothed, gain-augmented). Tests latent augmentation and schedule variants on top of known-best config (PiSSA, rank=128, lr=3e-4).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_optimized_dataset",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "baseline",
"description": "Control: known-best config (PiSSA r128 lr=3e-4) on the optimized dataset. No latent augmentation."
},
{
"id": "latent_mixup",
"description": "Latent mixup alpha=0.4 (MusicLDM). Tests if mixing training latents reduces memorization on 134 clips.",
"latent_mixup_alpha": 0.4
},
{
"id": "latent_noise",
"description": "Latent noise sigma=0.02. Mild Gaussian noise on training latents for regularization.",
"latent_noise_sigma": 0.02
},
{
"id": "mixup_and_noise",
"description": "Both latent mixup (0.4) and noise (0.02). Combined regularization.",
"latent_mixup_alpha": 0.4,
"latent_noise_sigma": 0.02
},
{
"id": "cosine_schedule",
"description": "Cosine LR decay. lr=3e-4 was stable with constant, but cosine may extract more from 5k steps.",
"lr_schedule": "cosine"
},
{
"id": "cosine_mixup",
"description": "Cosine LR + latent mixup. Best regularization combo candidate.",
"lr_schedule": "cosine",
"latent_mixup_alpha": 0.4
},
{
"id": "logit_normal",
"description": "Logit-normal timestep sampling (sigma=1.0). Concentrates training near t=0.5 where flow matching is hardest.",
"timestep_mode": "logit_normal"
},
{
"id": "curriculum_mixup",
"description": "Curriculum timesteps (logit_normal first 60%, then uniform) + latent mixup. Full regularization stack.",
"timestep_mode": "curriculum",
"latent_mixup_alpha": 0.4
}
]
}
+62
View File
@@ -0,0 +1,62 @@
{
"name": "pissa_sweep",
"description": "PiSSA vs standard init ablation at rank 128. Best prior config (lr=3e-4, bs=16, 10k steps) as baseline. PiSSA starts on-manifold via SVD init — should eliminate intruder dimensions. rsLoRA stabilises scaling at high rank.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": true
},
"experiments": [
{
"id": "standard_baseline",
"description": "Standard Kaiming init + classic alpha/rank scaling. Replicates best prior config for A/B comparison.",
"init_mode": "standard",
"use_rslora": false
},
{
"id": "pissa_rslora",
"description": "PiSSA init + rsLoRA scaling. Full Tier-S config. Should start on-manifold and avoid intruder dimensions."
},
{
"id": "pissa_classic_scale",
"description": "PiSSA init + classic alpha/rank scaling. Isolates PiSSA contribution from rsLoRA.",
"use_rslora": false
},
{
"id": "standard_rslora",
"description": "Standard init + rsLoRA only. Isolates rsLoRA contribution from PiSSA.",
"init_mode": "standard"
},
{
"id": "pissa_rslora_lr1e-4",
"description": "PiSSA+rsLoRA at lower lr=1e-4. PiSSA starts closer to optimum — may need less aggressive lr.",
"lr": 1e-4
},
{
"id": "pissa_rslora_lr5e-4",
"description": "PiSSA+rsLoRA at higher lr=5e-4. Test if on-manifold start tolerates faster learning.",
"lr": 5e-4
},
{
"id": "pissa_rslora_dropout",
"description": "PiSSA+rsLoRA with dropout 0.05. Note: PiSSA forces dropout=0 (principal components should not be dropped) — this tests standard init with rsLoRA + dropout.",
"init_mode": "standard",
"lora_dropout": 0.05
}
]
}
+103
View File
@@ -0,0 +1,103 @@
{
"name": "r128_sweet_spot",
"description": "Find the noise-free sweet spot on rank 128. LoRA+ ratio=16 caused noise — testing higher base LR without LoRA+ as a cleaner alternative. Target loss range 0.250.35. Also probing rank 256 since 102GB VRAM allows it.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r128_lr_2e4",
"group": "lr",
"description": "LR=2e-4. Conservative 2× step up from baseline — noise-free descent toward sweet spot.",
"lr": 2e-4
},
{
"id": "g1_r128_lr_3e4",
"group": "lr",
"description": "LR=3e-4. 3× baseline — landed at 0.41 on r64, should reach 0.250.35 on r128.",
"lr": 3e-4
},
{
"id": "g1_r128_lr_5e4",
"group": "lr",
"description": "LR=5e-4. Aggressive but no LoRA+ B-matrix asymmetry — cleaner noise profile.",
"lr": 5e-4
},
{
"id": "g2_r128_curriculum",
"group": "curriculum",
"description": "Curriculum only at baseline LR. Clean slow descent — reference for what curriculum contributes alone.",
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum",
"group": "curriculum",
"description": "LR=3e-4 + curriculum. Speed of higher LR with coverage of curriculum — no LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum_dropout",
"group": "curriculum",
"description": "LR=3e-4 + curriculum + dropout=0.05. Full controlled stack without LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum",
"lora_dropout": 0.05
},
{
"id": "g3_r128_lora_plus_4",
"group": "lora_plus",
"description": "LoRA+ ratio=4 (lr_B=4e-4). Much more conservative than ratio=16 — tests if noise came from ratio not the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g4_r256_baseline",
"group": "rank256",
"description": "Rank 256 at baseline LR. 102GB VRAM makes this viable — does more capacity keep helping?",
"rank": 256
},
{
"id": "g4_r256_lr_3e4",
"group": "rank256",
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
"rank": 256,
"lr": 3e-4
},
{
"id": "g5_r128_lr_2e4_cosine",
"group": "cosine",
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 60008000 by decaying LR to ~0 instead of staying flat.",
"lr": 2e-4,
"lr_schedule": "cosine"
},
{
"id": "g5_r128_lr_3e4_cosine",
"group": "cosine",
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
"lr": 3e-4,
"lr_schedule": "cosine"
}
]
}
+130
View File
@@ -0,0 +1,130 @@
{
"name": "r64_overnight",
"description": "Focused rank-64 overnight sweep. All experiments use rank 64 as base — confirmed best from tier1_thorough early results. 8000 steps to reach convergence (none converged at 4000).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r64_overnight",
"base": {
"steps": 8000,
"rank": 64,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r64_baseline",
"group": "rank",
"description": "Rank 64 baseline — clean reference at 8000 steps."
},
{
"id": "g1_r128_baseline",
"group": "rank",
"description": "Rank 128 — 102GB VRAM makes this free. Does doubling rank from 64 help further?",
"rank": 128
},
{
"id": "g2_r64_alpha_32",
"group": "alpha",
"description": "Rank 64 alpha=32 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 32.0
},
{
"id": "g2_r64_alpha_16",
"group": "alpha",
"description": "Rank 64 alpha=16 (scale=0.25). More aggressive scale reduction — may over-constrain.",
"alpha": 16.0
},
{
"id": "g3_r64_lora_plus",
"group": "regularisation",
"description": "LoRA+ ratio=16. lr_B = 16 × lr_A. Faster convergence at constant step budget.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_r64_dropout_0.05",
"group": "regularisation",
"description": "Dropout=0.05. Light sparsity regularisation on LoRA path.",
"lora_dropout": 0.05
},
{
"id": "g3_r64_dropout_0.1",
"group": "regularisation",
"description": "Dropout=0.1. Stronger regularisation — tests if 49 clips needs heavier constraint.",
"lora_dropout": 0.1
},
{
"id": "g3_r64_curriculum",
"group": "regularisation",
"description": "Curriculum sampling: logit_normal for steps 1-4800, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_r64_lr_low",
"group": "lr",
"description": "LR=3e-5. 3× lower — checks if 1e-4 is overshooting at rank 64.",
"lr": 3e-5
},
{
"id": "g4_r64_lr_high",
"group": "lr",
"description": "LR=3e-4. 3× higher — may converge faster but risk instability.",
"lr": 3e-4
},
{
"id": "g5_r64_target_full",
"group": "target",
"description": "Rank 64 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g5_r128_target_full",
"group": "target",
"description": "Rank 128 + full target. Maximum possible coverage with available VRAM.",
"rank": 128,
"target": "attn.qkv linear1"
},
{
"id": "g6_r64_full_tier1",
"group": "combined",
"description": "All Tier 1 at rank 64: LoRA+ 16 + dropout 0.05 + curriculum. Full stack at 8000 steps.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r64_alpha32_full",
"group": "combined",
"description": "Rank 64 alpha=32 + all Tier 1. Best alpha scaling + best regularisation stack.",
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r128_full_tier1",
"group": "combined",
"description": "Rank 128 + all Tier 1. Tests if more capacity + regularisation beats rank 64 full.",
"rank": 128,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
}
]
}
+52
View File
@@ -0,0 +1,52 @@
{
"name": "ti_sweep_1",
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
"base": {
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_text": "",
"lr": 2e-4,
"n_tokens": 4,
"inject_mode": "suffix"
},
"experiments": [
{
"id": "n4_baseline",
"group": "reference",
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
},
{
"id": "n4_clamped",
"group": "norm_clamp",
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
},
{
"id": "n4_prefix_clamped",
"group": "norm_clamp",
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
"inject_mode": "prefix"
},
{
"id": "n8_prefix_clamped",
"group": "norm_clamp",
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
"n_tokens": 8,
"inject_mode": "prefix"
},
{
"id": "n4_prefix_warm_clamped",
"group": "norm_clamp",
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
"inject_mode": "prefix",
"init_text": "mechanical impact sound design"
}
]
}
+61
View File
@@ -0,0 +1,61 @@
{
"name": "tier1_sweep",
"description": "Ablation of Tier 1 improvements: LoRA+, dropout, curriculum sampling. Baseline = uniform, no regularisation.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "lora_sweeps/tier1_sweep",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "baseline",
"description": "Standard LoRA — no Tier 1 changes. Reference point."
},
{
"id": "lora_plus_16",
"description": "LoRA+ only: lr_B = 16 * lr_A. Should converge faster in early steps.",
"lora_plus_ratio": 16.0
},
{
"id": "dropout_0.05",
"description": "LoRA dropout 0.05 only. Light regularisation for 49-clip dataset.",
"lora_dropout": 0.05
},
{
"id": "dropout_0.1",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "curriculum",
"description": "Curriculum sampling only: logit_normal for steps 1-2400, then uniform. Should improve convergence vs pure uniform.",
"timestep_mode": "curriculum"
},
{
"id": "full_tier1",
"description": "All Tier 1 combined: LoRA+ + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "rank_64",
"description": "Rank 64 baseline — MMAudio LoRA guide default. More expressive adapter for 49-clip dataset.",
"rank": 64
}
]
}
+144
View File
@@ -0,0 +1,144 @@
{
"name": "tier1_thorough",
"description": "Full overnight Tier 1 ablation on 49-clip BJ dataset. 4 groups: rank, alpha, regularisation, and best combinations. ~10-12h depending on GPU.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/tier1_thorough",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 1000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_rank_16",
"group": "rank",
"description": "Rank 16 baseline — reference point for all groups."
},
{
"id": "g1_rank_32",
"group": "rank",
"description": "Rank 32 — midpoint. Does doubling rank improve quality without overfitting?",
"rank": 32
},
{
"id": "g1_rank_64",
"group": "rank",
"description": "Rank 64 — MMAudio LoRA guide default. Maximum expressiveness at 49 clips.",
"rank": 64
},
{
"id": "g2_alpha_half_r16",
"group": "alpha",
"description": "Alpha=8 with rank 16 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 8.0
},
{
"id": "g2_alpha_half_r64",
"group": "alpha",
"description": "Alpha=32 with rank 64 (scale=0.5). Best-practice scaling for high-rank adapters.",
"rank": 64,
"alpha": 32.0
},
{
"id": "g3_lora_plus_4",
"group": "regularisation",
"description": "LoRA+ ratio=4 — conservative asymmetric LR. Lower bound for the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g3_lora_plus_16",
"group": "regularisation",
"description": "LoRA+ ratio=16 — standard from FLUX LoRA literature. Faster early convergence.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_dropout_0.05",
"group": "regularisation",
"description": "LoRA dropout 0.05 only. Light sparsity regularisation (arXiv:2404.09610).",
"lora_dropout": 0.05
},
{
"id": "g3_dropout_0.1",
"group": "regularisation",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "g3_curriculum",
"group": "regularisation",
"description": "Curriculum sampling only: logit_normal steps 1-2400, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r16",
"group": "combined",
"description": "All Tier 1 at rank 16: LoRA+ 16 + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r64",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32. Best expressiveness + best regularisation.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g5_lr_low",
"group": "lr",
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
"lr": 3e-5
},
{
"id": "g5_lr_high",
"group": "lr",
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
"lr": 3e-4
},
{
"id": "g6_target_full_r16",
"group": "target",
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g6_target_full_r64",
"group": "target",
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
"rank": 64,
"alpha": 32.0,
"target": "attn.qkv linear1"
},
{
"id": "g4_full_r64_6k",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum",
"steps": 6000,
"save_every": 1000
}
]
}
+30
View File
@@ -0,0 +1,30 @@
{
"name": "vocoder_finetune",
"description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "r128_lr_3e4_bj_vocoder",
"description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4."
}
]
}
+30
View File
@@ -5,6 +5,36 @@ _NODES = {
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"), "SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"), "SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"), "SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
"SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"),
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetCompressor": (".selva_dataset_pipeline", "SelvaDatasetCompressor", "SelVA Dataset Compressor"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
@@ -0,0 +1,201 @@
"""SelVA Activation Steering Extractor.
Computes per-block steering vectors by running the frozen generator on the
training dataset and recording how target style's conditioning shifts the DiT hidden
states vs. empty/unconditional conditioning.
For each block i:
steering[i] = mean(latent_hidden | target style conditions)
- mean(latent_hidden | empty conditions)
The resulting vectors are injected at inference time (via SelVA Sampler's
steering_strength input) to nudge the denoising trajectory toward target style's
activation patterns without modifying any model weights.
"""
import random
from pathlib import Path
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import _prepare_dataset
def _collect_activations(generator, conditions, latent, t_tensor):
"""Run one predict_flow call, collecting latent hidden states per block.
Returns a list of [seq, hidden_dim] float32 CPU tensors,
one per block (joint_blocks first, then fused_blocks).
"""
activations = []
def make_hook(is_joint):
def hook(module, input, output):
h = output[0] if is_joint else output
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
return hook
handles = []
for block in generator.joint_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
for block in generator.fused_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
try:
with torch.no_grad():
generator.predict_flow(latent, t_tensor, conditions)
finally:
for h in handles:
h.remove()
return activations # list of n_blocks tensors [seq, hidden]
class SelvaActivationSteeringExtractor:
"""Computes activation steering vectors from a training dataset.
Runs the frozen generator on N clips at random timesteps with both
target style-conditioned and empty-conditioned inputs, then saves the mean
difference per DiT block to a .pt file.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "extract"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("steering_path",)
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
DESCRIPTION = (
"Computes per-block activation steering vectors: mean(target style activations) "
"mean(empty activations) at each DiT block. Load the result with "
"SelVA Activation Steering Loader and connect to the Sampler."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
}),
"output_path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
}),
"n_samples": ("INT", {
"default": 16, "min": 1, "max": 256,
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
}
def extract(self, model, data_dir, output_path, n_samples, seed):
device = get_device()
dtype = model["dtype"]
seq_cfg = model["seq_cfg"]
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
print(f"[Steering] data_dir = {data_dir}", flush=True)
print(f"[Steering] output = {out_path}\n", flush=True)
dataset = _prepare_dataset(model, data_dir, device)
generator = model["generator"]
generator.eval()
torch.manual_seed(seed)
random.seed(seed)
indices = random.choices(range(len(dataset)), k=n_samples)
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
style_sums = [None] * n_blocks
empty_sums = [None] * n_blocks
counts = [0] * n_blocks
pbar = comfy.utils.ProgressBar(n_samples)
for sample_i, clip_idx in enumerate(indices):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
clip_latent_seq_len = x1_cpu.shape[1]
generator.update_seq_lengths(
latent_seq_len=clip_latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
# Random timestep and noise latent for this clip
t_val = torch.rand(1).item()
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
latent = torch.randn(
1, clip_latent_seq_len, generator.latent_dim,
device=device, dtype=dtype,
)
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
if style_sums[i] is None:
style_sums[i] = st.clone()
empty_sums[i] = em.clone()
else:
style_sums[i] += st
empty_sums[i] += em
counts[i] += 1
pbar.update(1)
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
# Steering vector per block: mean(target style) - mean(empty)
steering_vectors = []
for i in range(n_blocks):
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
steering_vectors.append(vec)
norm = vec.norm().item()
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
n_joint = len(generator.joint_blocks)
payload = {
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
"n_blocks": n_blocks,
"n_joint": n_joint,
"n_fused": len(generator.fused_blocks),
"latent_seq_len": seq_cfg.latent_seq_len,
"n_samples": n_samples,
"seed": seed,
"mode": model["mode"],
"variant": model["variant"],
}
torch.save(payload, str(out_path))
print(f"\n[Steering] Saved: {out_path}", flush=True)
soft_empty_cache()
return (str(out_path),)
+62
View File
@@ -0,0 +1,62 @@
"""SelVA Activation Steering Loader.
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaActivationSteeringLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("STEERING_VECTORS",)
RETURN_NAMES = ("steering_vectors",)
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
DESCRIPTION = (
"Loads activation steering vectors from a .pt file produced by "
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
"denoising toward the target activation patterns."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[Steering] File not found: {p}")
payload = torch.load(str(p), map_location="cpu", weights_only=False)
n_blocks = payload["n_blocks"]
n_joint = payload["n_joint"]
n_fused = payload["n_fused"]
n_vecs = len(payload["steering_vectors"])
print(f"[Steering] Loaded: {p}", flush=True)
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
f"latent_seq_len={payload['latent_seq_len']} "
f"n_samples={payload['n_samples']}", flush=True)
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
mean_norm = sum(norms) / len(norms)
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
return (payload,)
+153
View File
@@ -0,0 +1,153 @@
"""SelVA Audio Post-Processing nodes.
Post-generation enhancement applied to standard AUDIO outputs:
SelvaHarmonicExciter — multi-band harmonic exciter (HPF → tanh → mix)
SelvaOutputNormalizer — LUFS normalization + true peak limiting
"""
import numpy as np
import torch
from .utils import SELVA_CATEGORY
class SelvaHarmonicExciter:
"""Multi-band harmonic exciter for post-generation enhancement.
Isolates high-frequency content above a cutoff, applies tanh saturation
to generate 2nd/3rd harmonics, then mixes back with the dry signal.
Restores harmonic richness lost during BigVGAN vocoder reconstruction.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 3000.0, "min": 500.0, "max": 16000.0, "step": 100.0,
"tooltip": "Highpass cutoff frequency in Hz. Only content above this is excited. "
"3000 Hz targets the upper harmonics BigVGAN tends to smear.",
}),
"drive": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.5,
"tooltip": "Saturation drive. Higher = more harmonics generated. "
"2-3 is subtle, 5+ is aggressive.",
}),
"mix": ("FLOAT", {
"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Wet/dry blend. 0.1-0.2 is subtle enhancement, "
"0.5+ is aggressive harmonic addition.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "excite"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Multi-band harmonic exciter. Applies tanh saturation to the high-frequency band "
"to restore harmonics lost during BigVGAN vocoder reconstruction. "
"Uses pedalboard.HighpassFilter for band isolation."
)
def excite(self, audio, cutoff_hz: float, drive: float, mix: float):
from pedalboard import Pedalboard, HighpassFilter
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
wav_np = wav.float().numpy() # [C, T]
# Isolate HF band
board = Pedalboard([HighpassFilter(cutoff_frequency_hz=cutoff_hz)])
hf = board(wav_np, sr) # [C, T]
# Tanh saturation — normalize by drive so output stays in [-1, 1]
excited = np.tanh(hf * drive) / max(drive, 1.0)
# Mix back with dry
mixed = wav_np + mix * excited
# Soft clip to prevent going over
mixed = np.tanh(mixed)
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, T]
print(
f"[HarmonicExciter] cutoff={cutoff_hz}Hz drive={drive} mix={mix:.0%}",
flush=True,
)
return ({"waveform": wav_out, "sample_rate": sr},)
class SelvaOutputNormalizer:
"""Normalize generated audio to a target LUFS level with true peak limiting.
Apply as the final node before saving — brings generated audio to a
consistent loudness target regardless of input video loudness variation.
Uses pyloudnorm (BS.1770-4).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"target_lufs": ("FLOAT", {
"default": -14.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. "
"-14 LUFS for streaming (Spotify/YouTube), "
"-9 to -7 for production masters.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP applied after LUFS gain.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize output audio to a target LUFS level (BS.1770-4) with true peak limiting. "
"Apply as the last node before saving. Uses pyloudnorm."
)
def normalize(self, audio, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
wav_np = wav.permute(1, 0).double().numpy() # [T, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [T] mono
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(wav_np)
if not np.isfinite(loudness):
print("[OutputNormalizer] Could not measure loudness — clip too short or silent. Passing through.", flush=True)
return (audio,)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_out = wav * gain_linear
peak = wav_out.abs().max().item()
if peak > tp_linear:
wav_out = wav_out * (tp_linear / peak)
print(
f"[OutputNormalizer] {loudness:.1f} LUFS → {target_lufs} LUFS "
f"gain={gain_db:+.1f}dB TP={true_peak_dbtp}dBTP",
flush=True,
)
return ({"waveform": wav_out.unsqueeze(0), "sample_rate": sr},)
+293
View File
@@ -0,0 +1,293 @@
"""SelVA Audio Preprocessors — condition training clips for codec compatibility.
Two nodes that reduce the domain mismatch between custom training audio and the
MMAudio VAE's expected spectral distribution, improving LoRA training quality:
SelvaHfSmoother — soft low-pass blend to attenuate extreme HF content
SelvaSpectralMatcher — adaptive per-band EQ toward the codec's training distribution
Root cause they address: MMAudio was trained on natural sounds (speech, foley, env)
with limited engineered HF content. The BigVGANv2 vocoder (frozen, pre-trained) handles
the codec's HF reconstruction poorly for sound design / music training clips, because
those clips land in a latent-space region the vocoder never saw during training.
Recommended order: SpectralMatcher → HfSmoother → feature extraction → LoRA training.
"""
import numpy as np
import torch
import torchaudio.functional as AF
from .utils import SELVA_CATEGORY
# ── Mel filterbank (same algorithm as selva_core/ext/mel_converter.py) ────────
def _mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float, fmax: float) -> torch.Tensor:
"""Returns mel filterbank matrix [n_mels, n_fft//2+1]."""
def hz_to_mel(f):
return 2595.0 * np.log10(1.0 + np.asarray(f) / 700.0)
def mel_to_hz(m):
return 700.0 * (10.0 ** (np.asarray(m) / 2595.0) - 1.0)
n_freqs = n_fft // 2 + 1
fft_freqs = np.linspace(0.0, sr / 2.0, n_freqs)
mel_pts = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
hz_pts = mel_to_hz(mel_pts)
fb = np.zeros((n_mels, n_freqs), dtype=np.float32)
for m in range(1, n_mels + 1):
lo, mid, hi = hz_pts[m - 1], hz_pts[m], hz_pts[m + 1]
up = (fft_freqs - lo) / (mid - lo + 1e-12)
down = (hi - fft_freqs) / (hi - mid + 1e-12)
fb[m - 1] = np.maximum(0.0, np.minimum(up, down))
return torch.from_numpy(fb)
# ── VAE target log-mel means (source: selva_core/ext/autoencoder/vae.py) ──────
# These are the per-band expected log-mel energy means from MMAudio's training data.
# Used as the spectral matching target: clips are EQ'd to match this profile.
_TARGET_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439,
-1.2922, -1.2927, -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912,
-1.4313, -1.4152, -1.4527, -1.4728, -1.4568, -1.5101, -1.5051, -1.5172,
-1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, -1.6081, -1.6331,
-1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377,
-1.8417, -1.8643, -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673,
-1.9824, -2.0042, -2.0215, -2.0436, -2.0766, -2.1064, -2.1418, -2.1855,
-2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, -2.4659, -2.5072,
-2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673,
]
_TARGET_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006,
-2.2357, -2.4597, -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047,
-2.7483, -2.5926, -2.7462, -2.7033, -2.7386, -2.8112, -2.7502, -2.9594,
-2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, -3.1191, -2.9893,
-3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509,
-3.5089, -3.4647, -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747,
-3.7072, -3.7279, -3.7283, -3.7795, -3.8259, -3.8447, -3.8663, -3.9182,
-3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, -4.1488, -4.1874,
-4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053,
-5.4927, -5.5712, -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103,
-6.0955, -6.1673, -6.2362, -6.3120, -6.3926, -6.4797, -6.5565, -6.6511,
-6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, -7.6136, -7.7469,
-7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861,
]
_MEL_CONFIGS = {
"16k": dict(sr=16_000, n_fft=1024, n_mels=80, hop=256, fmin=0, fmax=8_000,
target=_TARGET_MEAN_80D, log10=True),
"44k": dict(sr=44_100, n_fft=2048, n_mels=128, hop=512, fmin=0, fmax=22_050,
target=_TARGET_MEAN_128D, log10=False),
}
# ── Node 1: HF Smoother ───────────────────────────────────────────────────────
class SelvaHfSmoother:
"""Soft high-frequency attenuation for LoRA training clip preprocessing.
Blends a low-pass filtered copy of the audio with the original. Attenuates
the extreme HF content common in engineered sound design that the BigVGANv2
vocoder handles poorly, bringing the clip closer to the spectral region the
MMAudio codec was trained on (natural sounds with limited HF energy).
A blend of 0.7 at 12 kHz is a transparent starting point — audible only on
close comparison. Increase blend or lower cutoff if roundtrip quality is still
poor after spectral matching.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered. 0.7 is a transparent starting point.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Blends a low-pass filtered version of the audio with the original to gently attenuate "
"high-frequency content that the SelVA codec handles poorly. "
"Use before feature extraction to improve LoRA training targets. "
"Run after SelVA Spectral Matcher for best results."
)
def process(self, audio, cutoff_hz: float, blend: float):
waveform = audio["waveform"].float() # [1, C, L]
sr = audio["sample_rate"]
filtered = AF.lowpass_biquad(waveform, sr, cutoff_hz)
out = blend * filtered + (1.0 - blend) * waveform
# Preserve RMS level — LPF removes energy, keep the clip at its original loudness
rms_in = waveform.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = out.pow(2).mean().sqrt().clamp(min=1e-8)
out = out * (rms_in / rms_out)
peak = out.abs().max()
if peak > 1.0:
out = out / peak
print(f"[HF Smoother] cutoff={cutoff_hz:.0f} Hz blend={blend:.2f} "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f} "
f"peak={out.abs().max():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr},)
# ── Node 2: Spectral Matcher ──────────────────────────────────────────────────
class SelvaSpectralMatcher:
"""Adaptive per-band EQ toward the SelVA VAE's expected spectral distribution.
Computes the log-mel energy profile of the clip and compares it to the per-band
means stored in the VAE's normalization buffers (the statistics MMAudio was trained
on). Applies a smooth frequency-domain gain correction so the clip's spectral shape
matches what the codec expects, improving encode→decode roundtrip quality and
therefore LoRA training target quality.
The correction is additive in log space (multiplicative in linear), so it only
changes spectral balance — not the waveform's timing or phase structure.
max_gain_db clamps the correction to prevent extreme boosts on very quiet bands.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution. "
"0.8 is a good starting point.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB. Prevents extreme boosts on "
"very quiet frequency bands. 12 dB is conservative.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Applies a smooth per-band gain correction to bring the audio's spectral profile "
"in line with the MMAudio VAE's expected distribution, derived from the per-band "
"normalization statistics baked into the VAE weights. "
"Use before feature extraction to improve LoRA training target quality. "
"Run before SelVA HF Smoother."
)
def process(self, audio, mode: str, strength: float, max_gain_db: float):
cfg = _MEL_CONFIGS[mode]
waveform = audio["waveform"].float() # [1, C, L]
sr_in = audio["sample_rate"]
sr_tgt = cfg["sr"]
n_fft = cfg["n_fft"]
hop = cfg["hop"]
# ── flatten to mono and resample if needed ────────────────────────────
wav = waveform[0].mean(0) # [L]
if sr_in != sr_tgt:
wav = AF.resample(wav.unsqueeze(0), sr_in, sr_tgt).squeeze(0)
device = wav.device
window = torch.hann_window(n_fft, device=device)
# ── STFT ──────────────────────────────────────────────────────────────
stft = torch.stft(wav, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True, return_complex=True) # [n_freqs, T]
mag = stft.abs() # [n_freqs, T]
# ── current log-mel mean per band ─────────────────────────────────────
fb = _mel_filterbank(sr_tgt, n_fft, cfg["n_mels"],
cfg["fmin"], cfg["fmax"]).to(device) # [n_mels, n_freqs]
mel_mag = torch.matmul(fb, mag).clamp(min=1e-5) # [n_mels, T]
if cfg["log10"]:
mel_log = torch.log10(mel_mag)
else:
mel_log = torch.log(mel_mag)
current_mean = mel_log.mean(dim=-1) # [n_mels]
target_mean = torch.tensor(cfg["target"], device=device) # [n_mels]
# ── per-mel-band gain (log space) ─────────────────────────────────────
mel_gain = (target_mean - current_mean) * strength # [n_mels]
# Clamp to ±max_gain_db
if cfg["log10"]:
max_log = max_gain_db / 20.0 # log10: 20 log10 = dB
else:
max_log = max_gain_db / 8.6859 # ln: 20 * log10(e) ≈ 8.686
mel_gain = mel_gain.clamp(-max_log, max_log)
# ── map mel gains → STFT frequency bins (weighted average) ────────────
fb_sum = fb.sum(0).clamp(min=1e-8) # [n_freqs]
freq_gain = (mel_gain @ fb) / fb_sum # [n_freqs]
if cfg["log10"]:
linear_gain = 10.0 ** freq_gain # [n_freqs]
else:
linear_gain = torch.exp(freq_gain) # [n_freqs]
# ── apply gain in frequency domain and reconstruct ───────────────────
stft_out = stft * linear_gain.unsqueeze(-1) # [n_freqs, T]
wav_out = torch.istft(stft_out, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True,
length=wav.shape[0]) # [L]
# ── resample back to original sr ──────────────────────────────────────
if sr_in != sr_tgt:
wav_out = AF.resample(wav_out.unsqueeze(0), sr_tgt, sr_in).squeeze(0)
# ── preserve original RMS level ───────────────────────────────────────
rms_in = wav.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = wav_out.pow(2).mean().sqrt().clamp(min=1e-8)
wav_out = wav_out * (rms_in / rms_out)
peak = wav_out.abs().max()
if peak > 1.0:
wav_out = wav_out / peak
# ── reshape to match input layout [1, C, L] ───────────────────────────
out = wav_out.unsqueeze(0).unsqueeze(0)
if waveform.shape[1] > 1:
out = out.expand(-1, waveform.shape[1], -1).clone()
gain_db_range = (
20.0 * torch.log10(linear_gain.clamp(min=1e-8))
)
print(f"[Spectral Matcher] mode={mode} strength={strength:.2f} "
f"gain [{gain_db_range.min():.1f}, {gain_db_range.max():.1f}] dB "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr_in},)
+77
View File
@@ -0,0 +1,77 @@
"""SelVA BigVGAN Loader.
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
The model is modified in-place so ComfyUI's model cache is updated — no need to
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from .selva_bigvgan_trainer import inject_gafilters
class SelvaBigvganLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
DESCRIPTION = (
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
"Supports both 16k and 44k models. "
"Connect the output to SelVA Sampler instead of the base model loader."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"path": ("STRING", {
"default": "bigvgan_bj.pt",
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
"Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, model, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
if "generator" not in ckpt:
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
mode = model["mode"]
if mode == "16k":
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
elif mode == "44k":
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
# Remember device before injecting new modules (which default to CPU)
target_device = next(vocoder.parameters()).device
if ckpt.get("has_gafilter", False):
kernel_size = ckpt.get("gafilter_kernel_size", 9)
n_gaf = inject_gafilters(vocoder, kernel_size)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
vocoder.load_state_dict(ckpt["generator"])
vocoder.to(target_device)
vocoder.eval()
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
return (model,)
+625
View File
@@ -0,0 +1,625 @@
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
Each experiment inherits from a shared `base` config and overrides specific keys.
Audio clips are loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image.
JSON format:
{
"name": "bigvgan_sweep",
"description": "optional note",
"data_dir": "/path/to/audio/clips",
"output_root": "/path/to/output",
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
...
]
}
"""
import copy
import csv
import json
import threading
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_bigvgan_trainer import (
_do_train,
_pregenerate_lora_mels,
_load_wav,
)
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
from .selva_lora_scheduler import (
_get_system_info,
_resolve_path,
_draw_comparison_curves,
)
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"train_mode": "snake_alpha_only",
"steps": 2000,
"lr": 1e-4,
"batch_size": 4,
"segment_seconds": 2.0,
"lambda_l2sp": 1e-3,
"use_gafilter": True,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 500,
"seed": 42,
"discriminator_path": "",
"lora_adapter": "",
}
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge param defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _parse_training_log(log_path: Path) -> list:
"""Parse BigVGAN training CSV → list of total_loss values."""
losses = []
if not log_path.exists():
return losses
try:
with open(log_path) as f:
reader = csv.DictReader(f)
for row in reader:
losses.append(float(row["total_loss"]))
except Exception:
pass
return losses
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
"""Build {step: loss} at each save_every boundary.
Uses round-to-nearest to handle log_interval that doesn't divide
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
"""
result = {}
for target in range(save_every, total_steps + 1, save_every):
# loss_history[i] = loss at step (i+1)*log_interval
idx = round(target / log_interval) - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
class SelvaBigvganScheduler:
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
Audio clips are loaded once and reused across all experiments. Each experiment
deep-copies the vocoder and trains independently. Results are written to
`experiment_summary.json` after every completed run so partial results are
preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
"Audio clips are loaded once and reused across all experiments. "
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "bigvgan_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
mode = model["mode"]
dtype = model["dtype"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
if mode == "16k":
original_vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
original_vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[BigVGAN Scheduler] {description}", flush=True)
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load audio clips once
# ------------------------------------------------------------------
# Find minimum segment length across all experiments so we load enough
min_segment_seconds = float("inf")
for exp in spec["experiments"]:
cfg = _merge_config(base_cfg, exp)
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
min_segment_samples = int(min_segment_seconds * sample_rate)
audio_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
audio_files.extend(data_dir.rglob(ext))
if not audio_files:
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
clips = []
for af in audio_files:
try:
wav, sr = _load_wav(af)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0) # [L]
if wav.shape[0] >= min_segment_samples:
clips.append(wav.cpu())
else:
print(f" [BigVGAN Scheduler] Skip {af.name}: "
f"shorter than {min_segment_seconds}s", flush=True)
except Exception as e:
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
if not clips:
raise RuntimeError(
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
)
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
# ------------------------------------------------------------------
# 4. Offload unused components to free VRAM
# ------------------------------------------------------------------
comfy.model_management.unload_all_models()
feature_utils.to("cpu")
if "generator" in model:
model["generator"].to("cpu")
if "video_enc" in model:
model["video_enc"].to("cpu")
soft_empty_cache()
# ------------------------------------------------------------------
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
# ------------------------------------------------------------------
text_clip_cache = {}
any_lora = any(
_merge_config(base_cfg, exp).get("lora_adapter", "")
for exp in spec["experiments"]
)
if any_lora:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_model.to(device)
try:
for npz_path in npz_files:
data = dict(np.load(str(npz_path), allow_pickle=False))
prompt = prompt_map.get(
npz_path.name, data.get("prompt", default_prompt)
)
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
tc = feature_utils.encode_text_clip([prompt])
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
finally:
if clip_model is not None:
clip_model.to("cpu")
soft_empty_cache()
if device.type == "cuda":
torch.cuda.empty_cache()
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
f"CLIP embeddings", flush=True)
# ------------------------------------------------------------------
# 6. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 100),
"start_step": 0,
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[BigVGAN Scheduler] Resuming — skipping "
f"{len(completed_ids)} completed: "
f"{sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[BigVGAN Scheduler] Could not read existing summary "
f"({e}) — starting fresh", flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": len(clips),
"experiments": [],
}
def _write_summary():
summary_path.write_text(
json.dumps(summary, indent=2), encoding="utf-8"
)
_write_summary()
# ------------------------------------------------------------------
# 7. Compute total steps for progress bar
# ------------------------------------------------------------------
total_steps = 0
for exp in spec["experiments"]:
if exp["id"] not in completed_ids:
cfg = _merge_config(base_cfg, exp)
total_steps += int(cfg.get("steps", 2000))
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
# ------------------------------------------------------------------
# 8. Run experiments in a worker thread
# ------------------------------------------------------------------
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
# inside torch.inference_mode(). inference_mode is thread-local — a
# new thread starts with it OFF, so all tensor operations produce
# normal autograd-compatible tensors.
_exc = [None]
def _worker():
try:
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
f"(already completed)", flush=True)
continue
cfg = _merge_config(base_cfg, exp)
# ── Extract experiment parameters ────────────────────
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
exp_steps = int(cfg.get("steps", 2000))
exp_lr = float(cfg.get("lr", 1e-4))
exp_bs = int(cfg.get("batch_size", 4))
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
exp_gafilter = bool(cfg.get("use_gafilter", True))
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
exp_phase = float(cfg.get("lambda_phase", 1.0))
exp_save = int(cfg.get("save_every", 500))
exp_seed = int(cfg.get("seed", 42))
exp_disc = str(cfg.get("discriminator_path", ""))
exp_lora = str(cfg.get("lora_adapter", ""))
segment_samples = int(exp_seg_s * sample_rate)
# Filter clips long enough for this experiment
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
if not exp_clips:
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
f"no clips >= {exp_seg_s}s", flush=True)
summary["experiments"].append({
"id": exp_id, "description": exp_desc,
"config": dict(cfg),
"results": {
"status": "failed",
"error": f"No clips >= {exp_seg_s}s",
"duration_seconds": 0,
},
"checkpoint_path": None,
"output_dir": str(output_root / exp_id),
})
_write_summary()
continue
# ── Resolve discriminator path ───────────────────────
disc_path = None
if exp_disc:
disc_path = Path(exp_disc.strip())
if not disc_path.is_absolute():
disc_path = (
Path(folder_paths.get_output_directory()) / disc_path
)
if not disc_path.exists():
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"discriminator not found: {disc_path}",
flush=True)
disc_path = None
# ── Pre-generate LoRA mels (disk-cached) ─────────────
lora_mel_pairs = None
if exp_lora:
lora_path = Path(exp_lora.strip())
if not lora_path.is_absolute():
lora_path = Path(folder_paths.base_path) / lora_path
if lora_path.exists():
seq_cfg = model["seq_cfg"]
lora_mel_pairs = _pregenerate_lora_mels(
model, data_dir, str(lora_path),
device, dtype, sample_rate,
seq_cfg.duration, seed=exp_seed,
cache_dir=str(output_root),
text_clip_cache=text_clip_cache,
)
if not lora_mel_pairs:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"no LoRA mel pairs generated",
flush=True)
lora_mel_pairs = None
if device.type == "cuda":
torch.cuda.empty_cache()
else:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"LoRA adapter not found: {lora_path}",
flush=True)
# ── Output dir ───────────────────────────────────────
exp_dir = output_root / exp_id
exp_dir.mkdir(parents=True, exist_ok=True)
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
flush=True)
if exp_desc:
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
print(f"[BigVGAN Scheduler] mode={train_mode} "
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"train_mode": train_mode, "steps": exp_steps,
"lr": exp_lr, "batch_size": exp_bs,
"segment_seconds": exp_seg_s,
"lambda_l2sp": exp_l2sp,
"use_gafilter": exp_gafilter,
"gafilter_kernel_size": exp_gaf_ks,
"lambda_phase": exp_phase,
"save_every": exp_save, "seed": exp_seed,
"discriminator_path": exp_disc,
"lora_adapter": exp_lora,
},
"results": {"status": "running"},
"checkpoint_path": None,
"output_dir": str(exp_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
# Ensure mel_converter is on device for this experiment
mel_converter.to(device)
# Fresh vocoder copy — _do_train modifies it in-place
vocoder_copy = copy.deepcopy(original_vocoder)
checkpoint_path = _do_train(
vocoder_copy, mel_converter, exp_clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, exp_steps, exp_lr, exp_bs,
exp_l2sp, exp_gafilter, exp_gaf_ks,
exp_phase, exp_save, exp_seed,
out_path, disc_path, pbar,
lora_mel_pairs,
)
duration = time.monotonic() - t_start
# Parse training CSV for loss history
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
loss_history = _parse_training_log(log_path)
log_interval = max(1, exp_steps // 20)
smoothed = (
_smooth_losses(loss_history)
if loss_history else []
)
final_loss = (
round(smoothed[-1], 6) if smoothed else None
)
min_loss = (
round(min(smoothed), 6) if smoothed else None
)
min_idx = (
smoothed.index(min(smoothed))
if smoothed else None
)
min_loss_step = (
(min_idx + 1) * log_interval
if min_idx is not None else None
)
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std = round(
float(np.std(loss_history[-quarter:])), 6
)
else:
loss_std = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval,
exp_save, exp_steps,
),
"loss_history": [
round(v, 6) for v in loss_history
],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["checkpoint_path"] = checkpoint_path
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
f"failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
finally:
# Clean up vocoder copy to free VRAM
soft_empty_cache()
_write_summary()
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if _exc[0] is not None:
raise _exc[0]
# ------------------------------------------------------------------
# 9. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[BigVGAN Scheduler] Sweep complete. "
f"Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 10. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
+106
View File
@@ -0,0 +1,106 @@
import json
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaDatasetBrowser:
"""Browse a dataset.json file entry by entry using an integer index.
Each entry in the JSON is expected to have:
- "path" : base path (no extension) — directory that holds frame images
- "label" : text description of the clip
Derived outputs:
- video_path : path + ".mp4"
- audio_path : path + ".wav"
- frames_dir : path (the directory itself, for image-sequence loaders)
- label : entry["label"]
- count : total number of entries in the file
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset_json": ("STRING", {
"default": "",
"tooltip": "Absolute or ComfyUI-relative path to a dataset.json file.",
}),
"index": ("INT", {
"default": 0,
"min": 0,
"max": 9999,
"step": 1,
"tooltip": "Zero-based index of the entry to inspect.",
}),
},
}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "INT")
RETURN_NAMES = ("video_path", "audio_wav", "audio_flac", "features_path", "frames_dir", "mask_dir", "label", "max_index")
OUTPUT_TOOLTIPS = (
"path + '.mp4'",
"features/ + name + '.wav'",
"features/ + name + '.flac'",
"features/ + name + '.npz' (pre-extracted SelVA features)",
"path (image-sequence directory)",
"path + '_mask' (mask image-sequence directory)",
"Text label for this clip",
"count - 1 — wire to a primitive INT's max to constrain the index widget",
)
FUNCTION = "browse"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Reads a dataset.json produced by the SelVA dataset preparation pipeline "
"and exposes one entry at a time via an integer index. "
"Outputs the video path, audio path, frames directory, label, and total entry count."
)
# Re-read the file every call so edits are picked up without restarting ComfyUI.
IS_CHANGED = classmethod(lambda cls, **_: float("nan"))
def browse(self, dataset_json: str, index: int):
p = Path(dataset_json.strip())
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Dataset Browser] File not found: {p}")
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list) or len(data) == 0:
raise ValueError(f"[SelVA Dataset Browser] Expected a non-empty JSON array in {p}")
count = len(data)
if index >= count:
raise IndexError(
f"[SelVA Dataset Browser] index {index} is out of range "
f"(dataset has {count} entries, last index is {count - 1})"
)
entry = data[index]
base = entry["path"]
label = entry.get("label", "")
p_base = Path(base)
feat_base = str(p_base.parent / "features" / p_base.name)
print(
f"[SelVA Dataset Browser] {index + 1}/{count} label='{label}' base={base}",
flush=True,
)
return (
base + ".mp4",
feat_base + ".wav",
feat_base + ".flac",
feat_base + ".npz",
base,
base + "_mask",
label,
count - 1,
)
+788
View File
@@ -0,0 +1,788 @@
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetCompressor (optional)
↓ AUDIO_DATASET
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
↓ AUDIO_DATASET
SelvaDatasetHfSmoother (optional — batch HF attenuation)
↓ AUDIO_DATASET
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg
def _load_audio(path: Path):
"""Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues."""
if path.suffix.lower() in _SOUNDFILE_EXTS:
import soundfile as sf
wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C]
wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L]
else:
wav, sr = torchaudio.load(str(path)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
return wav, sr
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = _load_audio(f)
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
class SelvaDatasetCompressor:
"""Apply mild parallel compression to reduce within-clip loudness variance.
Uses pedalboard.Compressor (2:13:1 ratio). Parallel (New York) style:
blends compressed signal with dry so transients are preserved while
the dynamic range is gently tightened. Apply after LUFS normalization.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"threshold_db": ("FLOAT", {
"default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.",
}),
"ratio": ("FLOAT", {
"default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5,
"tooltip": "Compression ratio. 2:13:1 is mild; stay below 4:1 to avoid pumping.",
}),
"attack_ms": ("FLOAT", {
"default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0,
"tooltip": "Attack time in ms. Slower attack preserves transients.",
}),
"release_ms": ("FLOAT", {
"default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0,
"tooltip": "Release time in ms.",
}),
"mix": ("FLOAT", {
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.30.5 is typical.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "compress"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Mild parallel compression to reduce within-clip dynamic range. "
"Blends compressed signal with dry at the given mix ratio. "
"Apply after LUFS normalization."
)
def compress(self, dataset, threshold_db: float, ratio: float,
attack_ms: float, release_ms: float, mix: float):
from pedalboard import Compressor, Pedalboard
board = Pedalboard([Compressor(
threshold_db=threshold_db,
ratio=ratio,
attack_ms=attack_ms,
release_ms=release_ms,
)])
out = []
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pedalboard expects [C, L] float32 numpy
wav_np = wav.float().numpy() # [C, L]
compressed = board(wav_np, sr) # [C, L]
mixed = (1.0 - mix) * wav_np + mix * compressed
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]})
print(
f"[DatasetCompressor] {len(out)} clips compressed "
f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}",
flush=True,
)
return (out,)
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
mono = wav[0].mean(0) # [L]
window = torch.hann_window(n_fft, device=mono.device)
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
if mono.shape[0] < 2048:
return 60.0 # clip too short to frame — assume clean
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
"max_silence_fraction": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
"(< -60 dBFS). Set to 0 to disable silence detection.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
# Silence detection
if max_silence_fraction > 0:
mono = wav[0].mean(0)
if mono.shape[0] >= 2048:
frames = mono.unfold(0, 2048, 512)
rms = frames.pow(2).mean(-1).sqrt()
silent_frac = (rms < 1e-3).float().mean().item()
if silent_frac > max_silence_fraction:
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
class SelvaDatasetSaver:
"""Save all clips in an AUDIO_DATASET to disk as FLAC files.
Optionally copies matching .npz feature files from a source directory,
keeping FLAC/NPZ pairs in sync after the inspector has filtered clips.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"output_dir": ("STRING", {
"default": "",
"tooltip": "Absolute path to output folder. Created if it does not exist.",
}),
},
"optional": {
"npz_source_dir": ("STRING", {
"default": "",
"tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. "
"Missing NPZs are warned but do not abort the save.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("report",)
OUTPUT_NODE = True
FUNCTION = "save"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Save every clip in an AUDIO_DATASET to output_dir as FLAC. "
"If npz_source_dir is provided, copies the matching .npz file for each clip — "
"so rejected clips never get their NPZ copied."
)
def save(self, dataset, output_dir: str, npz_source_dir: str = ""):
import shutil
import soundfile as sf
out = Path(output_dir.strip())
out.mkdir(parents=True, exist_ok=True)
npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None
saved = 0
npz_copied = 0
npz_missing = []
for item in dataset:
name = item["name"]
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# soundfile wants [L] mono or [L, C] multichannel, float32
wav_np = wav.permute(1, 0).float().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
flac_path = out / f"{name}.flac"
sf.write(str(flac_path), wav_np, sr, subtype="PCM_24")
saved += 1
if npz_src is not None:
# Augmented clips store their origin name — use it to find the .npz
lookup = item.get("origin_name", name)
npz_path = npz_src / f"{lookup}.npz"
if npz_path.exists():
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
npz_copied += 1
else:
npz_missing.append(name)
lines = [
f"[DatasetSaver] Saved {saved} clips → {out}",
]
if npz_src is not None:
lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}")
for n in npz_missing:
lines.append(f" MISSING NPZ: {n}")
report = "\n".join(lines)
print(report, flush=True)
return (report,)
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
class SelvaDatasetSpectralMatcher:
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply adaptive spectral matching to every clip in a dataset. "
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
"VAE's expected distribution."
)
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
from .selva_audio_preprocessors import SelvaSpectralMatcher
matcher = SelvaSpectralMatcher()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = matcher.process(audio, mode, strength, max_gain_db)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
f"mode={mode} strength={strength}", flush=True)
return (out,)
class SelvaDatasetHfSmoother:
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
Wraps SelvaHfSmoother so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply soft HF attenuation to every clip in a dataset. "
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
"with the original to tame extreme HF content."
)
def process(self, dataset, cutoff_hz: float, blend: float):
from .selva_audio_preprocessors import SelvaHfSmoother
smoother = SelvaHfSmoother()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = smoother.process(audio, cutoff_hz, blend)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetHfSmoother] {len(out)} clips processed "
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
return (out,)
# ── Dataset augmenter ────────────────────────────────────────────────────────
class SelvaDatasetAugmenter:
"""Create augmented variants of each clip to expand a small dataset.
Supports gain variation (always available) and optionally pitch shift
and time stretch via audiomentations. Install audiomentations for the
full feature set: ``pip install audiomentations``
Each original clip produces ``variants_per_clip`` augmented copies.
Originals are kept by default (toggle ``keep_originals``).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"variants_per_clip": ("INT", {
"default": 2, "min": 1, "max": 20,
"tooltip": "Number of augmented copies per original clip.",
}),
"gain_range_db": ("FLOAT", {
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
}),
"seed": ("INT", {"default": 42}),
},
"optional": {
"pitch_range_semitones": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
}),
"time_stretch_range": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
"tooltip": "Random time stretch ±fraction (0.1 = 90%110% speed). "
"Requires audiomentations. 0 = disabled.",
}),
"keep_originals": ("BOOLEAN", {
"default": True,
"tooltip": "Include the original unaugmented clips in the output.",
}),
},
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "augment"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Create augmented variants of each clip (gain, pitch, time stretch) "
"to expand small training datasets. Gain is always available; pitch and "
"time stretch require audiomentations (pip install audiomentations)."
)
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
seed: int, pitch_range_semitones: float = 0.0,
time_stretch_range: float = 0.0, keep_originals: bool = True):
rng = np.random.RandomState(seed)
# Try audiomentations for pitch/stretch
use_am = False
am_compose = None
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
if needs_am:
try:
import audiomentations as am
transforms = []
if pitch_range_semitones > 0:
transforms.append(am.PitchShift(
min_semitones=-pitch_range_semitones,
max_semitones=pitch_range_semitones,
p=0.5,
))
if time_stretch_range > 0:
transforms.append(am.TimeStretch(
min_rate=1.0 - time_stretch_range,
max_rate=1.0 + time_stretch_range,
leave_length_unchanged=True,
p=0.5,
))
am_compose = am.Compose(transforms)
use_am = True
except ImportError:
print("[DatasetAugmenter] audiomentations not installed — "
"pitch_shift and time_stretch disabled. "
"Install: pip install audiomentations", flush=True)
out = []
if keep_originals:
out.extend(dataset)
for item in dataset:
wav = item["waveform"] # [1, C, L]
sr = item["sample_rate"]
name = item["name"]
for v in range(variants_per_clip):
# Gain variation (always applied)
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
gain_lin = 10.0 ** (gain_db / 20.0)
wav_aug = wav * gain_lin
# Pitch/stretch via audiomentations
if use_am and am_compose is not None:
wav_np = wav_aug[0].numpy() # [C, L] float32
if wav_np.shape[0] == 1:
wav_np = wav_np[0] # [L] mono for audiomentations
wav_np = am_compose(samples=wav_np, sample_rate=sr)
if wav_np.ndim == 1:
wav_np = wav_np[np.newaxis, :] # back to [1, L]
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
# Prevent clipping
peak = wav_aug.abs().max()
if peak > 1.0:
wav_aug = wav_aug / peak
out.append({
"waveform": wav_aug,
"sample_rate": sr,
"name": f"{name}_aug{v:02d}",
"origin_name": name,
})
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
f"gain=±{gain_range_db:.1f}dB"
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
flush=True)
return (out,)
+515
View File
@@ -0,0 +1,515 @@
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against target style reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against target style reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
Start at 0; enable only if mean-only optimization converges
without noise, then increase slowly (0.010.1).
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss — captures spectral envelope
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix loss — captures timbral texture (can add noise if too high)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss computed directly in VAE latent space.
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
ref_mean: [C_lat] mean latent vector of reference clips
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
gram_weight: weight for Gram component — 0 = mean only (recommended start)
Operating in latent space avoids backprop through the VAE decoder, which
is @torch.inference_mode() and produces noisy, unstable gradients.
"""
# Mean latent loss — matches average activation per channel
gen_mean = z.mean(dim=0) # [C_lat]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix — inter-channel covariance, position-invariant
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the target style reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with target style reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.020.05 is recommended; 0.1 (paper default) causes oscillation.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
"tooltip": "Weight of the target style style loss. High values push harder toward "
"target style style but add noise. Start at 0.1 and increase slowly.",
}),
"gram_weight": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
"0.1 adds texture matching but can introduce white noise.",
}),
"anchor_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
"Prevents optimization from pushing x0 out of the flow's "
"expected distribution (which causes white noise). "
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward target style style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match target style reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load reference clips and encode to latent space.
# Style loss is computed in latent space (after net_generator.unnormalize)
# rather than mel space — this avoids backpropagating through the VAE
# decoder (which is @torch.inference_mode() and produces noisy gradients).
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
if not hasattr(feature_utils.tod.vae, "encoder"):
raise RuntimeError(
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
"Reload the model with the encoder enabled."
)
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device, torch.float32) # cuFFT requires float32
ref_latents = []
with torch.no_grad():
for rf in ref_files:
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32)
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_latents:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference latent statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
ref_mean = all_means.mean(0) # [C_lat]
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread) and cast to
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
ref_mean = ref_mean.clone().detach().to(dtype)
ref_gram = ref_gram.clone().detach().to(dtype)
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone()
# Strip inference-mode flags from all model weights and buffers BEFORE any
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
# operations on inference tensors produce inference tensors, so conditions
# computed from tainted weights would also be tainted. clone() outside
# inference_mode produces a normal tensor regardless of the source flag.
def _strip_inference(module):
for mod in module.modules():
for name, buf in list(mod._buffers.items()):
if buf is not None:
mod._buffers[name] = buf.clone()
for name, param in list(mod._parameters.items()):
if param is not None:
mod._parameters[name] = torch.nn.Parameter(
param.data.clone(), requires_grad=False
)
_strip_inference(net_generator)
_strip_inference(feature_utils)
_strip_inference(mel_converter)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Clone all tensors inside conditions/empty_conditions to ensure no inference
# flags survived from intermediate computations inside preprocess_conditions.
def _clone_nested(obj):
if isinstance(obj, torch.Tensor):
return obj.clone()
elif isinstance(obj, dict):
return {k: _clone_nested(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_clone_nested(v) for v in obj)
return obj
conditions = _clone_nested(conditions)
empty_conditions = _clone_nested(empty_conditions)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
x0_init = x0_init.detach() # anchor — kept fixed, no grad
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# Detach from x0 so Phase 1 does not build a computation graph.
with torch.no_grad():
x = x0.detach()
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Straight-through estimator: reconnect x to x0's gradient path by
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
# creates a grad_fn pointing back to x0, so loss.backward() will
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
#
# x may carry an inference tensor flag from Phase 1 (derived from
# conditions which were built outside inference_mode but may have
# propagated the flag). .clone() strips it so the STE addition does
# not try to save an inference tensor for backward.
x = x.clone()
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Style loss in latent space ───────────────────────────────────────
# Unnormalize x back to VAE latent space — fully differentiable, no
# decode needed. ref_mean/ref_gram are computed from encoded reference
# clips in the same space. Avoids backprop through VAE decoder which
# is @torch.inference_mode() and produces noisy gradients.
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
# the ODE into an out-of-distribution region that decodes as white noise.
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
loss = style_loss + anchor_loss
optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
f"x0_std={x0.data.std().item():.3f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
+30 -106
View File
@@ -35,66 +35,6 @@ def _resize_frames(frames, size):
return x.clamp(0.0, 1.0) # [N, C, H, W] return x.clamp(0.0, 1.0) # [N, C, H, W]
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
"""
Compute a bounding box around the union of all mask frames.
mask: [M, H', W'] float [0,1]
square: if True, expand bbox to a square and shift into frame bounds;
if False, apply margin independently on each axis (rect crop).
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
"""
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
m = F.interpolate(
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
).squeeze(1)
else:
m = mask.float()
union = (m > 0.5).max(dim=0).values # [H, W] bool
if not union.any():
if square:
# Empty mask — center square crop
side = min(frame_h, frame_w)
cy, cx = frame_h // 2, frame_w // 2
y0 = max(0, cy - side // 2)
x0 = max(0, cx - side // 2)
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
else:
# Empty mask — return full frame (no meaningful rect to crop to)
return 0, 0, frame_h, frame_w
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
y0, y1 = int(ys[0]), int(ys[-1]) + 1
x0, x1 = int(xs[0]), int(xs[-1]) + 1
if square:
side = max(y1 - y0, x1 - x0)
pad = int(side * margin)
side += 2 * pad
cy = (y0 + y1) // 2
cx = (x0 + x1) // 2
y0n = cy - side // 2
x0n = cx - side // 2
y1n = y0n + side
x1n = x0n + side
# Shift into frame bounds to preserve square shape
if y0n < 0: y1n -= y0n; y0n = 0
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
if x0n < 0: x1n -= x0n; x0n = 0
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
else:
pad_y = int(max(1, y1 - y0) * margin)
pad_x = int(max(1, x1 - x0) * margin)
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0): def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
""" """
Apply a ComfyUI MASK to resized frames. Apply a ComfyUI MASK to resized frames.
@@ -128,9 +68,20 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
return frames * alpha + 0.5 * (1.0 - alpha) return frames * alpha + 0.5 * (1.0 - alpha)
def _resolve_named_path(cache_dir: str, name: str) -> str:
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
# Sanitize: replace path separators so the name stays inside cache_dir
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
i = 1
while True:
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
if not os.path.exists(p):
return p
i += 1
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None, def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True, mask_strength=1.0, mask_clip=True, mask_sync=True):
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
h = hashlib.sha256() h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes() raw = video_tensor.cpu().numpy().tobytes()
n = len(raw) n = len(raw)
@@ -148,10 +99,6 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
h.update(str(round(mask_strength, 4)).encode()) h.update(str(round(mask_strength, 4)).encode())
h.update(str(mask_clip).encode()) h.update(str(mask_clip).encode())
h.update(str(mask_sync).encode()) h.update(str(mask_sync).encode())
h.update(str(crop_to_mask).encode())
h.update(str(crop_rect).encode())
if crop_to_mask or crop_rect:
h.update(str(round(crop_margin, 4)).encode())
h.update(prompt.encode()) h.update(prompt.encode())
h.update(str(fps).encode()) h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
@@ -181,6 +128,8 @@ class SelvaFeatureExtractor:
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}), "tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
"cache_dir": ("STRING", {"default": "", "cache_dir": ("STRING", {"default": "",
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}), "tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
"name": ("STRING", {"default": "",
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
"mask": ("MASK", { "mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.", "tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}), }),
@@ -196,18 +145,6 @@ class SelvaFeatureExtractor:
"default": True, "default": True,
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.", "tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
}), }),
"crop_to_mask": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
}),
"crop_rect": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
}),
"crop_margin": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
}),
}, },
} }
@@ -218,14 +155,14 @@ class SelvaFeatureExtractor:
"Source fps of the video — wire to VHS_VideoCombine frame_rate.", "Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.", "The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
) )
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
FUNCTION = "extract_features" FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant." DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0, def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", mask=None, duration=0.0, cache_dir="", name="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True, mask_strength=1.0, mask_clip=True, mask_sync=True):
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
if video_info is not None: if video_info is not None:
fps = video_info["loaded_fps"] fps = video_info["loaded_fps"]
@@ -241,15 +178,19 @@ class SelvaFeatureExtractor:
if not cache_dir: if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path): if name.strip():
print(f"[SelVA] Using cached features: {cached_path}", flush=True) # Named mode: always extract and save to an incremented filename
cached = _load_cached(cached_path) cached_path = _resolve_named_path(cache_dir, name.strip())
return (cached, float(fps), cached.get("prompt", prompt)) else:
# Hash mode: skip extraction if identical inputs were already processed
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path)
return (cached, float(fps), cached.get("prompt", prompt))
device = get_device() device = get_device()
dtype = model["dtype"] dtype = model["dtype"]
@@ -265,24 +206,10 @@ class SelvaFeatureExtractor:
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3) pbar = comfy.utils.ProgressBar(3)
# Pre-compute crop bbox once from the original-resolution mask
crop_bbox = None
if mask is not None and (crop_to_mask or crop_rect):
H_vid, W_vid = video.shape[1], video.shape[2]
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
cy0, cx0, cy1, cx1 = crop_bbox
_mode = "square" if _square else "rect"
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
try: try:
with torch.no_grad(): with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None and mask_clip: if mask is not None and mask_clip:
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength) clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
@@ -295,9 +222,6 @@ class SelvaFeatureExtractor:
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None and mask_sync: if mask is not None and mask_sync:
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength) sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
+421
View File
@@ -0,0 +1,421 @@
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
JSON format:
{
"name": "eval_batch_1",
"data_dir": "/path/to/features",
"output_dir": "/path/to/evals/batch1",
"steps": 25,
"seed": 42,
"adapters": [
{"id": "baseline"},
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
]
}
Empty / missing "path" = baseline (no LoRA applied).
"""
import copy
import json
import sys
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_pil_to_tensor,
_find_audio,
_load_audio,
)
from selva_core.model.lora import apply_lora, load_lora
def _avg_metrics(metrics_list: list) -> dict:
"""Average spectral metrics across multiple clips, ignoring None entries."""
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
"spectral_flatness", "temporal_variance"]
valid = [m for m in metrics_list if m]
if not valid:
return {}
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _safe_stem(adapter_id: str) -> str:
"""Replace characters illegal in filenames."""
for ch in r'/\:*?"<>|':
adapter_id = adapter_id.replace(ch, "_")
return adapter_id
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
METRICS = [
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
("spectral_flatness", "Spectral Flatness"),
("temporal_variance", "Temporal Variance"),
]
COLORS = [
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
]
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
for ax, (key, title) in zip(axes, METRICS):
values = []
colors = []
for i, m in enumerate(metrics_list):
v = m.get(key, 0.0) if m else 0.0
values.append(v)
colors.append(COLORS[i % len(COLORS)])
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
ax.set_title(title, fontsize=9)
ax.set_xlabel(key, fontsize=8)
ax.tick_params(axis="y", labelsize=7)
ax.tick_params(axis="x", labelsize=7)
# Value labels on bars
for bar, val in zip(bars, values):
w = bar.get_width()
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
f"{val:.3f}", va="center", ha="left", fontsize=6)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(output_path), dpi=110)
buf = canvas.buffer_rgba()
w, h = canvas.get_width_height()
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
from PIL import Image
return _pil_to_tensor(Image.fromarray(arr))
class SelvaLoraEvaluator:
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
Generates one audio sample per adapter, computes spectral metrics for each,
and produces a comparison chart. Use this after a sweep to compare candidates
before running the next round of training.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_image")
OUTPUT_TOOLTIPS = (
"Path to eval_summary.json — contains spectral metrics per adapter.",
"Bar chart comparing spectral metrics across all evaluated adapters.",
)
DESCRIPTION = (
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
"from a fixed reference clip, then collects spectral metrics for comparison. "
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"eval_file": ("STRING", {
"default": "eval_batch.json",
"tooltip": (
"Path to the JSON evaluation spec. Relative paths resolve "
"to the ComfyUI output directory. "
"Each adapter entry needs an 'id' and an optional 'path'. "
"Omit 'path' for a no-LoRA baseline."
),
}),
}
}
def run(self, model, eval_file):
# ------------------------------------------------------------------
# 1. Resolve and parse the JSON file
# ------------------------------------------------------------------
eval_path = Path(eval_file.strip())
if not eval_path.is_absolute():
candidate = Path(folder_paths.models_dir) / eval_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / eval_path
eval_path = candidate
if not eval_path.exists():
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
spec = json.loads(eval_path.read_text(encoding="utf-8"))
if "adapters" not in spec or not spec["adapters"]:
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
for i, a in enumerate(spec["adapters"]):
if "id" not in a:
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
if "data_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
if "output_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
name = spec.get("name", eval_path.stem)
data_dir = _resolve_path(spec["data_dir"])
output_dir = _resolve_path(spec["output_dir"])
steps = int(spec.get("steps", 25))
seed = int(spec.get("seed", 42))
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
# ------------------------------------------------------------------
# 2. Prepare dataset (VAE encode once)
# ------------------------------------------------------------------
device = get_device()
dtype = model["dtype"]
dataset = _prepare_dataset(model, data_dir, device)
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------
# 3. Collect reference metrics for all dataset clips
# ------------------------------------------------------------------
import shutil
npz_files = sorted(data_dir.glob("*.npz"))
ref_dir = output_dir / "reference"
ref_dir.mkdir(exist_ok=True)
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
flush=True)
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
continue
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
shutil.copy2(str(audio_path), str(ref_out))
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
ref_clips.append({
"clip": npz_path.stem,
"wav_path": str(ref_out),
"spectral_metrics": metrics,
})
except Exception as e:
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
# Average reference metrics across all clips
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
print(f"[LoRA Evaluator] Reference avg — "
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
# ------------------------------------------------------------------
summary = {
"name": name,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"output_dir": str(output_dir),
"n_clips": len(ref_clips),
"steps": steps,
"seed": seed,
"reference_avg": ref_avg,
"reference_clips": ref_clips,
"adapters": [],
}
summary_path = output_dir / "eval_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Per-adapter evaluation loop (all clips)
# ------------------------------------------------------------------
n_clips = len(dataset)
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
for adapter_spec in spec["adapters"]:
adapter_id = adapter_spec["id"]
adapter_path = (adapter_spec.get("path") or "").strip()
safe_id = _safe_stem(adapter_id)
clip_dir = output_dir / safe_id
clip_dir.mkdir(exist_ok=True)
record = {
"id": adapter_id,
"path": adapter_path or None,
"meta": None,
"clips": [],
"avg_metrics": None,
"status": "running",
}
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
try:
with torch.inference_mode(False):
generator = copy.deepcopy(model["generator"])
if adapter_path:
pt_path = Path(adapter_path)
if not pt_path.is_absolute():
pt_path = Path(folder_paths.base_path) / pt_path
if not pt_path.exists():
raise FileNotFoundError(f"Adapter not found: {pt_path}")
ckpt = torch.load(str(pt_path), map_location="cpu",
weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
meta = ckpt.get("meta", {})
else:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
use_rslora = meta.get("use_rslora", False)
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
# Always use standard init for loading — PiSSA checkpoints
# include linear.weight (residual) in state_dict
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target), dropout=dropout,
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"apply_lora matched 0 layers (target={target})"
)
load_lora(generator, state_dict)
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
f"(rank={rank}, {n} layers)", flush=True)
else:
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
generator = generator.to(device, dtype)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
clip_metrics_list = []
for clip_idx in range(n_clips):
clip_stem = npz_files[clip_idx].stem
wav, sr = _eval_sample(
generator, feature_utils_orig, dataset,
seq_cfg, device, dtype,
num_steps=steps, seed=seed, clip_idx=clip_idx,
)
if wav is None:
pbar.update(1)
continue
wav_path = clip_dir / f"{clip_stem}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
metrics = _spectral_metrics(wav, sr)
clip_metrics_list.append(metrics)
record["clips"].append({
"clip": clip_stem,
"wav_path": str(wav_path),
"spectral_metrics": metrics,
})
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
pbar.update(1)
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
record["status"] = "completed"
avg = record["avg_metrics"]
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
except Exception as e:
record["status"] = "failed"
record["error"] = str(e)
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
traceback.print_exc()
pbar.update(n_clips - len(record["clips"]))
finally:
try:
del generator
except NameError:
pass
soft_empty_cache()
summary["adapters"].append(record)
_write_summary()
# ------------------------------------------------------------------
# 5. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 6. Comparison chart
# ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed:
ids = ["reference"] + [r["id"] for r in completed]
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
chart_path = output_dir / "metric_comparison.png"
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
else:
from PIL import Image
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
return (str(summary_path), comparison)
+109
View File
@@ -0,0 +1,109 @@
import copy
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from selva_core.model.lora import apply_lora, load_lora
class SelvaLoraLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"adapter_path": ("STRING", {
"default": "",
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
}),
"strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to all LoRA contributions. "
"1.0 = full adapter strength. "
"0.0 = effectively disables the adapter. "
"Values above 1.0 exaggerate the effect.",
}),
},
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
"The base model is not modified — a shallow copy of the model bundle is returned."
)
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
if not adapter_path.strip():
raise ValueError("[SelVA LoRA] adapter_path is empty.")
# Resolve path: allow absolute or relative to ComfyUI base
from pathlib import Path
p = Path(adapter_path)
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
# Support both raw state_dict and {state_dict, meta} formats
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
meta = checkpoint.get("meta", {})
else:
state_dict = checkpoint
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
init_mode = meta.get("init_mode", "standard")
use_rslora = meta.get("use_rslora", False)
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
f"init={init_mode} rslora={use_rslora} strength={strength}",
flush=True)
# Shallow-copy the model bundle so the original generator is not mutated
patched = {**model}
generator = copy.deepcopy(model["generator"])
# For PiSSA, use standard init (the base weights will be overwritten
# by load_state_dict since the checkpoint includes linear.weight)
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target),
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"[SelVA LoRA] No layers matched target={target}. "
"Check that the adapter was trained with the same target suffixes."
)
load_lora(generator, state_dict)
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
norms = [p.norm().item() for name, p in generator.named_parameters()
if "lora_A" in name]
if norms:
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
else:
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
# Apply strength scaling: multiply all lora_B params by strength
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
if strength != 1.0:
with torch.no_grad():
for name, param in generator.named_parameters():
if "lora_B" in name:
param.mul_(strength)
generator.to(model["generator"].parameters().__next__().device)
patched["generator"] = generator
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
return (patched,)
+539
View File
@@ -0,0 +1,539 @@
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "tier1_sweep",
"description": "optional human note",
"data_dir": "dataset/dog_bark",
"output_root": "lora_output/tier1_sweep",
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
...
]
}
"""
import copy
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
SelvaLoraTrainer,
SkipExperiment,
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
def _get_system_info() -> dict:
"""Collect GPU / torch version info for the summary header."""
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"alpha": 0.0,
"target": "attn.qkv",
"batch_size": 4,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"resume_path": "",
"seed": 42,
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": True,
"latent_mixup_alpha": 0.0,
"latent_noise_sigma": 0.0,
}
# Palette for comparison chart: one color per experiment (cycles if > 8)
_PALETTE = [
(66, 133, 244), # blue
(234, 67, 53), # red
(52, 168, 83), # green
(251, 188, 5), # yellow
(155, 89, 182), # purple
(26, 188, 156), # teal
(230, 126, 34), # orange
(149, 165, 166), # grey
]
def _resolve_path(raw: str) -> Path:
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge base defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
# Don't carry id/description into the training params
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
start_step: int, total_steps: int) -> dict:
"""Build a dict of {step: loss} at each save_every boundary.
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
start + (i+1)*log_interval].
"""
result = {}
targets = range(save_every, total_steps + 1, save_every)
for target in targets:
# index of the loss entry nearest to this step
idx = (target - start_step) // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
) -> Image.Image:
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
# Collect all smoothed series
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"log_interval": ed.get("log_interval", 50),
"start_step": ed.get("start_step", 0),
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
# Draw each curve
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
# Legend (right side)
lx = W - pr + 10
ly = pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
class SelvaLoraScheduler:
"""Runs a sweep of LoRA training experiments defined in a JSON file.
The dataset (VAE encoding + .npz loading) is performed once and shared
across all experiments. Each experiment deep-copies the generator and trains
independently. Results are written to `experiment_summary.json` after every
completed run so partial results are preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
"The dataset is encoded once and reused across all experiments. "
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is. "
"See LORA_TRAINING.md for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
# Try relative to ComfyUI models dir first, then output dir
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[LoRA Scheduler] {description}", flush=True)
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load + encode dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = [] # collected for comparison image
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 50),
"start_step": 0,
})
# Restore the original summary, clear completed_at so it gets set again
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaLoraTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
variant = model["variant"]
mode = model["mode"]
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
# Required training params
steps = int(cfg.get("steps", 2000))
rank = int(cfg.get("rank", 16))
lr = float(cfg.get("lr", 1e-4))
alpha = float(cfg.get("alpha", 0.0))
target = str(cfg.get("target", "attn.qkv"))
batch_size = int(cfg.get("batch_size", 4))
warmup = int(cfg.get("warmup_steps", 100))
grad_accum = int(cfg.get("grad_accum", 1))
save_every = int(cfg.get("save_every", 500))
resume_path = str(cfg.get("resume_path", ""))
seed = int(cfg.get("seed", 42))
ts_mode = str(cfg.get("timestep_mode", "uniform"))
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
curr_switch = float(cfg.get("curriculum_switch", 0.6))
dropout = float(cfg.get("lora_dropout", 0.0))
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
lr_schedule = str(cfg.get("lr_schedule", "constant"))
init_mode = str(cfg.get("init_mode", "pissa"))
use_rslora = bool(cfg.get("use_rslora", True))
alpha_val = alpha if alpha > 0.0 else float(2 * rank)
target_suffixes = tuple(target.strip().split())
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
"batch_size": batch_size, "warmup_steps": warmup,
"grad_accum": grad_accum, "save_every": save_every,
"seed": seed, "target": list(target_suffixes),
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
"curriculum_switch": curr_switch,
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
"lr_schedule": lr_schedule,
"init_mode": init_mode, "use_rslora": use_rslora,
},
"results": {"status": "running"},
"adapter_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup,
grad_accum, save_every, resume_path, seed,
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
lr_schedule, init_mode, use_rslora,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
grad_norm_history = r.get("grad_norm_history", [])
spectral_metrics = r.get("spectral_metrics", {})
run_start_step = r.get("start_step", 0)
smoothed = _smooth_losses(loss_history) if loss_history else []
# Scalar summary metrics
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (
run_start_step + (min_idx + 1) * log_interval
if min_idx is not None else None
)
# Stability: std-dev of raw loss over last 25% of steps
if loss_history:
quarter = max(1, len(loss_history) // 4)
last_q = loss_history[-quarter:]
loss_std_last_quarter = round(float(np.std(last_q)), 6)
else:
loss_std_last_quarter = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, run_start_step, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"grad_norm_history": grad_norm_history,
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["adapter_path"] = r["adapter_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except SkipExperiment as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
partial = getattr(e, "partial", {})
lh = partial.get("loss_history", [])
smoothed = _smooth_losses(lh) if lh else []
exp_record["results"] = {
"status": "skipped",
"stopped_at_step": partial.get("stopped_at_step"),
"final_loss": round(smoothed[-1], 6) if smoothed else None,
"loss_history": [round(v, 6) for v in lh],
"grad_norm_history": partial.get("grad_norm_history", []),
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
except Exception as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
# Continue to next experiment rather than aborting the whole sweep
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -149,7 +149,7 @@ class SelvaModelLoader:
enable_conditions=True, enable_conditions=True,
mode=mode, mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path, bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=False, need_vae_encoder=True,
).to(device, dtype).eval() ).to(device, dtype).eval()
if strategy == "offload_to_cpu": if strategy == "offload_to_cpu":
+107 -3
View File
@@ -3,6 +3,7 @@ import comfy.utils
import comfy.model_management import comfy.model_management
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_textual_inversion_trainer import _inject_tokens
class SelvaSampler: class SelvaSampler:
@@ -31,9 +32,31 @@ class SelvaSampler:
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
}, },
"optional": { "optional": {
"steering_vectors": ("STEERING_VECTORS", {
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
"Nudges each DiT block's hidden state toward the extracted pattern.",
}),
"steering_strength": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to each steering vector before adding to block output. "
"Start around 0.10.3; higher values risk destabilizing the ODE.",
}),
"normalize": ("BOOLEAN", { "normalize": ("BOOLEAN", {
"default": True, "default": True,
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.", "tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.",
}),
"textual_inversion": ("TEXTUAL_INVERSION", {
"tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. "
"Injects style tokens into CLIP conditioning without modifying model weights.",
}),
"ti_strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Blends between original CLIP conditioning (0.0) and full TI injection (1.0). "
"Reduce toward 0.30.5 if TI produces buzz artifacts.",
}), }),
}, },
} }
@@ -45,7 +68,7 @@ class SelvaSampler:
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance." DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True): def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
import dataclasses import dataclasses
from selva_core.model.flow_matching import FlowMatching from selva_core.model.flow_matching import FlowMatching
@@ -110,6 +133,19 @@ class SelvaSampler:
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None if negative_prompt.strip() else None
# Inject textual inversion tokens into CLIP conditioning
if textual_inversion is not None:
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
K = emb.shape[0]
inject_mode = textual_inversion.get("inject_mode", "suffix")
ti_text = _inject_tokens(text_clip, emb, K, inject_mode)
text_clip = torch.lerp(text_clip, ti_text, ti_strength)
if neg_text_clip is not None:
ti_neg = _inject_tokens(neg_text_clip, emb, K, inject_mode)
neg_text_clip = torch.lerp(neg_text_clip, ti_neg, ti_strength)
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode} strength={ti_strength}",
flush=True)
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions( empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip bs=1, negative_text_features=neg_text_clip
@@ -123,6 +159,63 @@ class SelvaSampler:
device=gen_device, dtype=dtype, generator=rng, device=gen_device, dtype=dtype, generator=rng,
).to(device) ).to(device)
# Activation steering: apply only during the conditional predict_flow pass
# so steering gets amplified by cfg_strength rather than canceling out.
steering_handles = []
_orig_predict_flow = None
if steering_vectors is not None and steering_strength > 0.0:
vecs = steering_vectors["steering_vectors"]
n_joint = steering_vectors["n_joint"]
# Patch predict_flow to flag which pass is conditional.
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
# identity check tells us which is which.
_is_cond_pass = [False]
_orig_predict_flow = net_generator.predict_flow
def _tracked_predict_flow(latent, t, cond):
_is_cond_pass[0] = (cond is conditions)
return _orig_predict_flow(latent, t, cond)
net_generator.predict_flow = _tracked_predict_flow
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
vec = vec_cpu.to(dev, dt) # [seq, hidden]
def hook(module, input, output):
if not _is_cond_pass[0]:
return # skip unconditional pass; steering amplified by cfg_strength
# Interpolate steering vec to match actual output seq length
# (handles generation at different duration than extraction)
if is_joint:
out_seq = output[0].shape[1]
else:
out_seq = output.shape[1]
v = vec
if v.shape[0] != out_seq:
v = torch.nn.functional.interpolate(
v.T.unsqueeze(0), # [1, hidden, seq_orig]
size=out_seq,
mode="linear",
align_corners=False,
).squeeze(0).T # [seq_new, hidden]
if is_joint:
latent_out = output[0] + strength * v
return (latent_out,) + output[1:]
else:
return output + strength * v
return hook
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
for i, block in enumerate(blocks):
is_joint = i < n_joint
if i < len(vecs):
h = block.register_forward_hook(
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
)
steering_handles.append(h)
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
f"strength={steering_strength} (conditional pass only)", flush=True)
# Flow matching ODE (Euler) # Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
@@ -139,6 +232,11 @@ class SelvaSampler:
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy " "[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration." "to 'offload_to_cpu', using a smaller variant, or reducing duration."
) )
finally:
if _orig_predict_flow is not None:
net_generator.predict_flow = _orig_predict_flow
for h in steering_handles:
h.remove()
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
@@ -168,8 +266,14 @@ class SelvaSampler:
audio = audio.mean(dim=1, keepdim=True) # stereo → mono audio = audio.mean(dim=1, keepdim=True) # stereo → mono
if normalize: if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
# If RMS normalization pushes peaks into clipping, scale back to
# preserve dynamics rather than hard-clipping (no saturation)
peak = audio.abs().max().clamp(min=1e-8) peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1) if peak > 1.0:
audio = audio / peak
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},) return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
+50
View File
@@ -0,0 +1,50 @@
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaSkipExperiment:
"""Writes skip_current.flag into a sweep output_root.
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
the current experiment and move to the next one. The trainer picks up
the flag within 50 steps (~a few seconds).
"""
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"output_root": ("STRING", {
"default": "",
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("flag_path",)
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
FUNCTION = "skip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
"and move to the next one. Queue this node while the scheduler is running. "
"Partial scalars collected so far are saved in the summary."
)
def skip(self, output_root: str):
p = Path(output_root.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
flag = p / "skip_current.flag"
flag.touch()
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
return (str(flag),)
+70
View File
@@ -0,0 +1,70 @@
"""SelVA Textual Inversion Loader.
Loads a .pt file produced by SelvaTextualInversionTrainer and returns a
TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaTextualInversionLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. "
"Relative paths resolve to the ComfyUI output directory.",
}),
},
}
RETURN_TYPES = ("TEXTUAL_INVERSION",)
RETURN_NAMES = ("textual_inversion",)
OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. "
"Connect the output to the SelVA Sampler's optional textual_inversion input to guide "
"generation toward the training data style without degrading audio quality."
)
def load(self, path: str) -> tuple:
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[TI Loader] File not found: {p}")
data = torch.load(str(p), map_location="cpu", weights_only=False)
embeddings = data["embeddings"] # [K, 1024]
n_tokens = int(data.get("n_tokens", embeddings.shape[0]))
print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} "
f"shape={tuple(embeddings.shape)}", flush=True)
if data.get("init_text"):
print(f"[TI Loader] init_text='{data['init_text']}'", flush=True)
if data.get("step"):
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
f"lr={data.get('lr', '?')}", flush=True)
inject_mode = data.get("inject_mode", "suffix")
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
bundle = {
"embeddings": embeddings, # [K, 1024] float32 on CPU
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"path": str(p),
"init_text": data.get("init_text", ""),
}
return (bundle,)
+450
View File
@@ -0,0 +1,450 @@
"""SelVA Textual Inversion Trainer.
Learns K token embedding vectors in CLIP space that guide the base model
to generate audio in the style of the training clips — without modifying
any model weights.
Key difference from LoRA:
- ALL generator parameters are frozen (requires_grad=False)
- Only K×1024 token embeddings receive gradients
- Latents stay on the decoder's natural manifold → no quality degradation
- The learned tokens shift WHICH latents are generated, not HOW
Usage:
1. Train on your .npz audio features
2. Load result with SelVA Textual Inversion Loader
3. Connect to SelVA Sampler optional input
"""
import copy
import random
import traceback
from pathlib import Path
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.flow_matching import FlowMatching
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_smooth_losses,
_draw_loss_curve,
_pil_to_tensor,
)
# ---------------------------------------------------------------------------
# Eval helper with token injection
# ---------------------------------------------------------------------------
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
n_tokens: int, inject_mode: str) -> torch.Tensor:
"""Build a text_clip tensor with learned tokens injected.
inject_mode:
"suffix" — replace last n_tokens positions (EOS/padding zone)
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
Works for both training (tokens is a Parameter) and eval (tokens is detached).
"""
if inject_mode == "prefix":
bos = text_clip[:, :1, :].detach() # [B, 1, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
else: # suffix (default)
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
return torch.cat([front, toks], dim=1) # [B, 77, D]
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, num_steps=25, seed=42, clip_idx=0):
"""Inference pass with learned tokens injected into text conditioning."""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype).clone()
emb = learned_tokens.detach().to(device, dtype)
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_input,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
tod = feature_utils_orig.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
tod.to(tod_orig_dev)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
target_rms = 10 ** (-27.0 / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = (audio * (target_rms / rms))
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
return audio.squeeze(0), seq_cfg.sampling_rate
except Exception as e:
print(f"[TI Trainer] Eval sample failed: {e}", flush=True)
traceback.print_exc()
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTextualInversionTrainer:
"""Learns K CLIP token embeddings that steer SelVA toward a target audio style.
Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor
receives gradients, keeping generated latents on the decoder's natural manifold
and preserving base model audio quality while shifting generation style.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "train"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("embeddings_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
"Smoothed training loss curve.",
)
DESCRIPTION = (
"Trains K learnable CLIP token embeddings against your audio dataset "
"with all model weights frozen. The tokens are then injected into the "
"sampler to guide generation toward the training data style without "
"degrading audio quality."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).",
}),
"output_path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.",
}),
"n_tokens": ("INT", {
"default": 4, "min": 1, "max": 16,
"tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.",
}),
"steps": ("INT", {
"default": 3000, "min": 100, "max": 50000,
"tooltip": "Training steps. 3000 is a reasonable starting point.",
}),
"lr": ("FLOAT", {
"default": 2e-4, "min": 1e-5, "max": 1e-1, "step": 1e-5,
"tooltip": "Learning rate. 2e-4 matches the LoRA working regime. Higher LR (1e-3) causes token norm to drift without plateauing on small datasets.",
}),
"batch_size": ("INT", {
"default": 4, "min": 1, "max": 64,
"tooltip": "Clips sampled per training step. Smaller batch (48) gives more diverse gradients and helps token norm saturate rather than drift.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
"save_every": ("INT", {
"default": 1000, "min": 100, "max": 10000,
"tooltip": "Save a checkpoint and generate an eval sample every N steps.",
}),
},
"optional": {
"inject_mode": (["suffix", "prefix"], {
"default": "suffix",
"tooltip": (
"Where to inject the learned tokens in the 77-token CLIP sequence. "
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
),
}),
"init_text": ("STRING", {
"default": "",
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
}),
"warmup_steps": ("INT", {
"default": 100, "min": 0, "max": 1000,
"tooltip": "Linear LR warmup steps.",
}),
},
}
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
batch_size, seed, save_every,
inject_mode="suffix", init_text="", warmup_steps=100):
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
# --- Resolve paths ---
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True)
print(f"[TI Trainer] data_dir = {data_dir}", flush=True)
print(f"[TI Trainer] output = {out_path}\n", flush=True)
# --- Load dataset (reuse LoRA trainer helper) ---
dataset = _prepare_dataset(model, data_dir, device)
# Training must run outside inference_mode so autograd works
with torch.inference_mode(False), torch.enable_grad():
r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode,
)
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
return (r["embeddings_path"], _pil_to_tensor(curve_img))
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
):
torch.manual_seed(seed)
# --- Generator (frozen) ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
generator.requires_grad_(False)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Init learned tokens ---
# Call encode_text_clip outside the grad context (it has @inference_mode),
# grab values only (no grad needed), then wrap as nn.Parameter.
if init_text.strip():
with torch.no_grad():
init_embed = feature_utils_orig.encode_text_clip([init_text.strip()])
# Positions 1:1+n_tokens — after BOS, before EOS — have actual content
init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float()
if init_vals.shape[0] < n_tokens:
# Prompt was very short; pad remaining with small noise
pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02
init_vals = torch.cat([init_vals, pad], dim=0)
learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype))
print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1{n_tokens})", flush=True)
else:
learned_tokens = torch.nn.Parameter(
torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02
)
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
# --- Measure CLIP token norm from the dataset (content positions 120) ---
# Learned tokens must stay within this range or the model treats them as
# out-of-distribution and produces buzz artifacts instead of style shift.
with torch.no_grad():
sample_norms = []
for item in dataset[:min(len(dataset), 20)]:
tc = item[3].squeeze(0) # [77, 1024]
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
clip_norm_ref = torch.cat(sample_norms).mean().item()
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
f"limit={clip_norm_limit:.4f}", flush=True)
# --- Optimizer + scheduler ---
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Checkpoint dir ---
ckpt_dir = out_path.parent / out_path.stem
ckpt_dir.mkdir(parents=True, exist_ok=True)
# --- Baseline sample (once, before any training) ---
print(f"[TI Trainer] Generating baseline sample...", flush=True)
baseline_wav, baseline_sr = _eval_sample(
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
)
if baseline_wav is not None:
baseline_path = ckpt_dir / "baseline.wav"
try:
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
except RuntimeError:
import soundfile as sf
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
try:
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
except Exception:
pass
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
pbar = comfy.utils.ProgressBar(steps)
loss_history = []
running_loss = 0.0
print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True)
for step in range(1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
x1 = generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
loss = fm.loss(v_pred, x0, x1).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Clamp token norm to CLIP manifold — prevents out-of-distribution
# embeddings that cause buzz artifacts instead of style shift.
with torch.no_grad():
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
scale = (clip_norm_limit / norms).clamp(max=1.0)
learned_tokens.data.mul_(scale)
running_loss += loss.item()
pbar.update(1)
if step % log_interval == 0:
avg = running_loss / log_interval
loss_history.append(round(avg, 6))
running_loss = 0.0
lr_now = scheduler.get_last_lr()[0]
norm = learned_tokens.norm(dim=-1).mean().item()
print(f"[TI Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e} "
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
if step % save_every == 0 or step == steps:
# Save checkpoint
ckpt = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": step,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
torch.save(ckpt, str(ckpt_path))
# Eval sample
wav, sr = _eval_sample_ti(
generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, seed=seed,
)
if wav is not None:
wav_path = ckpt_dir / f"step_{step:05d}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
try:
metrics = _spectral_metrics(wav, sr)
_save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
print(f"[TI Trainer] step {step} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.4f} "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
except Exception as e:
print(f"[TI Trainer] Spectral/spectrogram failed: {e}", flush=True)
print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True)
# --- Final save ---
final = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": steps,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
torch.save(final, str(out_path))
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
soft_empty_cache()
return {
"embeddings_path": str(out_path),
"loss_history": loss_history,
}
+479
View File
@@ -0,0 +1,479 @@
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "ti_sweep_1",
"description": "optional human note",
"data_dir": "dataset/bj_sounds",
"output_root": "ti_output/sweep_1",
"base": {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000
},
"experiments": [
{"id": "baseline", "description": "default 4 tokens"},
{"id": "n8_tokens", "n_tokens": 8},
{"id": "lr_5e4", "lr": 5e-4},
{"id": "warm_init", "init_text": "industrial sound design"},
{"id": "n4_more_steps", "steps": 5000}
]
}
"""
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
# ---------------------------------------------------------------------------
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
# ---------------------------------------------------------------------------
def _get_system_info() -> dict:
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
_PARAM_DEFAULTS = {
"n_tokens": 4,
"lr": 2e-4,
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000,
"init_text": "",
"inject_mode": "suffix",
}
_PALETTE = [
(66, 133, 244),
(234, 67, 53),
(52, 168, 83),
(251, 188, 5),
(155, 89, 182),
(26, 188, 156),
(230, 126, 34),
(149, 165, 166),
]
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
result = {}
for target in range(save_every, total_steps + 1, save_every):
idx = target // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
from PIL import Image, ImageDraw
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
lx, ly = W - pr + 10, pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTiScheduler:
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
The dataset is loaded once and reused. Each experiment calls
SelvaTextualInversionTrainer._train_inner() with its own config.
Results are written to experiment_summary.json after every completed run.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — compare runs across sweeps.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
"The dataset is encoded once and reused. Results (loss, config, embeddings "
"paths) are collected in experiment_summary.json after each run."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "ti_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"output directory. See node description for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate JSON
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[TI Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[TI Scheduler] {description}", flush=True)
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
all_curve_data.append({
"id": rec["id"],
"loss_history": rec["results"].get("loss_history", []),
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
comparison_img_path = output_root / "loss_comparison.png"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
def _save_comparison():
try:
img = _draw_comparison_curves(all_curve_data)
img.save(str(comparison_img_path))
except Exception as e:
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaTextualInversionTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
n_tokens = int(cfg["n_tokens"])
lr = float(cfg["lr"])
steps = int(cfg["steps"])
batch_size = int(cfg["batch_size"])
warmup = int(cfg["warmup_steps"])
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
inject_mode = str(cfg["inject_mode"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / "embeddings.pt"
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed} "
f"inject_mode={inject_mode}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"n_tokens": n_tokens,
"lr": lr,
"steps": steps,
"batch_size": batch_size,
"warmup_steps": warmup,
"seed": seed,
"save_every": save_every,
"init_text": init_text,
"inject_mode": inject_mode,
},
"results": {"status": "running"},
"embeddings_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text, inject_mode,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
loss_std_last_quarter = None
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["embeddings_path"] = r["embeddings_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
_write_summary()
_save_comparison()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image (final update, then return to ComfyUI)
# ------------------------------------------------------------------
_save_comparison()
comparison_img = _draw_comparison_curves(all_curve_data)
return (str(summary_path), _pil_to_tensor(comparison_img))
+157
View File
@@ -0,0 +1,157 @@
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
Useful for diagnosing codec reconstruction quality: if the output sounds
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
not the diffusion model or LoRA.
"""
import torch
import torch.nn.functional as F
import torchaudio
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
class SelvaVaeRoundtrip:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"audio": ("AUDIO",),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio_reconstructed",)
OUTPUT_TOOLTIPS = (
"Audio after VAE encode → decode roundtrip. "
"Compare to the input to hear codec reconstruction quality.",
)
FUNCTION = "roundtrip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
"Use this to isolate codec reconstruction quality from generation quality. "
"If the output sounds degraded compared to the input, the VAE/DAC is the "
"bottleneck — not the model or LoRA."
)
def roundtrip(self, model, audio):
from selva_core.model.utils.features_utils import FeaturesUtils
mode = model["mode"]
seq_cfg = model["seq_cfg"]
dtype = model["dtype"]
device = get_device()
generator = model["generator"]
feature_utils = model["feature_utils"]
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
# Load encoder only — decoder/vocoder come from model["feature_utils"]
# to mirror exactly what the sampler uses.
# AutoEncoderModule requires vocoder_ckpt_path even when only encoding,
# so pass the BigVGAN path (weights won't actually be used for decode here).
bigvgan_path = _SELVA_DIR / "ext" / "best_netG.pt"
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
vae_enc = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
bigvgan_vocoder_ckpt=str(bigvgan_path) if bigvgan_path.exists() else None,
).to(device).eval()
try:
# Prepare input audio
waveform = audio["waveform"] # [1, C, L]
sr_in = audio["sample_rate"]
wav = waveform[0].mean(0) # mono [L]
if sr_in != seq_cfg.sampling_rate:
wav = torchaudio.functional.resample(
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
).squeeze(0)
print(f"[VAE Roundtrip] Resampled {sr_in}{seq_cfg.sampling_rate} Hz",
flush=True)
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
if wav.shape[0] > target_len:
wav = wav[:target_len]
elif wav.shape[0] < target_len:
wav = F.pad(wav, (0, target_len - wav.shape[0]))
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
with torch.no_grad():
# Encode: audio → raw latent [1, latent_dim, T]
dist = vae_enc.encode_audio(wav_b)
latent = dist.mode().clone()
# Trim/pad to exact model sequence length (same as _prepare_dataset)
tgt = seq_cfg.latent_seq_len
if latent.shape[2] < tgt:
latent = F.pad(latent, (0, tgt - latent.shape[2]))
elif latent.shape[2] > tgt:
latent = latent[:, :, :tgt]
# To [B, T, latent_dim] — layout the generator uses
latent_t = latent.transpose(1, 2).to(dtype)
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
flush=True)
# Normalize → unnormalize mirrors the training/inference pipeline:
# training normalizes encoded latents; sampler unnormalizes before decode.
# This ensures the latent is in the same space the decoder expects.
latent_norm = generator.normalize(latent_t.clone())
latent_unnorm = generator.unnormalize(latent_norm)
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
flush=True)
# Decode using model's feature_utils — same path as the sampler
tod = feature_utils.tod
tod_orig_device = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils.decode(latent_unnorm)
out = feature_utils.vocode(spec)
finally:
tod.to(tod_orig_device)
out = out.float().cpu()
if out.dim() == 1:
out = out.unsqueeze(0).unsqueeze(0)
elif out.dim() == 2:
out = out.unsqueeze(1)
elif out.dim() == 3 and out.shape[1] != 1:
out = out.mean(dim=1, keepdim=True)
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
target_rms = 10 ** (-27.0 / 20.0)
out = out * (target_rms / rms)
out = out.clamp(-1.0, 1.0)
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
flush=True)
finally:
del vae_enc
soft_empty_cache()
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
+309
View File
@@ -0,0 +1,309 @@
"""
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
Supports two initialization modes:
- **standard**: Kaiming-uniform A, zero B (classic LoRA).
- **pissa**: A and B from the top-r SVD of the pretrained weight.
Starts on-manifold, eliminates intruder dimensions at init
(arXiv:2404.02948, NeurIPS 2024 Spotlight).
Supports two scaling modes:
- **standard**: alpha / rank
- **rslora**: alpha / sqrt(rank) — rank-stabilized scaling that prevents
gradient collapse at high ranks (arXiv:2312.03732).
Usage:
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
n = apply_lora(net_generator, rank=16, alpha=16.0)
print(f"Wrapped {n} linear layers with LoRA")
# ... train only LoRA params ...
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
# Later, at inference:
apply_lora(net_generator, rank=16, alpha=16.0)
load_lora(net_generator, torch.load("adapter.pt"))
"""
import math
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
Output: base(x) + (dropout(x) @ A.T @ B.T) * scale
Standard init: A is Kaiming uniform, B is zero → adapter starts at zero.
PiSSA init: A and B from top-r SVD of pretrained weight → adapter starts
at the principal components, base weight stores the residual.
"""
def __init__(self, linear: nn.Linear, rank: int, alpha: float,
dropout: float = 0.0, init_mode: str = "standard",
use_rslora: bool = False):
super().__init__()
in_f = linear.in_features
out_f = linear.out_features
self.linear = linear
linear.weight.requires_grad_(False)
if linear.bias is not None:
linear.bias.requires_grad_(False)
ref_dtype = linear.weight.dtype
ref_device = linear.weight.device
if use_rslora:
self.scale = alpha / math.sqrt(rank)
else:
self.scale = alpha / rank
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
if init_mode == "pissa":
# PiSSA: init from top-r SVD of pretrained weight.
# SVD in float32 for numerical stability, then cast back.
W = linear.weight.data.float() # [out_f, in_f]
U, S, Vt = torch.linalg.svd(W, full_matrices=False)
sqrt_S = S[:rank].sqrt()
# A: [rank, in_f], B: [out_f, rank]
A_init = sqrt_S.unsqueeze(1) * Vt[:rank, :]
B_init = U[:, :rank] * sqrt_S.unsqueeze(0)
# Residual: W_res = W - B_init @ A_init * scale
# so that base(x) + LoRA(x) = W_res@x + (B@A)*scale@x = W@x at init
linear.weight.data = (W - B_init @ A_init * self.scale).to(ref_dtype)
self.lora_A = nn.Parameter(A_init.to(dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(B_init.to(dtype=ref_dtype, device=ref_device))
else:
# Standard LoRA: Kaiming A, zero B → starts at identity
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
def extra_repr(self) -> str:
rank = self.lora_A.shape[0]
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
def apply_lora(
model: nn.Module,
rank: int = 16,
alpha: float = None,
target_suffixes: tuple = ("attn.qkv",),
dropout: float = 0.0,
init_mode: str = "standard",
use_rslora: bool = False,
) -> int:
"""Replace matching nn.Linear layers with LoRALinear in-place.
Args:
model: The module to modify (typically net_generator).
rank: LoRA rank.
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
target_suffixes: Tuple of module name suffixes to wrap. Default is
("attn.qkv",) which targets all SelfAttention QKV
projections in the MM-DiT generator.
Add "linear1" to also wrap post-attention output projections.
dropout: Dropout probability on the LoRA path (not the base linear).
0.050.1 helps regularize on small datasets.
Must be 0 when using PiSSA (principal components shouldn't be dropped).
init_mode: "standard" (Kaiming/zero) or "pissa" (SVD-based).
use_rslora: If True, scale by alpha/sqrt(rank) instead of alpha/rank.
Returns:
Number of linear layers wrapped.
"""
if alpha is None:
alpha = float(rank)
if init_mode == "pissa" and dropout > 0.0:
print("[LoRA] Warning: dropout forced to 0 for PiSSA init "
"(principal components should not be dropped).")
dropout = 0.0
count = 0
for name, module in list(model.named_modules()):
if not any(name.endswith(s) for s in target_suffixes):
continue
if not isinstance(module, nn.Linear):
continue
parts = name.split(".")
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], LoRALinear(
module, rank, alpha, dropout=dropout,
init_mode=init_mode, use_rslora=use_rslora,
))
count += 1
return count
def get_lora_state_dict(model: nn.Module) -> dict:
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
def get_lora_and_base_state_dict(model: nn.Module) -> dict:
"""Return state dict with LoRA params AND base linear weights.
Needed for PiSSA checkpoints where the base weight stores the residual
(W - top_r(W)*scale), not the original pretrained weight.
"""
result = {}
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
prefix = name + "."
result[prefix + "lora_A"] = module.lora_A.data
result[prefix + "lora_B"] = module.lora_B.data
result[prefix + "linear.weight"] = module.linear.weight.data
if module.linear.bias is not None:
result[prefix + "linear.bias"] = module.linear.bias.data
return result
def spectral_surgery(
model: nn.Module,
calibration_fn,
n_calibration: int = 128,
policy: str = "smooth_abs",
):
"""Post-training Spectral Surgery: reweight LoRA singular values to suppress
intruder dimensions and amplify useful components (arXiv:2603.03995).
Args:
model: Model with LoRA applied.
calibration_fn: Callable that takes (model, step_idx) and runs one forward+backward
pass on a calibration sample. Must call loss.backward().
n_calibration: Number of calibration samples to average gradients over.
policy: Reweighting policy: "smooth_abs" (recommended), "hard" (binary).
Modifies LoRA A and B in-place. Returns number of layers processed.
"""
model.eval()
lora_layers = [(name, mod) for name, mod in model.named_modules()
if isinstance(mod, LoRALinear)]
if not lora_layers:
return 0
# Accumulate per-layer gradient sensitivity: g_k = u_k^T * (dL/dΔW) * v_k
sensitivities = {}
for name, mod in lora_layers:
sensitivities[name] = None
for step in range(n_calibration):
model.zero_grad()
# Enable grad temporarily on LoRA params
for _, mod in lora_layers:
mod.lora_A.requires_grad_(True)
mod.lora_B.requires_grad_(True)
calibration_fn(model, step)
for name, mod in lora_layers:
A = mod.lora_A.data.float() # [rank, in_f]
B = mod.lora_B.data.float() # [out_f, rank]
# ΔW = B @ A * scale → gradient dL/dΔW ≈ (dL/dB @ A + B^T @ dL/dA) / 2
# Per-component sensitivity: project onto SVD directions
delta_W = (B @ A * mod.scale).detach()
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
r = A.shape[0]
U_r, S_r, Vt_r = U[:, :r], S[:r], Vt[:r, :]
# Compute sensitivity from LoRA gradients
if mod.lora_A.grad is not None and mod.lora_B.grad is not None:
grad_A = mod.lora_A.grad.float() # [rank, in_f]
grad_B = mod.lora_B.grad.float() # [out_f, rank]
# dL/d(ΔW) ≈ grad_B @ A + B^T @ grad_A (chain rule through B@A)
grad_dW = grad_B @ A + B.T @ grad_A # approximate
# Per-component: g_k = u_k^T @ grad_dW @ v_k
g = torch.einsum("ik,ij,jk->k", U_r, grad_dW, Vt_r.T) # [r]
else:
g = torch.zeros(r, device=A.device)
if sensitivities[name] is None:
sensitivities[name] = g
else:
sensitivities[name] += g
# Disable grad again
for _, mod in lora_layers:
mod.lora_A.requires_grad_(False)
mod.lora_B.requires_grad_(False)
# Apply reweighting per layer
count = 0
for name, mod in lora_layers:
g = sensitivities[name] / n_calibration
A = mod.lora_A.data.float()
B = mod.lora_B.data.float()
delta_W = B @ A * mod.scale
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
r = A.shape[0]
S_r = S[:r]
if policy == "hard":
# Keep components with positive sensitivity, zero out negative
mask = (g > 0).float()
else:
# smooth_abs: sigmoid-weighted by sensitivity magnitude
# Normalize g to [-1, 1] range, apply sigmoid
g_norm = g / (g.abs().max() + 1e-8)
mask = torch.sigmoid(5.0 * g_norm) # steep sigmoid
# L1 norm preservation: scale mask so total nuclear norm is preserved
mask = mask * (S_r.sum() / (mask * S_r).sum().clamp(min=1e-8))
# Reconstruct: ΔW' = U_r @ diag(mask * S_r) @ Vt_r
S_new = mask * S_r
delta_W_new = U[:, :r] @ torch.diag(S_new) @ Vt[:r, :]
# Factor back into B' @ A' * scale: use SVD of ΔW'/scale
dW_unscaled = delta_W_new / mod.scale
U2, S2, Vt2 = torch.linalg.svd(dW_unscaled, full_matrices=False)
sqrt_S2 = S2[:r].sqrt()
A_new = sqrt_S2.unsqueeze(1) * Vt2[:r, :]
B_new = U2[:, :r] * sqrt_S2.unsqueeze(0)
ref_dtype = mod.lora_A.dtype
mod.lora_A.data = A_new.to(ref_dtype)
mod.lora_B.data = B_new.to(ref_dtype)
count += 1
kept = (mask > 0.5).sum().item()
print(f"[Spectral Surgery] {name}: kept {kept}/{r} components, "
f"sensitivity range [{g.min():.3f}, {g.max():.3f}]", flush=True)
return count
def load_lora(model: nn.Module, state_dict: dict) -> None:
"""Load LoRA weights into a model that has already had apply_lora() called.
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
parameters are not modified.
"""
missing, unexpected = model.load_state_dict(state_dict, strict=False)
bad = [k for k in unexpected if "lora_" not in k]
if bad:
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
lora_missing = [k for k in missing if "lora_" in k]
if lora_missing:
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
+465
View File
@@ -0,0 +1,465 @@
#!/usr/bin/env python3
"""
LoRA fine-tuning for SelVA / MMAudio generator.
Teaches the model new or partially-known sound classes from custom video+audio pairs.
Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model).
Data layout:
data/my_sound/
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage:
python train_lora.py \\
--data_dir data/my_sound \\
--output_dir lora_output \\
--variant large_44k \\
--selva_dir /path/to/ComfyUI/models/selva \\
--rank 16 --steps 2000 --lr 1e-4
"""
import argparse
import os
import sys
import random
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
}
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Stereo → mono
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Resample
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr
).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def load_npz(path: Path) -> dict:
"""Load a feature bundle produced by SelvaFeatureExtractor."""
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
# ---------------------------------------------------------------------------
# Feature extraction (audio + text only — visual features come from .npz)
# ---------------------------------------------------------------------------
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
"""Encode a waveform to the generator's latent space via the VAE.
encode_audio is @inference_mode — .clone() is required before the autograd path.
"""
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)")
parser.add_argument("--target", nargs="+", default=["attn.qkv"],
help="Module name suffixes to wrap with LoRA. Also try 'linear1'.")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.")
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
help="Spread of logit-normal distribution.")
parser.add_argument("--curriculum_switch", type=float, default=0.6,
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
parser.add_argument("--lora_dropout", type=float, default=0.0,
help="Dropout on the LoRA path only. 0.050.1 helps on small datasets.")
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
args = parser.parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[LoRA] bf16 not supported on this GPU — falling back to fp16")
args.precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate
# --- Weight paths ---
def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name)
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
gen_weight = w(gen_filename)
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1)
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
print("[LoRA] Loading CLIP text encoder...")
clip_model = create_model_from_pretrained(
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
).to(device, dtype).eval()
clip_model = patch_clip(clip_model)
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_weight,
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_weight, map_location="cpu", weights_only=False)
)
# --- Apply LoRA ---
n_lora = apply_lora(
net_generator,
rank=args.rank,
alpha=args.alpha,
target_suffixes=tuple(args.target),
dropout=args.lora_dropout,
)
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
if n_lora == 0:
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
sys.exit(1)
# Freeze everything except LoRA params
for name, p in net_generator.named_parameters():
p.requires_grad_("lora_" in name)
trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad)
total = sum(p.numel() for p in net_generator.parameters())
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)")
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
print(f"[LoRA] No .npz files found in {data_dir}")
sys.exit(1)
prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
dataset = []
for npz_path in npz_files:
audio_path = find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
bundle = load_npz(npz_path)
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try:
audio = load_audio(audio_path, sample_rate, duration)
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset:
print("[LoRA] No clips could be loaded.")
sys.exit(1)
print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler ---
# LoRA+: separate param groups for A and B with different LRs.
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": args.lr},
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
], weight_decay=1e-2)
if args.lora_plus_ratio != 1.0:
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
def lr_lambda(step):
if step < args.warmup_steps:
return step / max(1, args.warmup_steps)
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
if "step" not in ckpt:
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
sys.exit(1)
start_step = ckpt["step"]
if start_step >= args.steps:
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
sys.exit(0)
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step}{args.steps})")
# --- Training loop ---
net_generator.train()
optimizer.zero_grad()
remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
_curriculum_switched = False
total_loss = 0.0
for step in range(start_step + 1, args.steps + 1):
batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1)
if args.timestep_mode == "logit_normal" or (
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
t = torch.sigmoid(u)
else:
t = torch.rand(args.batch_size, device=device, dtype=dtype)
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
_curriculum_switched = True
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
loss.backward()
total_loss += loss.item() * args.grad_accum
if step % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
avg = total_loss / 50
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}")
total_loss = 0.0
if step % args.save_every == 0 or step == args.steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(net_generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
},
}, ckpt_path)
print(f"[LoRA] Saved {ckpt_path}")
# Save final adapter with embedded metadata
# Increment filename if a previous final already exists (resume case)
final = output_dir / "adapter_final.pt"
if final.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final = output_dir / f"adapter_final_{i:03d}.pt"
meta = {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
}
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA] Training complete. Adapter saved to {final}")
if __name__ == "__main__":
main()