221 Commits

Author SHA1 Message Date
Ethanfel 40d29bcaf8 feat: add experiment configs for logit+cosine combo and BigVGAN decoder fine-tuning
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 16:48:21 +02:00
Ethanfel 65dc549494 feat: add reference audio comparison metrics to LoRA trainer eval
New _reference_metrics() computes LSD, MCD, and per-band correlation
between eval samples and the original source audio at each checkpoint.
Loads reference audio once before the training loop and logs metrics
alongside existing spectral metrics.

Also fix batch_size in lora_optimized_dataset.json (4 -> 16).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 15:04:07 +02:00
Ethanfel f745e241c4 chore: sanitize tooltips/comments + add experiment configs
- Replace all BJ references with generic "target style/audio" in
  activation steering, DITTO optimizer, and BigVGAN trainer
- Add latent_mixup_alpha/latent_noise_sigma to LoRA scheduler defaults
- Add bigvgan_disc_fm_retest.json and lora_optimized_dataset.json

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 13:44:37 +02:00
Ethanfel 082a2da438 fix: restore dtype after float32 STFT in discriminator spectrogram
torch.stft requires float32 input, but the .float() cast was not
reversed before the spectrogram hit bfloat16 Conv2d weights. Save
the original dtype and cast back after abs().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 12:13:55 +02:00
Ethanfel c28e090196 fix: cast discriminator inputs to match bfloat16 dtype in BigVGAN FM loss
The frozen discriminators are loaded in model dtype (bfloat16) but vocoder
waveform outputs are float32, causing a Conv2d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 11:36:02 +02:00
Ethanfel af6c225f53 feat: add dataset pipeline nodes + latent augmentation for LoRA trainer
New dataset pipeline nodes:
- SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution
- SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility
- SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking

Improvements:
- Inspector: silence detection (max_silence_fraction param)
- Saver: origin_name lookup for augmented clips' npz pairing
- LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization
- LoRA trainer: one-time SR mismatch warning in _load_audio

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 11:32:00 +02:00
Ethanfel 30127c13ca feat: add BigVGAN vocoder sweep scheduler node
Runs a series of BigVGAN fine-tuning experiments from a JSON sweep file.
Audio clips loaded once, vocoder deep-copied per experiment, results
collected in experiment_summary.json with comparison loss curves.
Resume-aware — skips completed experiments on re-run.

Includes overnight sweep config (8 experiments): snake alpha steps,
GAFilter ablation, phase loss weight, discriminator FM, all_params.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 02:39:56 +02:00
Ethanfel 4226297735 chore: remove debug VRAM logging
Training confirmed working — VRAM usage is normal backward-pass
activation memory, not a leak. Removed all debug _vram_log and _vram
calls. Kept the video_enc offload and torch.cuda.empty_cache fixes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:50:08 +02:00
Ethanfel 4297715a08 debug: add driver-level VRAM reporting + offload video_enc
torch.cuda.memory_allocated only tracks PyTorch allocator. Added
torch.cuda.mem_get_info to see actual CUDA driver memory usage.
Also offload video_enc (TextSynch) which was missed in the original
offload — stays on GPU when strategy != offload_to_cpu.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:48:04 +02:00
Ethanfel 9af4bbdd91 fix: force torch.cuda.empty_cache() after pre-generation and CLIP encoding
PyTorch's caching allocator reserves GPU memory from pre-generation
(~90 GiB for generator + tod) and doesn't return it to CUDA/OS.
soft_empty_cache may not call torch.cuda.empty_cache(). Force a full
cache release after CLIP encoding and after LoRA mel pre-generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:42:45 +02:00
Ethanfel 89d6fccd28 debug: add per-operation VRAM logging in first training step
Logs VRAM at: after target_mel, after vocoder forward, before loss,
after loss computation, and after backward. Only logs for step 0 to
avoid spam. Will identify which operation causes the 94 GiB spike.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:35:54 +02:00
Ethanfel bd84242fa1 debug: add VRAM logging at offload and training checkpoints
Logs torch.cuda.memory_allocated/reserved at each step: before unload,
after unload_all_models, after feature_utils.to(cpu), after generator
to(cpu), after cache clear, after mel_converter to(device), and before
training loop. This will identify what's holding VRAM.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:28:31 +02:00
Ethanfel 5a2c003fb2 fix: move baseline sample after inference flag stripping
_save_sample("baseline") was called before the vocoder's inference
tensors were sanitized, causing "Inference tensors do not track version
counter". Moved it after the clone/detach loop and vocoder.to(device).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:26:11 +02:00
Ethanfel f8d4d77b0d fix: pre-compute text CLIP embeddings in main thread to avoid inference tensor crash
CLIP weights are inference tensors from ComfyUI loading. inference_mode
is thread-local, so the worker thread can't use CLIP even with a context
manager. Pre-compute all text embeddings in the main thread (where
inference_mode IS active), clone+detach to normal tensors, and pass them
to the worker via text_clip_cache dict. CLIP no longer needs to be on
GPU during pre-generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:19:44 +02:00
Ethanfel 32e5344ea2 fix: wrap CLIP encoding in inference_mode during pre-generation
CLIP weights are inference tensors from ComfyUI loading. The worker
thread runs without inference_mode, so PyTorch rejects inference tensors
in multi_head_attention_forward (version counter tracking). Wrap the
encode_text_clip call in torch.inference_mode() since text encoding
doesn't need gradients.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 01:10:58 +02:00
Ethanfel 10a71b0c4f fix: offload entire model to CPU in main thread before worker starts
The previous offload ran inside the worker thread, but by then ComfyUI
had already loaded the full model to GPU. Now feature_utils.to('cpu')
and generator.to('cpu') run in the main thread right after
unload_all_models(), before the worker starts. vocoder.to(device, dtype)
is called explicitly after inference flag stripping in _do_train to
bring only the vocoder back to GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:56:13 +02:00
Ethanfel 37a27160aa fix: match mel dtype to vocoder in baseline sample generation
ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16
before inference flag stripping. Cast mel to vocoder's dtype to prevent
input/bias type mismatch during baseline sample save.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:45:31 +02:00
Ethanfel cb9a1eef01 fix: stop loading full feature_utils to GPU before training
feature_utils.to(device) was loading CLIP ViT-H, synchformer, T5, VAE,
and vocoder (~90 GiB) to GPU for the entire training run. Now only
mel_converter (tiny) is moved to GPU. Pre-generation manages its own
device placement: temporarily moves CLIP and tod to GPU, then moves them
back when done. This frees ~90 GiB for the backward pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:44:38 +02:00
Ethanfel d70c611bf7 fix: offload CLIP, synchformer, T5, generator, VAE to CPU before training
Only the vocoder and mel_converter are needed during BigVGAN training.
The rest of the SelVA pipeline (CLIP ViT-H, synchformer, T5, generator,
VAE) was staying on GPU and consuming ~90 GiB, leaving no room for
backward pass activations. Now offloaded individually to CPU before
the training loop starts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:33:07 +02:00
Ethanfel 4e6cc4d519 feat: cache pre-generated LoRA mels to disk for reuse
LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow.
Cache results to a .pt file next to the output, keyed by a SHA-256 hash
of the LoRA adapter content + generation parameters (seed, steps, CFG,
duration, sample rate, npz file list). Automatically reused on subsequent
runs when parameters haven't changed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:30:20 +02:00
Ethanfel 0854bd2638 fix: cast discriminators to model dtype to match vocoder output
Discriminators are constructed as float32 but receive bfloat16 tensors
from the vocoder. Cast to model dtype on load to prevent conv dtype
mismatch in feature matching loss.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:25:04 +02:00
Ethanfel 187b2e3169 fix: cast GAFilter to model dtype after injection
GAFilter conv weights are created as float32 but the rest of the vocoder
is bfloat16. vocoder.to(device) missed the dtype cast, causing conv1d
dtype mismatch when Snake bfloat16 output flows into GAFilter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:24:11 +02:00
Ethanfel 608746ce7b fix: cast input mel to model dtype before vocoder forward pass
mel_converter outputs float32 (cuFFT requirement) but vocoder weights are
bfloat16 from model loading. Cast input_mel back to model dtype before
feeding the vocoder to avoid conv1d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:18:05 +02:00
Ethanfel bba5aec7a5 fix: add CFG to LoRA mel pre-generation to match inference conditions
Pre-generated mels were using a bare forward pass with no classifier-free
guidance, producing mels that don't match what the vocoder sees at inference
(where cfg_strength=4.5 is the default). Now uses ode_wrapper with
preprocess_conditions/get_empty_conditions, same as the sampler node.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:17:16 +02:00
Ethanfel d06936802b fix: cast mel_converter buffers to float32 to match STFT input dtype
mel_basis and hann_window buffers inherit bfloat16 from model loading.
Since all mel_converter inputs are cast to float32 for cuFFT, the
internal buffers must also be float32 to avoid matmul dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 00:10:52 +02:00
Ethanfel bee518a855 fix: cast all STFT inputs to float32 to prevent cuFFT bfloat16 crash
cuFFT does not support bfloat16 tensors. When the model is loaded in
bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram,
multi-resolution STFT loss) crash. Add .float() at every STFT boundary.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 23:53:36 +02:00
Ethanfel 48b72c0be0 feat: add LoRA mel pre-generation to BigVGAN vocoder trainer
When a lora_adapter path is provided, the trainer pre-generates
LoRA-distorted mels for each training clip (full ODE generation +
VAE decode) and trains the vocoder to produce clean audio from them.
This teaches the vocoder to compensate for LoRA latent distribution
shift without requiring perfectly aligned training pairs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 23:26:36 +02:00
Ethanfel e16480b4c9 feat: add PiSSA/rsLoRA support to scheduler and PiSSA sweep experiment
Thread init_mode and use_rslora through the scheduler's config parsing,
experiment record, and _train_inner call. Default alpha changed to 2*rank
to match trainer. Add pissa_sweep.json with 7 experiments ablating PiSSA
init vs standard, rsLoRA scaling, and learning rate variations at rank 128.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 22:07:27 +02:00
Ethanfel 784fb2753f feat: PiSSA init, rsLoRA scaling, Spectral Surgery, and training fixes
LoRA quality improvements addressing intruder dimension problem:

1. PiSSA initialization (arXiv:2404.02948): init A,B from top-r SVD of
   pretrained weight. Starts on-manifold, eliminates intruder dimensions
   at init. Base weight stores residual W_res = W - B@A*scale.

2. rsLoRA scaling (arXiv:2312.03732): alpha/sqrt(rank) instead of
   alpha/rank. Prevents gradient collapse at high ranks (128+).

3. Post-training Spectral Surgery (arXiv:2603.03995): SVD of trained
   LoRA update, gradient-sensitivity reweighting to suppress remaining
   intruder dimensions. Runs automatically after training completes.

4. alpha default changed to 2*rank (was 1*rank). Produces fewer intruder
   dimensions per arXiv:2410.21228.

5. weight_decay reduced from 1e-2 to 0.0 (standard for LoRA, prevents
   erasing learned style weights).

6. random.choices replaced with random.sample when batch_size <= dataset
   size (eliminates duplicate samples per batch).

PiSSA checkpoints include base weights (residual). Loader/evaluator
updated to handle both standard and PiSSA checkpoint formats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 21:54:36 +02:00
Ethanfel ecf828b007 fix: move vocoder to correct device after GAFilter injection
inject_gafilters creates Conv1d modules on CPU. load_state_dict
preserves existing param devices but GAFilter params stay on CPU,
causing device mismatch during vocode. Save target device before
injection, then move entire vocoder after loading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 20:28:55 +02:00
Ethanfel 793368af18 fix: strip inference flag before unnormalize in LoRA trainer eval
x1_pred is an inference tensor (computed from inference-mode weights
loaded by ComfyUI). generator.unnormalize() uses in-place mul_/add_
which fails on inference tensors. Clone strips the flag.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 20:01:53 +02:00
Ethanfel 1d1ae61409 fix: move only VAE+vocoder to GPU during eval to prevent device mismatch
The previous check (next(feature_utils_orig.parameters()).device) only
inspected the first parameter (from CLIP), missing CPU-stranded vocoder
weights when the module was in a mixed-device state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 19:36:02 +02:00
Ethanfel 8fa2699551 fix: correct DITTO reference latent space mismatch
References were stored in normalized flow-matching space
(net_generator.normalize(z_sample)) but the style loss compares against
unnormalize(x) which is in VAE latent space. The optimizer was minimizing
L1 between tensors at different scales, pushing the ODE endpoint out of
distribution and producing noise.

Fix: store reference latents in VAE space (z_sample directly) so both
ref_mean/ref_gram and x_un are in the same coordinate system.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:57:08 +02:00
Ethanfel 14fabf01f9 fix: reduce opt_lr step to 0.001 to allow finer lr control in DITTO
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:40:21 +02:00
Ethanfel 445da1e69b fix: replace std clamp with anchor regularization to prevent OOD noise
The std clamp was post-hoc and only addressed magnitude, not direction.
x0 was drifting to mean=-0.55/std=3.1 (ODE expected mean=0/std=1).

Replace with anchor_weight * MSE(x0, x0_init) added directly to the loss.
The optimizer now balances style matching against staying near the initial
N(0,1) noise — gradient-aware, prevents both magnitude and mean drift.

Also logs style/anchor losses and x0_std per step for diagnostics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:30:05 +02:00
Ethanfel fa6c4fa834 fix: clamp x0 std after each optimizer step to prevent OOD noise
Optimized x0 was reaching std=2.72 vs expected ~1.0 for flow matching.
An out-of-distribution initial condition maps to white noise in the output.
After each step, rescale x0 back toward unit std if it exceeds 1.5.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:23:39 +02:00
Ethanfel 286681edff fix: cast mel to model dtype before VAE encode in DITTO reference loading
mel_converter outputs float32 (cuFFT requirement), but VAE encoder weights
are bfloat16. Cast mel to dtype before encode to avoid type mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:18:41 +02:00
Ethanfel 056a7b973d fix: enable VAE encoder in model loader — required for DITTO reference encoding
need_vae_encoder=False was deleting the encoder to save a small amount of VRAM.
DITTO now needs it to encode reference clips to latent space for style loss.
The spectrogram VAE encoder is small enough that the overhead is negligible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:15:27 +02:00
Ethanfel 633fe36fbb fix: compute DITTO style loss in latent space to eliminate VAE decoder noise
Root cause of white noise: backpropagating through vae.decode produces
unstable gradients — the VAE decoder was designed for inference only.

Fix: encode reference clips to VAE latent space once (no grad), compute
mean + Gram matrix statistics there, and compute style loss directly on
net_generator.unnormalize(x) — a single differentiable linear operation.
The gradient path is now: loss → x (unnormalized) → ODE → x0, with no
decoder in the backward pass.

Also adds VAE encoder availability check (fails cleanly if encoder was
deleted to save VRAM).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:12:31 +02:00
Ethanfel 8862089fd0 fix: remove 32-clip cap on DITTO reference loading — use all available clips
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:10:10 +02:00
Ethanfel 608e7df04b feat: add gram_weight param to DITTO, reduce default style_weight to 0.1
White noise on output was caused by the Gram matrix loss pushing the latent
into incoherent regions. Now gram_weight defaults to 0 (mean spectrum only)
and style_weight defaults to 0.1 instead of 1.0. Users can enable Gram
gradually once mean-only optimization converges cleanly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:03:32 +02:00
Ethanfel 101b1bdb41 fix: _do_optimize returns dict not tuple — prevent double-wrapping AUDIO output
optimize() does return (_result[0],) to wrap for ComfyUI. _do_optimize was
returning (dict,) instead of dict, causing double-wrapping: ((dict,),).
ComfyUI then received a tuple as audio and failed on audio["waveform"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:56:59 +02:00
Ethanfel 732df151b0 fix: cast ref_mean/ref_gram to model dtype before loss computation
ref_mean and ref_gram are float32 (mel computed via cuFFT which requires
float32). mel_gen is bfloat16. F.l1_loss(bfloat16, float32) promotes to
float32, producing a float32 loss. loss.backward() then pushes float32
gradients through bfloat16 ops → 'Found dtype Float but expected BFloat16'.

Fix: clone().detach().to(dtype) at the start of _do_optimize.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:48:41 +02:00
Ethanfel 817b75df49 fix: bypass @torch.inference_mode() on decode to preserve gradient chain
feature_utils.decode and autoencoder.decode are both decorated with
@torch.inference_mode(), which unconditionally destroys grad_fn on all
outputs — making loss.backward() fail with 'does not require grad'.

Fix: call feature_utils.tod.vae.decode() directly, which has no decorator
and is fully differentiable. Transpose matches the original wrapper signature.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:44:35 +02:00
Ethanfel 1f02d73a3e fix: remove checkpoint wrapper on decode — direct call preserves grad chain
_unnorm_decode was wrapped in checkpoint(use_reentrant=False) to avoid saving
inference-mode weight tensors during backward. Since _strip_inference() now
cleans all params/buffers before any forward pass, the checkpoint is no longer
needed and was silently breaking the gradient chain from mel_gen back to x0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:40:00 +02:00
Ethanfel fb255edaf0 fix: strip inference-mode tensor flags in DITTO before conditions computation
Root cause: net_generator/feature_utils/mel_converter parameters were loaded
in ComfyUI's inference_mode; operations on inference tensors propagate the flag,
so conditions computed from tainted weights were also tainted. checkpoint()
with use_reentrant=False then failed trying to save inference tensors during
the backward recompute pass.

Fix: _strip_inference() clones all params/buffers of all three models before
any forward pass, and _clone_nested() cleans any residual inference flags in
the conditions/empty_conditions output tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:35:15 +02:00
Ethanfel 8ccc2438e4 fix: remove FlashSR (audiosr incompatible with Python 3.12), add training loss CSV
- Drop SelvaFlashSR node — audiosr pins numpy<=1.23.5 which cannot build
  on Python 3.12 (pkgutil.ImpImporter removed); use Saganaki22/ComfyUI-AudioSR instead
- BigVGAN trainer now writes <output_stem>_training_log.csv alongside the
  checkpoint: step, total, fm, mel, stft, phase, l2sp columns, line-buffered
  so loss can be tailed live during training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:18:34 +02:00
Ethanfel 8371466e44 fix: guarantee length preservation in _ActivationWithGAFilter
Activation1d's anti-alias Kaiser sinc resampling (asymmetric pad_left /
pad_right) can produce ±1-2 sample rounding in edge cases, causing the
BigVGAN AMPBlock residual addition (xt + x) to fail with a size mismatch.

Trim or pad the output to exactly match the input length so the resblock
skip connection always has matching dimensions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:39:03 +02:00
Ethanfel ba0499b77c fix: FlashSR device handling and remove unused tmp_out
Use device="auto" for audiosr.build_model — safer than passing a device
string that may not be accepted in all audiosr versions.
Remove unused tmp_out temp file that was created but never written to.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:32:02 +02:00
Ethanfel ce62bccc1f feat: add post-generation audio enhancement nodes
Three new nodes for post-generation quality improvement:

- SelvaHarmonicExciter: multi-band exciter (HPF → tanh saturation → mix)
  restores harmonic richness lost in BigVGAN HF reconstruction

- SelvaFlashSR: audio super-resolution via FlashSR basic model
  (haoheliu/versatile_audio_super_resolution, requires pip install audiosr)
  predicts missing HF content above vocoder reconstruction ceiling

- SelvaOutputNormalizer: BS.1770-4 LUFS normalization + true peak limiting
  for consistent loudness on generated outputs (pyloudnorm)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:27:39 +02:00
Ethanfel 45fced55bc fix: exclude GAFilter params from L2-SP regularization
L2-SP anchors trainable params to their pretrained values. GAFilter is a
newly initialized module (identity FIR filter) with no pretrained values —
anchoring it to identity initialization would resist learning. Exclude
gafilter params from the L2-SP loss so they train freely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:19:52 +02:00
Ethanfel db112394e8 feat: add AF-Vocoder GAFilter to BigVGAN trainer and loader
Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel
depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN
residual blocks. Initialized as identity so training starts from pretrained
behaviour.

- inject_gafilters() walks resblocks.*.activations and wraps each Activation1d
  with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically
- Trained alongside Snake alphas in snake_alpha_only mode
- Checkpoint saves has_gafilter + gafilter_kernel_size metadata
- Loader detects metadata and injects before load_state_dict so weights populate correctly
- Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:15:14 +02:00
Ethanfel c53ea5517c feat: add FA-GAN phase-aware STFT loss to BigVGAN trainer
Adds L1 loss on real, imaginary, and magnitude STFT components across
three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing
directly — magnitude-only losses cannot distinguish correct spectrum
with wrong phase from a smeared spectrum.

Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top
of both the discriminator FM path and the fallback mel+STFT path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 16:09:31 +02:00
Ethanfel 82e449681c fix: cast mel_converter and wav to float32 before cuFFT in DITTO
cuFFT does not support bfloat16. mel_converter was being moved to device
without an explicit dtype, inheriting bfloat16 from the model context.
Force float32 for both mel_converter.to() and wav.to() so the STFT
inside the mel converter runs in a supported dtype.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:59:55 +02:00
Ethanfel 15fc5f0793 feat: add SelvaDatasetCompressor node for parallel compression
Mild 2:1-3:1 parallel compression via pedalboard.Compressor to reduce
within-clip loudness variance after LUFS normalization. Blend ratio
keeps transients intact while tightening dynamics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:36:27 +02:00
Ethanfel 48493a3f0d feat: add SelvaDatasetSaver node with NPZ sidecar copy
Saves all clips in an AUDIO_DATASET to FLAC. When npz_source_dir is
provided, copies the matching .npz for each clip so FLAC/NPZ pairs
stay in sync after the inspector filters out bad clips.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:27:48 +02:00
Ethanfel becb38c27e fix: use soundfile for WAV/FLAC/OGG to bypass torchcodec/FFmpeg dependency
torchaudio was defaulting to the torchcodec backend which requires FFmpeg
shared libraries not present in the ComfyUI venv, silently skipping every
clip and producing an empty dataset.

Also add experiments/vocoder_finetune.json for the BJ vocoder LoRA run
(lr=3e-4, rank=128, 10k steps).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 15:16:22 +02:00
Ethanfel b9f95cfd7e fix: detect silent discriminator load failure and fall back explicitly
If no matching key was found for MPD or MRD in the checkpoint, the for-loops
completed silently and randomly-initialized discriminators were used as frozen
feature extractors — producing meaningless feature matching loss while
appearing to work. Now raises RuntimeError (caught by outer except) which
triggers the existing fallback to mel+STFT losses with a clear warning.
Also prints available checkpoint keys to help diagnose format mismatches.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:39:55 +02:00
Ethanfel f50afa9796 fix: guard _estimate_snr against short clips, fix freqs device in _check_hf_shelf
Bug 1: mono.unfold(0, 2048, 512) returns an empty tensor for clips shorter
than 2048 samples (~46ms). torch.quantile on an empty tensor crashes with
"quantile() input tensor must be non-empty". Guard: return 60.0 (assume
clean) for clips too short to frame — the pipeline has no minimum-length
filter so any short file in the dataset folder would crash the Inspector.

Bug 2: torch.linspace(...) in _check_hf_shelf created a CPU tensor, making
band_lo/band_hi CPU boolean masks. Indexing a GPU mag_sq tensor with CPU
masks crashes. Pass device=mono.device so freqs lands on the same device
as the audio.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:28:36 +02:00
Ethanfel 8a85819f97 feat: register audio dataset pipeline nodes in __init__.py 2026-04-09 14:25:57 +02:00
Ethanfel f1c4654bab feat: add SelvaDatasetItemExtractor node 2026-04-09 14:24:58 +02:00
Ethanfel 2d06cb2f52 fix: pass device to hann_window in _check_hf_shelf to avoid GPU mismatch 2026-04-09 14:22:13 +02:00
Ethanfel 0731addea9 feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:20:03 +02:00
Ethanfel 7eb9bd5745 feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:17:44 +02:00
Ethanfel 057bfb813d feat: add SelvaDatasetResampler node (soxr VHQ)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:13:45 +02:00
Ethanfel 2c71d4c184 feat: add SelvaDatasetLoader node
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 14:09:43 +02:00
Ethanfel d25df10aa5 feat: add audio dataset pipeline skeleton 2026-04-09 14:05:31 +02:00
Ethanfel d70a4d2123 docs: add audio dataset pipeline implementation plan 2026-04-09 14:02:46 +02:00
Ethanfel 2b10205657 fix: raise segment_seconds max from 4s to 30s
Hardcoded max of 4.0 prevented using full 8s clips. Raised to 30s.
Also bumped default from 1.0 to 2.0 as a more sensible starting point.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:49:50 +02:00
Ethanfel 8166c56552 perf: gradient checkpointing on vocoder forward to reduce activation memory
BigVGAN's 512x upsampling stack stores huge intermediate activations for
backward even in snake_alpha_only mode (only 5K trainable params, but
activation graph runs through the full network after each snake op).

Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations
during backward instead of storing them — ~2x compute cost, large reduction
in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:45:24 +02:00
Ethanfel eece79ccae fix: correct MRD channel width to 128 and unload models before training
Two bugs:

1. _DiscriminatorR used channels=32 but the BigVGAN pretrained discriminator
   checkpoint has channels=128. All convs in _DiscriminatorR now use 128,
   matching the checkpoint architecture so state_dict loads without error.

2. BigVGAN trainer OOM: SelVA generator and other ComfyUI models remain in
   VRAM during training (~90 GiB used). Add unload_all_models() + cache
   flush before the training loop to reclaim VRAM headroom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 13:40:01 +02:00
Ethanfel 357b875e5e fix: strip inference tensor flags in DITTO optimizer
Two crash paths under "RuntimeError: Inference tensors cannot be
saved for backward":

1. clip_f / sync_f loaded from main-thread inference_mode carry the
   inference flag. Clone them on entry to the worker thread so the
   conditions built from them are clean non-inference tensors.
   Also clone x after Phase 1 before the STE reconnection — Phase 1
   runs under no_grad and produces outputs that may still carry the
   flag through the conditions path.

2. net_generator.unnormalize + feature_utils.decode called outside
   any checkpoint wrapper with requires_grad=True input. Backward
   tried to save inference-flagged model weights. Wrapped both calls
   in checkpoint(use_reentrant=False) so they recompute on backward
   instead of storing activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:18:20 +02:00
Ethanfel 211494a91c fix: DITTO gradient never reached x0, remove unused imports and dead code
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.

Fix: use a straight-through estimator after the no-grad prefix:
  x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.

Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:10:02 +02:00
Ethanfel 1e9551152e feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes
BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 12:04:05 +02:00
Ethanfel f17f6f0863 feat: save ground truth spectrogram once for direct comparison
Writes _gt_spec.png from ref_mel before training starts so each step's
_spec.png can be compared against the unmodified vocoder roundtrip target.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:05:47 +02:00
Ethanfel 304d9d01bf feat: save mel spectrogram PNG alongside each eval sample
Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample
call now writes both a .wav and a _spec.png so training progress is visible
without listening. Colour map is blue→green→yellow (viridis-ish), low
frequencies at the bottom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:03:28 +02:00
Ethanfel 0128a81cc2 fix: use full first clip for eval samples instead of 1s segment
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:01:52 +02:00
Ethanfel 710261f5be fix: add soundfile fallback for torchaudio.save in sample writing
Same environment has no compatible ffmpeg/torchcodec for saving.
Mirror the _load_wav pattern: try torchaudio, fall back to soundfile.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:58:07 +02:00
Ethanfel 5df2abd6dd fix: handle all three inference-tensor sources in vocoder sanitization
remove_parametrizations() stores weight as a plain __dict__ tensor (not
nn.Parameter), making it invisible to _parameters iteration. Also, buffers
(Activation1d anti-aliasing filters) are inference tensors that break the
backward graph mid-network. Fix all three categories:
1. _parameters: clone().detach(), wrap as Parameter
2. plain __dict__ tensors: clone(), register_parameter (also makes trainable)
3. _buffers: clone() to strip inference flag without parametrizing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:54:41 +02:00
Ethanfel b243908873 debug: inspect conv_pre parametrizations and _parameters keys
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:46:16 +02:00
Ethanfel 9df855ee0e debug: print is_inference() status before failing conv_pre call
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:41:51 +02:00
Ethanfel 78f8aa98ad fix: clone inference tensors at thread entry to strip the inference flag
torch.inference_mode is thread-local, but the inference flag lives on the
tensor object. Operations on inference tensors always propagate it, even in
a clean thread. The only escape is .clone() called outside inference_mode.
At thread entry (inference_mode disabled): clone clips and mel_converter
buffers to get clean normal tensors before any training computation.
Vocoder parameter clone() also now works correctly in this thread context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:35:48 +02:00
Ethanfel e870446b0f fix: run BigVGAN training in a fresh thread to escape inference_mode
torch.inference_mode is thread-local. ComfyUI sets it on the node-execution
thread; inference_mode(False) alone is insufficient to escape it in some
environments (e.g. async wrappers, lora-manager hook). A new thread always
starts clean. Moved all training logic into _do_train() called via
threading.Thread so every tensor is a normal autograd tensor by default.
Simplified parameter cloning: clone().detach().requires_grad_(True).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:30:53 +02:00
Ethanfel df63b147e9 fix: sanitize all submodule buffers of mel_converter + guarantee target_mel output
Previous fix only iterated mel_converter._buffers (direct buffers). Submodules
(e.g. Spectrogram.window) still held inference tensors. Switch to .modules()
to cover all nested buffers, matching the vocoder parameter sanitization.
Also add a zeros+copy_ safety net on target_mel output so conv can save it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:14:12 +02:00
Ethanfel 51ac099073 fix: sanitize target_flat — clips are inference tensors from outer inference_mode
The clips list is built inside ComfyUI's inference_mode context, so every
element is an inference tensor. torch.stack().clone() propagates the flag.
Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor,
so mel_converter(target_flat) inside no_grad produces a saveable input.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:09:26 +02:00
Ethanfel b7565ec458 fix: sanitize inference tensors in BigVGAN trainer via zeros+copy_ pattern
param.data.clone() and tensor.detach() on inference tensors both produce
inference tensors — the flag propagates through all operations on them.
Inside inference_mode(False), torch.zeros() creates genuine normal tensors.
Use zeros+copy_ to sanitize both vocoder parameters and mel_converter
buffers once before training, so autograd can save inputs for backward.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:05:36 +02:00
Ethanfel 0fcb6d3106 fix(bigvgan-trainer): replace parameter objects to fully strip inference tensor flag
param.data = clone() only replaces storage — the nn.Parameter object itself
retains the inference tensor flag set when the model was loaded. Replace each
parameter with a fresh nn.Parameter(data.clone()) created inside
inference_mode(False) so both the object and its data are normal tensors.
Move optimizer creation to after re-creation so it references the new objects.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:58:57 +02:00
Ethanfel c86306bde8 fix(bigvgan-trainer): clone vocoder parameters to strip inference tensor flag
The vocoder is loaded inside ComfyUI's torch.inference_mode(), making all
its parameters inference tensors. Autograd cannot save inference tensors
for backward even with requires_grad=True. Clone all parameters inside
torch.inference_mode(False) before training to get normal tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:55:16 +02:00
Ethanfel f04d59fe63 fix(bigvgan-trainer): clone mel outputs to strip inference tensor flag from buffers
mel_converter buffers (mel_basis, hann_window) are inference tensors
because the model was loaded inside ComfyUI's torch.inference_mode().
Operations on them propagate the flag to outputs. Clone both target_mel
and pred_mel to get normal autograd-compatible tensors. .clone() is
differentiable so the grad graph to vocoder parameters is preserved.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:51:28 +02:00
Ethanfel daa36a5f7b fix(bigvgan-trainer): clone target tensor to exit inference mode before backward
Clips loaded outside torch.inference_mode(False) are inference tensors.
Autograd cannot save them for backward. .clone() creates a normal tensor,
same fix pattern as selva_lora_trainer's dist.mode().clone().

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:47:47 +02:00
Ethanfel 16e20b30ce fix(bigvgan-trainer): cast audio to model dtype to match bf16 mel_converter buffers
Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded
from disk is float32, causing matmul dtype mismatch. Cast all audio
tensors to model["dtype"] before passing to mel_converter/vocoder.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:46:01 +02:00
Ethanfel ea7dfed27a fix(bigvgan-trainer): fallback to soundfile when torchaudio ffmpeg backend fails
torchcodec/libavutil soname mismatch causes torchaudio to fail on every
file load, silently emptying clips. Add _load_wav() that tries torchaudio
first then falls back to soundfile (handles wav/flac without ffmpeg).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:41:59 +02:00
Ethanfel 81ff0d46c9 fix(bigvgan-trainer): resolve device mismatch in _save_sample after offload
After the finally block, offload_to_cpu moves the vocoder to CPU while
ref_mel stays on GPU. Fix: detect vocoder's current device via
next(vocoder.parameters()).device and move ref_mel there before vocoding.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:35:07 +02:00
Ethanfel 9fdeb65182 feat(bigvgan-trainer): add eval samples at checkpoints and end
Saves baseline.wav (ground truth roundtrip before training), stepN.wav
at each save_every checkpoint, and final.wav after training completes.
All use the same fixed reference segment (clip 0, position 0) for
direct comparison across checkpoints.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:30:34 +02:00
Ethanfel 790a53e3df fix(bigvgan): add 44k/BigVGANv2 support to trainer and loader
44k variants use BigVGANv2 directly as the vocoder (no wrapper, no
@inference_mode decorator), accessible at feature_utils.tod.vocoder.
16k wraps BigVGANVocoder inside BigVGAN, accessed at .vocoder.vocoder.
Both trainer and loader now branch on model["mode"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:28:32 +02:00
Ethanfel 9c784b4bdb feat: add BigVGAN vocoder fine-tuner and loader nodes
Spectral-loss-only fine-tuning of the BigVGAN vocoder (mel→waveform)
on BJ audio clips. DiT and VAE are completely frozen. Losses: mel L1
reconstruction + multi-resolution STFT magnitude L1 (same three
resolutions as the BigVGAN discriminator config). Saves in
{'generator': state_dict} format compatible with the original BigVGAN
checkpoint. Loader replaces vocoder weights in the loaded SELVA_MODEL
in-place so no full model reload is needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:26:12 +02:00
Ethanfel 115a0c3718 feat(steering): conditional-only injection + per-position vectors
Two improvements for stronger steering effect:

1. Apply steering only during the conditional predict_flow pass by
   monkey-patching predict_flow to set a flag via identity check
   (cond is conditions). Hooks skip the unconditional pass, so
   steering is amplified by cfg_strength (~4.5x) instead of canceling
   out in the CFG guidance term.

2. Restore per-position [seq, hidden] steering vectors instead of
   seq-averaged [hidden]. More spatially specific — captures positional
   activation patterns rather than a global mean. Seq length mismatch
   at inference time handled via linear interpolation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 01:02:51 +02:00
Ethanfel 95923cdf42 feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual
inversion. Extractor runs frozen generator on dataset with BJ vs empty
conditions, records mean hidden-state delta per block, saves [hidden_dim]
vectors (seq-averaged so they broadcast to any inference duration). Loader
reads the bundle. Sampler registers forward hooks during the ODE that add
strength × vec to each block output, cleaned up in a finally block.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:38:26 +02:00
Ethanfel 28ee3db337 feat(sampler): add ti_strength blend for TI injection
TI via text conditioning produces buzz because SelVA's text path is
mean-pooled into a global DiT bias — not rich per-token cross-attention
like SD. The optimizer learns a constant spectral artifact rather than
semantic style shift.

ti_strength=1.0 (default) = full injection as before.
ti_strength<1.0 = lerp between original and injected text_clip,
allowing the effect to be dialled back without retraining.
Applies to both text_clip and neg_text_clip symmetrically.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:07:57 +02:00
Ethanfel b89167cfae fix(ti-trainer): clamp token norm to CLIP manifold to prevent buzz artifacts
Diagnosis: learned tokens grew to norm ~3.2 while real CLIP content tokens
sit at ~1.0. Model never trained on embeddings that large — activates buzz
artifact instead of semantic style shift.

Fix: measure mean token norm from content positions (1–20) of dataset CLIP
embeddings at startup, clamp learned_tokens per-token after every optimizer
step to max 1.5× that reference (50% headroom). Token norm is now logged
as current/limit for easy monitoring.

ti_sweep_1.json: rebuild around norm_clamp group — n4_clamped (primary
diagnostic), prefix_clamped, n8_prefix_clamped, warm_clamped.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:54:23 +02:00
Ethanfel f9d092158a fix(ti): lower default lr/batch, add lr_batch sweep group
n4_baseline showed token_norm growing linearly without plateau — classic
sign of lr too high relative to parameter count. With only K×1024 params,
gradient signal per param is already high-magnitude; high lr causes
overshoot rather than convergence.

- Default lr: 1e-3 → 2e-4 (matches LoRA working regime)
- Default batch_size: 16 → 4 (more diverse gradients, helps norm saturate)
- ti_sweep_1.json: add lr_batch group (lr_low_b4, lr_mid_b8,
  lr_low_b4_prefix, lr_2e3), restructure with clearer groups,
  annotate n4_baseline as completed with findings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:42:22 +02:00
Ethanfel 92535deab2 fix(ti-scheduler): save comparison image after each completed experiment
Previously the comparison PNG was only written at the very end of the sweep,
so an interrupted run produced no image at all. Now _save_comparison() is
called right after _write_summary() for every successful experiment, keeping
loss_comparison.png current throughout the sweep.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:39:30 +02:00
Ethanfel 0b24207ca5 feat(ti-trainer): generate baseline.wav once before training starts
Saves baseline.wav + baseline.png in the checkpoint dir using the same
seed as the TI eval samples — direct A/B comparison at every checkpoint
without re-generating the baseline each time.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:33:28 +02:00
Ethanfel e1a2f0ed7d feat: add inject_mode (suffix/prefix) to TI pipeline
Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps),
token_norm grew linearly without plateau — generator likely ignores last-K
CLIP positions (EOS/padding zone) where suffix injects.

Fix: add inject_mode parameter throughout the pipeline:
- "suffix": replace last K positions (original behavior, model may ignore)
- "prefix": replace positions 1:1+K right after BOS — highest attention
  weight in CLIP, much stronger gradient signal expected

Changes:
- selva_textual_inversion_trainer.py: _inject_tokens() helper centralises
  the torch.cat construction for both modes; used in training loop and eval;
  inject_mode stored in checkpoint files
- selva_textual_inversion_loader.py: reads inject_mode from checkpoint,
  includes in TEXTUAL_INVERSION bundle
- selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field
- selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and
  _train_inner call
- ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm);
  n4_baseline marked completed; suffix experiments retained for comparison

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:31:52 +02:00
Ethanfel f96265da23 feat(ti-trainer): add loss curve IMAGE output
Reuses _draw_loss_curve + _smooth_losses + _pil_to_tensor from the LoRA
trainer — raw loss in light blue, smoothed overlay in blue, matches the
LoRA trainer's visual style.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:20:44 +02:00
Ethanfel c0d95ce356 feat: add ti_sweep_1 experiment file
First TI sweep covering the three most impactful axes:
- token_count group: n_tokens 4 / 8 / 16 (capacity vs overfitting)
- learning_rate group: 5e-4 / 1e-3 / 2e-3 with n_tokens=4
- warm_init group: n4 and n8 seeded from 'mechanical impact sound design'

7 experiments total, 3000 steps each, same data_dir as LoRA sweeps.
n4_baseline (lr=1e-3, random init) is the primary reference point.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:14:31 +02:00
Ethanfel e37bfe1b1c feat: add SelVA TI Scheduler for sweep-based textual inversion experiments
- SelvaTiScheduler: runs a JSON-defined sweep of TI training experiments,
  loading the dataset once and reusing it across runs
- Collects per-experiment loss history, final/min loss, stability metric
  (loss_std_last_quarter), and duration — written to experiment_summary.json
  after each completed run so partial sweeps survive interruption
- Resume-aware: skips experiments already marked completed in an existing
  summary file
- Outputs smoothed loss comparison chart (same axes, one curve per experiment)
- SelvaTextualInversionTrainer._train_inner now returns a dict
  {embeddings_path, loss_history} so the scheduler can read results;
  train() extracts just the path for ComfyUI

JSON format: name, description, data_dir, output_root, base config,
experiments list with id + param overrides

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:13:04 +02:00
Ethanfel bb07bc8169 fix(ti-trainer): guard spectral metrics, drop unused imports
- Wrap _spectral_metrics + _save_spectrogram in try-except so a matplotlib
  or STFT error doesn't abort the checkpoint save (matches LoRA trainer)
- Remove unused `import math` and `_pil_to_tensor` import
- Drop dead `img` variable (_save_spectrogram returns None)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:10:19 +02:00
Ethanfel e36cdd7947 fix(ti-trainer): fix gradient flow and spectral metric shapes
- Replace in-place text_clip assignment with torch.cat so the computation
  graph correctly links text_input → learned_tokens; in-place assignment
  into a requires_grad=False leaf severs the graph and learned_tokens
  receives no gradients
- _spectral_metrics(wav, sr): was passing wav.unsqueeze(0) [1,1,L] instead
  of wav [1,L]; stft mean(dim=1) would return wrong shape [1,T] not [n_freqs]
- _save_spectrogram(wav, sr, ...): was passing wav.squeeze(0) [L] (1D)
  instead of wav [1,L] as the function expects

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:08:13 +02:00
Ethanfel e56ece9c1c feat: add SelVA Textual Inversion Trainer and Loader nodes
Learns K CLIP token embeddings ([K, 1024]) with all model weights frozen,
keeping generated latents on the decoder's natural manifold — avoids the
quality degradation that affects LoRA on BJ's audio dataset.

- selva_textual_inversion_trainer.py: trains learned_tokens via AdamW,
  injects into last K positions of 77-token CLIP embedding, checkpoints
  with eval audio + spectral metrics
- selva_textual_inversion_loader.py: loads .pt bundle, returns
  TEXTUAL_INVERSION dict for sampler
- selva_sampler.py: optional textual_inversion input; injects into both
  text_clip and neg_text_clip before preprocess_conditions
- __init__.py: registers both new nodes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 23:01:44 +02:00
Ethanfel eed7eefeac feat: add SelVA HF Smoother and Spectral Matcher preprocessing nodes
Two ComfyUI nodes to reduce domain mismatch between custom training audio
and the MMAudio VAE's expected spectral distribution:

SelvaHfSmoother: blends a low-pass filtered copy (biquad) with the original
at a configurable cutoff and blend ratio. Attenuates extreme HF content that
BigVGANv2 handles poorly. RMS-preserving.

SelvaSpectralMatcher: computes the log-mel energy profile of the clip,
compares it per-band to the VAE's normalization means (DATA_MEAN_80D/128D),
and applies a smooth STFT-domain gain correction to match the codec's training
distribution. Configurable strength and max_gain_db clamp. RMS-preserving.

Recommended workflow: SpectralMatcher → HfSmoother → feature extraction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 20:28:16 +02:00
Ethanfel 107bb05f17 fix(vae-roundtrip): pass bigvgan path to encoder-only FeaturesUtils
AutoEncoderModule unconditionally asserts vocoder_ckpt_path is not None
even when need_vae_encoder=True. Pass best_netG.pt to satisfy the assert;
the vocoder weights are not actually used since decode+vocode go through
model["feature_utils"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 20:05:44 +02:00
Ethanfel 10e6095e31 fix(vae-roundtrip): use model feature_utils for decode, add normalize/unnormalize, normalize output
- Load fresh FeaturesUtils only for encoding; use model["feature_utils"] for
  decode+vocode to mirror the exact path the sampler takes
- Apply generator.normalize() → unnormalize() around the encoded latent so the
  decoder receives latents in the same space it expects from inference
- Log both encoded and norm→unnorm latent stats to diagnose round-trip fidelity
- Normalize output to -27 dBFS (matching training clip RMS) and clamp to [-1, 1]
  to prevent clipping artifacts in the output waveform

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:50:01 +02:00
Ethanfel 528d33be39 fix: trim/pad latent to seq_cfg.latent_seq_len before decoding
Without this the decoder produced 7s instead of 8s due to STFT rounding.
Same fix as _prepare_dataset uses for training data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:22:09 +02:00
Ethanfel 8195c3114a feat: add SelVA VAE Roundtrip node
Encodes audio through the VAE then decodes straight back, bypassing the
diffusion model entirely. Use this to isolate whether saturation artifacts
are introduced by the codec reconstruction (VAE/DAC) or by the LoRA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:15:20 +02:00
Ethanfel c8e6b91f67 feat: add alpha_scale_sweep to fix LoRA noise contamination
Previous sweep used alpha=rank (scale=1.0) which at rank 128/256 drowned
base model priors — spectral flatness went from 0.013 (baseline) to 0.094.
This sweep tests alpha dramatically below rank across r16/r32/r128 to find
the scale where LoRA nudges rather than overwrites.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:55:05 +02:00
Ethanfel fdce9cbbf1 feat: evaluate adapters on all dataset clips, not just clip_001
- _eval_sample gains clip_idx param (default 0, backward compatible)
- Evaluator loops over all dataset clips per adapter, saves one WAV per clip
- Reference metrics computed for all clips and averaged
- Comparison chart and summary use avg_metrics across all clips
- Eliminates bias from evaluating on an unrepresentative single clip

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:42:55 +02:00
Ethanfel 42ceb4b153 fix: preserve original audio extension when copying reference file
shutil.copy2 was writing FLAC binary to reference.wav — unplayable.
Now copies as reference{.flac/.wav/etc} matching the source extension.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:31:26 +02:00
Ethanfel 4505b89db1 feat: add reference audio to LoRA evaluator
Loads the first clip's original audio (same clip used for inference),
copies it to output_dir/reference.wav, runs spectral metrics and
saves a spectrogram. Appears first in the comparison chart so generated
samples can be judged against the target sound.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:30:33 +02:00
Ethanfel dbfa7b23fe feat: add eval_r128_candidates.json
Evaluates top 5 adapters from r128_sweet_spot: baseline, lr_5e4_r128,
lr_3e4_r256, lr_3e4_r128, curriculum_lr_3e4 final + step 6000 checkpoint
(before regression) for spectral comparison.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:28:28 +02:00
Ethanfel d2e1ea7b80 feat: add SelVA LoRA Evaluator node
Generates audio samples from a list of adapters against a fixed reference
clip, collects spectral metrics for each, and outputs a comparison bar
chart + eval_summary.json. Useful for comparing sweep candidates before
committing to a next round of training.

JSON format: name, data_dir, output_dir, steps, seed, adapters[{id, path}].
Empty path = baseline (no LoRA).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:26:50 +02:00
Ethanfel 9a47508d2d fix: lower RMS normalization target from -23/-20 to -27 dBFS
Training clips at -23 LUFS measure -25 to -31 dBFS RMS (avg ~-27).
Normalizing output to -23 dBFS was 4-8 dB too loud, causing saturation
on clips with high crest factor and peaks near 0 dBFS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 17:19:20 +02:00
Ethanfel 678c050f11 fix: make normalize(x1) assignment explicit in training loop
normalize() uses in-place ops so it worked, but reading the return value
makes the intent clear and guards against future refactors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 15:43:42 +02:00
Ethanfel 1be07a80d2 feat: add cosine LR decay schedule to trainer and scheduler
- Add lr_schedule param (constant|cosine) to SelvaLoraTrainer
- Cosine decays LR from initial value to ~0 after warmup, preventing
  the oscillation observed at steps 6000-8000 with lr=2e-4 flat
- Wire lr_schedule through scheduler _PARAM_DEFAULTS and _train_inner call
- Add g5_r128_lr_2e4_cosine and g5_r128_lr_3e4_cosine to r128_sweet_spot sweep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:25:01 +02:00
Ethanfel 58e1985af2 feat: SelVA Skip Experiment node + save partial scalars on skip
- New node: SelVA Skip Experiment — writes skip_current.flag from UI,
  queue in a second workflow tab while scheduler is running
- SkipExperiment now attaches partial loss/grad/spectral data to the
  exception so the scheduler saves all collected scalars in the summary

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:10:43 +02:00
Ethanfel 264dc49d42 feat: skip_current.flag to cancel experiment and move to next
Create the flag file in the sweep output_root to skip the running
experiment at the next log interval (every 50 steps):
  touch /path/to/experiment/skip_current.flag

Scheduler marks it as 'skipped' in the summary and continues.
Skipped experiments are NOT resumed on restart (unlike failed ones).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 13:09:01 +02:00
Ethanfel fec5c86f09 feat: add spectral_flatness and temporal_variance to eval metrics
spectral_flatness (Wiener entropy) — 0=tonal, 1=white noise.
Rising value across steps directly flags noise contamination.
temporal_variance — RMS std/mean per frame. Low = lifeless/compressed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:45:40 +02:00
Ethanfel 2861327016 feat: spectral metrics per eval sample in experiment summary
Computes hf_energy_ratio (>4kHz), spectral_centroid_hz, spectral_rolloff_hz
at each save_every checkpoint. Logged to console and stored in
experiment_summary.json under results.spectral_metrics[step].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:44:43 +02:00
Ethanfel c4687521ef feat: save spectrogram PNG alongside each eval sample
Log-frequency dB spectrogram (inferno colormap, 100Hz–16kHz) saved as
step_XXXXX.png next to step_XXXXX.wav in samples/ subfolder.
Makes high-frequency rolloff (low bitrate signature) immediately visible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:42:34 +02:00
Ethanfel 8717af2728 fix: prevent saturation from RMS normalization clipping peaks
RMS normalize to target then scale back if peaks exceed 1.0,
preserving dynamics instead of hard-clipping transients.
Eval sample target updated to -23 dBFS to match training data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:29:29 +02:00
Ethanfel 78e9838a83 fix: replace peak normalization with RMS normalization at -20 dBFS
Peak norm was slamming output to full scale regardless of content level,
making generated audio several times louder than training clips.
RMS norm to -20 dBFS matches typical processed audio level.
Sampler exposes target_lufs (-40 to -6, default -20) for user control.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 12:06:48 +02:00
Ethanfel 94610b8943 feat: r128_sweet_spot sweep — noise-free LR search + rank 256
9 experiments targeting loss 0.25-0.35 without LoRA+ noise.
Tests higher base LR (2e-4/3e-4/5e-4), curriculum combos, conservative
LoRA+ ratio=4, and rank 256 baseline + lr=3e-4.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:46:08 +02:00
Ethanfel f5f7f2ae68 fix: eval sample seed 0 -> 42
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:32:43 +02:00
Ethanfel 1663b39833 fix: bump eval sample to 25 ODE steps (was 8)
Inference is fast on RTX PRO 6000 — 8 steps was washing out quality
differences between experiments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:32:27 +02:00
Ethanfel a7923d5fb7 feat: r64_overnight sweep — focused rank-64 ablation at 8000 steps
15 experiments across rank (64/128), alpha, regularisation, LR, target
layers, and combined stacks. Based on tier1_thorough early results
confirming rank 64 sounds best perceptually.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 01:32:23 +02:00
Ethanfel 786a57c424 feat: sweep resume + 5 additional experiments (LR, target, extended)
Scheduler: on re-run, reads existing experiment_summary.json and skips
already-completed experiments — safe to stop and restart mid-sweep.

tier1_thorough: adds g5 (lr 3e-5/3e-4), g6 (full target attn.qkv+linear1
at r16 and r64), and g4_full_r64_6k (6000-step extended run) — 17 total.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:59:16 +02:00
Ethanfel f15e02b0b8 fix: eval samples use fixed clip/seed, save to samples/ subfolder
- Always sample dataset[0] with fixed noise seed so checkpoints are
  directly comparable (hear the model improve step by step)
- Save to output_dir/samples/step_XXXXX.wav instead of alongside checkpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:54:37 +02:00
Ethanfel 0682a536cb fix: point data_dir to features/ subdir where .npz and audio live
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:45:32 +02:00
Ethanfel 0000878e76 feat: thorough overnight sweep + dataset browser updates
- Dataset browser: audio/features now resolve through features/ subdir
- tier1_sweep.json: update data_dir to BJ dataset path
- tier1_thorough.json: 12-experiment overnight sweep across 4 groups
  (rank 16/32/64, alpha scaling, LoRA+/dropout/curriculum isolation,
  full Tier 1 stack at r16 and r64) — output to BJ/experiment/

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:38:19 +02:00
Ethanfel 675644189d feat: add SelVA Dataset Browser node
Companion node for inspecting dataset.json entries by integer index.
Outputs video (.mp4), audio (.wav/.flac), features (.npz), frames dir,
mask dir, label, and max_index for constraining the index widget range.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:55:27 +02:00
Ethanfel 82fb7a0009 docs: note AudioX shows no perceptual quality gain on V2A vs SelVA
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 09:12:00 +02:00
Ethanfel af4777d2d7 docs: add AudioX vs SelVA evaluation
Architecture comparison, capability matrix, integration cost estimate,
LoRA training difficulty analysis, and license implications.
Verdict: SelVA remains preferred for V2A + LoRA fine-tuning; AudioX
adds value for music generation, inpainting, and text-to-audio tasks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 09:11:09 +02:00
Ethanfel ed8abf7a5b docs: add video format recommendations to dataset preparation section
New section 1.1 covers aspect ratio (16:9 landscape preferred), resolution
(≥480p), frame rate (any, use VHS_VIDEOINFO), and portrait handling
(center-crop to square). Based on CLIP 384px and Synchformer 224px internals.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:44:14 +02:00
Ethanfel 21ed93d3ee docs: add audio dataset pipeline reference doc
Full research notes on cleaning, augmentation, and quality metrics for
generative model training. Covers LUFS normalization, AudioSep, waveform
augmentation (pitch shift, RIR, EQ), latent mixup, DNSMOS gating, tool
install commands, and key paper references.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:37:48 +02:00
Ethanfel f1e2bbd55b feat: add first experiment sweep file for Tier 1 ablation
6 experiments: baseline, LoRA+ (ratio=16), dropout 0.05, dropout 0.1,
curriculum sampling, and all three combined. bf16 batch 16, 2000 steps,
seed 42. data_dir placeholder needs to be updated before running.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:15:06 +02:00
Ethanfel 3d9221c248 fix: three bugs in scheduler and trainer
- trainer: raise ValueError early when remaining steps < log_interval (50)
  instead of UnboundLocalError on smoothed_img/final_path at return
- trainer: use None in grad_norm_history instead of silent 0.0 when
  grad_accum > log_interval and no optimizer step fired in the interval
- trainer: include start_step in _train_inner return dict
- scheduler: use start_step from result dict for min_loss_step and
  loss_at_steps (fixes wrong step labels on resumed experiments)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:11:25 +02:00
Ethanfel 2d200395af feat: add grad norm logging and richer experiment summary output
trainer:
- Track gradient norm before clipping at each optimizer step
- Log avg grad_norm per log_interval alongside loss in console output
- Include grad_norm_history in _train_inner return dict

scheduler:
- Add system block to summary (GPU name, VRAM, torch/CUDA version)
- Include full loss_history and grad_norm_history arrays in each
  experiment result (50-step resolution, not just save_every checkpoints)
- Add loss_std_last_quarter stability metric (std dev of raw loss over
  last 25% of steps — high value indicates unstable training)
- Add log_interval field so consumers know the x-axis resolution

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:06:39 +02:00
Ethanfel 3ec380a27e feat: add SelVA LoRA Scheduler node for automated experiment sweeps
- Extract _prepare_dataset() from SelvaLoraTrainer.train() as a module-level
  function so the dataset can be encoded once and reused across experiments
- Change _train_inner() return value from tuple to dict (adds loss_history,
  meta, completed; train() unpacks for ComfyUI — no change to node outputs)
- New SelvaLoraScheduler node: reads a JSON sweep file, runs N experiments
  sequentially, writes experiment_summary.json (updated after each run) and
  loss_comparison.png with all smoothed curves overlaid on the same axes
- Register SelvaLoraScheduler in nodes/__init__.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 13:03:21 +02:00
Ethanfel 9bc2568543 docs: document LoRA dropout, LoRA+, and curriculum timestep sampling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 12:45:53 +02:00
Ethanfel eb63c1ead7 feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights),
  0.05–0.1 helps regularize on small datasets (arXiv:2404.09610)
- LoRA+: separate optimizer param groups for lora_A and lora_B with
  configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354)
- Curriculum mode: logit_normal for first N% of steps then uniform,
  directly addresses early convergence + fine-detail degradation at
  boundaries (arXiv:2603.12517)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 12:43:18 +02:00
Ethanfel 5baa070e61 docs: add observations section with fp32/batch/precision findings
Work-in-progress empirical notes: fp32 batch 32 reaches same quality as
bf16 batch 16 in 1/3 the steps but overfits past ~2000 steps on 10 clips.
Lower loss does not reliably mean better audio on small datasets.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 02:34:53 +02:00
Ethanfel 9fc739fe9e docs: add prompt guide and masking note to dataset preparation section
Poor prompts and missing masks are a common source of white noise in LoRA
training — imprecise sync features force the adapter to compensate with noise.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:43:28 +02:00
Ethanfel 57fae4a8ce chore: default timestep_mode back to uniform
logit_normal reaches lower loss but perceptual improvement over uniform
is dataset-dependent. Keeping uniform as default to match original MMAudio
training behavior; logit_normal remains available as an option.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:21:08 +02:00
Ethanfel 8e919c0459 fix: resolve relative and Unix-style output_dir paths to ComfyUI output folder
On Windows, /folder is drive-relative (no drive letter) rather than a real
absolute path. Redirect these to ComfyUI's output directory so files don't
land at C:\folder. Also redirects plain relative paths (e.g. lora_output)
to output/ instead of the process working directory.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:14:04 +02:00
Ethanfel fec8eaac95 fix: save adapter and loss curves on cancel, not only on normal completion
Wraps training loop in try/finally so adapter_final.pt and loss PNGs are
always written. On cancellation the adapter is named
adapter_cancelled_stepXXXXX.pt so it can be used with --resume.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 01:06:44 +02:00
Ethanfel d83632e754 fix: pad/trim clip and sync features to fixed seq_len at dataset load time
Clips from shorter videos produce fewer CLIP frames (e.g. 2s → 16 frames,
8s → 64 frames). Mixed-length datasets would cause torch.stack() to fail
during batching. Normalize to seq_cfg.clip_seq_len / sync_seq_len at load,
same as latents are already normalized to latent_seq_len.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:54:05 +02:00
Ethanfel a5014e49eb feat: add logit-normal timestep sampling to reduce white noise artifacts
Uniform timestep sampling undertrained t>0.8 (the final denoising steps),
leaving residual noise that CFG amplifies at inference. Logit-normal sampling
concentrates training near t=0.5 while still covering the full range, improving
high-t coverage and reducing noise floor in generated audio.

Default changed from uniform to logit_normal (sigma=1.0). Previous behavior
available with timestep_mode=uniform.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:35:42 +02:00
Ethanfel 8ae0ba3c7d fix: increment adapter_final filename on resume to avoid overwriting previous final
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:15:31 +02:00
Ethanfel 2b2b438307 fix: set OUTPUT_NODE=True on SelVA Feature Extractor so it runs without connected outputs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:11:16 +02:00
Ethanfel 39984f73c2 docs: add observed batching results to training guide
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:05:16 +02:00
Ethanfel 1f8cd6f930 docs: rewrite LORA_TRAINING.md with real-world findings
- Added batch_size VRAM table and updated step recommendations for batched training
- Added adapter strength section with practical guidance (0.6-0.7 for noise)
- Added ComfyUI node as Option A for training (not just CLI)
- Noted .mp3 as not recommended, soundfile fallback implied
- Added output files section with sample_*.wav and loss curve PNGs
- Added "LoRA has no effect" troubleshooting (wrong node wired)
- Updated loss convergence targets based on observed training runs
- Clarified linear1 target: 150+ clips recommended

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 00:00:36 +02:00
Ethanfel 20f8138146 chore: show batch_size in training step log
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:45:43 +02:00
Ethanfel 09b3b94ddd feat: add batch_size parameter to training (default 4)
Replaces single-sample steps with batched sampling via random.choices().
Tensors are stacked to [B, T, C] before the forward pass; t is now [B].
Default grad_accum lowered to 1 since real batching gives stable gradients.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:36:12 +02:00
Ethanfel 3f67de694c feat: save loss_raw.png and loss_smoothed.png to output_dir
Raw curve shown in light blue, EMA-smoothed (beta=0.9) overlay in darker
blue. Both saved as PNG at end of training. The node IMAGE output now
returns the smoothed version. Live preview also uses the smoothed overlay.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:15:48 +02:00
Ethanfel 423e174b88 debug: print lora_A norm after loading to confirm adapter applied
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 23:05:23 +02:00
Ethanfel 4806daa4ca chore: lower default warmup_steps from 500 to 100
500 warmup steps is 25% of a 2000-step run — too long. 100 steps lets
the full lr kick in much earlier without sacrificing stability.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:51:27 +02:00
Ethanfel 16b3eb11cc fix: pass max_size=800 to progress bar preview (was 85px wide)
The third element in ComfyUI's preview tuple is max_size in pixels, not
JPEG quality. Passing 85 was capping the live loss curve at 85×40px.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:48:56 +02:00
Ethanfel 004ea63f62 fix: fall back to soundfile for torchaudio.save when torchcodec unavailable
Same torchcodec/FFmpeg issue as the load path, now on the eval sample save.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:44:04 +02:00
Ethanfel afb3242eca fix: disable inference_mode entirely for training via inference_mode(False)
torch.enable_grad() alone is insufficient: operations on inference tensors
(created inside ComfyUI's outer inference_mode context) produce inference
tensors even inside enable_grad, breaking autograd. inference_mode(False)
exits the inference context so the deepcopy, apply_lora, and training loop
run with a fully clean autograd context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:40:50 +02:00
Ethanfel 849f31e2a6 fix: create LoRA params inside torch.enable_grad() to escape inference_mode
torch.enable_grad() re-enables grad tracking but nn.Parameters created while
torch.inference_mode() is active are inference tensors that can't enter autograd
regardless. Splitting into _train_inner() and calling it inside enable_grad()
ensures the deepcopy, apply_lora, and the training loop all run with a clean
autograd context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:36:28 +02:00
Ethanfel 505d445eb3 fix: wrap training loop in torch.enable_grad()
ComfyUI executes all nodes inside torch.no_grad(), which prevents gradient
tracking and makes loss.backward() fail. torch.enable_grad() overrides it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:32:00 +02:00
Ethanfel 8fade1b0e3 fix: initialize LoRA params on same device as wrapped linear
apply_lora() is called after generator.to(device), so lora_A/lora_B were
being created on CPU while the rest of the model was on CUDA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:17:29 +02:00
Ethanfel ad57432803 fix: pad/trim latent to exact latent_seq_len after VAE encoding
STFT hop-size rounding produces ±1 latent frame vs the expected seq length.
Clamp to seq_cfg.latent_seq_len after transpose so generator.forward assertion passes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:12:20 +02:00
Ethanfel 43f732f904 fix: transpose VAE latent from [B,C,T] to [B,T,C] before generator
VAE encoder returns channels-first [B, latent_dim, T]; the generator
expects time-first [B, T, latent_dim] (same convention as decode which
already does .transpose(1,2)). Fixes normalize() size mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:08:00 +02:00
Ethanfel 6b9adf0816 fix: fall back to soundfile when torchcodec FFmpeg libs are missing
Recent torchaudio defaults to torchcodec as the audio backend, which requires
FFmpeg shared libraries. Falls back to soundfile for envs where torchcodec
can't load (e.g. containerised ComfyUI without system FFmpeg).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 22:03:57 +02:00
Ethanfel 52434a053a fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure
torch.stft requires float32 input — casting vae_utils to bf16 caused silent
failures during dataset pre-loading. Also adds traceback.print_exc() so future
clip-load errors are visible in the ComfyUI log.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 21:57:20 +02:00
Ethanfel 56c8d5d6b4 feat: save eval audio sample alongside each checkpoint
At every save_every steps, run a quick 8-step no-CFG inference pass on
a random training clip and save the decoded waveform as
sample_stepXXXXX.wav next to the checkpoint. Uses the existing
generator.unnormalize + feature_utils.decode + vocode pipeline from
the sampler. Failure is non-fatal (logged and skipped).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 21:47:02 +02:00
Ethanfel b430953602 feat: live loss curve preview during training
- Send updated loss curve to ComfyUI frontend every 50 steps via
  pbar_train.update_absolute() with a JPEG preview tuple — same
  mechanism as KSampler's denoising previews.
- Fix x-axis step labels for resumed runs (previously always started
  at 0; now correctly shows start_step + offset).
- Split _draw_loss_curve (returns PIL Image) from _pil_to_tensor
  (converts for ComfyUI IMAGE output).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:11:38 +02:00
Ethanfel 57cd3dd4b4 fix: use load_lora for resume and remove redundant inference_mode wrapper
- Resume now calls load_lora() instead of load_state_dict() directly,
  giving proper warnings for missing/unexpected LoRA keys.
- Remove redundant `with torch.inference_mode():` around encode_audio
  (already @inference_mode decorated); dist.mode().clone() pattern
  is now clearer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:09:35 +02:00
Ethanfel f206a1b38c feat: add SelVA LoRA Trainer ComfyUI node
Runs the full training loop inside ComfyUI. Reuses the already-loaded
CLIP model from the inference model for text encoding; loads only a
minimal VAE encoder separately (freed after dataset pre-loading).

Outputs:
- SELVA_MODEL with LoRA applied (ready to connect directly to Sampler)
- adapter_path STRING (for SelVA LoRA Loader in future sessions)
- loss_curve IMAGE (PIL-rendered line chart of training loss per 50 steps)

Progress is shown via ComfyUI ProgressBar (two phases: dataset loading,
then training steps). Resume is supported via resume_path input.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 17:07:38 +02:00
Ethanfel 2f4641247a feat: add resume support to train_lora.py
Step checkpoints now save optimizer state, scheduler state, and step
number alongside the LoRA weights. Pass --resume path/to/adapter_stepXXXXX.pt
to continue training from that checkpoint. --steps always means total steps,
so resuming from 1000 with --steps 2000 trains 1000 more steps.

adapter_final.pt format is unchanged (state_dict + meta only) so
SelvaLoraLoader remains compatible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 16:59:30 +02:00
Ethanfel 8e9114b92c docs: add clip length and scalable dataset size recommendations
- Clip length section: fixed 8s duration, padding/trim behavior, per-sound-type
  strategies (continuous, short events, repeating, onset placement).
- Dataset size table: 5-10 / 15-30 / 30-60 / 60-150 / 150-300 / 300+ clips
  with scenario and expected result for each tier.
- Note on diversity vs quantity.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 16:34:50 +02:00
Ethanfel 63b4391573 fix: named .npz files always start at _001
dog_bark_001.npz, dog_bark_002.npz instead of dog_bark.npz, dog_bark_001.npz.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:44:26 +02:00
Ethanfel 89af5a468c docs: add LoRA training guide
Covers dataset preparation (ComfyUI feature extraction + clean audio),
training CLI reference, tuning guide (rank/steps/lr), adapter loading
in ComfyUI, and troubleshooting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:43:09 +02:00
Ethanfel c88e27742c fix: sanitize name field and remove double load_npz call
- _resolve_named_path: replace / \ and null in name to prevent path
  traversal outside cache_dir (would cause a confusing FileNotFoundError
  at np.savez time instead of at path resolution).
- train_lora: load_npz was called twice per clip when prompt was in
  prompts.txt; consolidate to a single call before prompt resolution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:30:25 +02:00
Ethanfel cbcd154c96 feat: add name field with auto-increment to SelvaFeatureExtractor
When name is provided, features are saved as name.npz (or name_001.npz,
name_002.npz etc. if the file already exists) instead of a content hash —
useful for building a named training dataset. Hash-based caching is
unchanged when name is left empty.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:16:51 +02:00
Ethanfel 1eb82d8050 refactor: train_lora accepts .npz + audio pairs instead of raw video
- Input is now pre-extracted .npz files (from SelvaFeatureExtractor) paired
  with clean audio files (same stem). Visual features no longer re-extracted
  during training.
- FeaturesUtils loaded with enable_conditions=False (VAE only) — Synchformer
  and T5 are no longer loaded, saving ~3-4 GB VRAM.
- CLIP text encoder loaded separately via patch_clip so text prompt can differ
  from the one used during feature extraction.
- Prompt priority: prompts.txt override > embedded in .npz > directory name.
- Removed: torchvision video loading, frame sampling/resizing, net_video_enc,
  synchformer path check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:14:26 +02:00
Ethanfel cde280049b fix: correct LoRALinear dtype and remove unused import
- LoRALinear now creates lora_A/lora_B with dtype matching the base
  linear's weight, preventing a float32/bf16 mismatch at forward time
  when the generator is loaded in bf16 or fp16.
- Remove unused `import math` from train_lora.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:57:09 +02:00
Ethanfel 437c62b28f feat: LoRA fine-tuning for SelVA generator
Teaches the model new/partial sound classes from custom video+audio pairs.
Only ~10 MB of adapter weights are trained vs ~4.4 GB for the full model.

selva_core/model/lora.py
  LoRALinear: wraps nn.Linear with frozen base + trainable A/B matrices.
  B initialised to zero → zero adapter contribution at init.
  apply_lora(): walks named_modules, replaces matching nn.Linear in-place.
  Default target: "attn.qkv" (all 21 SelfAttention QKV projections in
  large_44k). Add "linear1" to also wrap post-attention output projections.
  get_lora_state_dict() / load_lora() for ~10 MB save/load.

train_lora.py (standalone script, no ComfyUI dependency)
  Data format: directory of video files + optional prompts.txt
  ("filename: description"). Falls back to directory name as prompt.
  Pre-extracts features for all clips into RAM, then trains from those.
  Training loop: encode audio→latent (need_vae_encoder=True), flow
  matching MSE loss on velocity prediction, backward on LoRA params only.
  Saves adapter_stepNNNNN.pt checkpoints + adapter_final.pt with metadata.
  Key verified interfaces used:
    encode_audio() → DiagonalGaussianDistribution; .mode().clone() required
    normalize() is in-place
    forward(latent, clip_f, sync_f, text_f, t) takes raw tensors

nodes/selva_lora_loader.py (SelVA LoRA Loader ComfyUI node)
  Loads .pt adapter, deep-copies the generator, applies LoRA, loads weights.
  strength param scales lora_B to adjust adapter contribution at inference.
  Reads rank/alpha/target from embedded metadata if present.
  Returns a patched SELVA_MODEL bundle for use with the existing Sampler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:38:46 +02:00
Ethanfel b519b042e2 docs: document mask inputs and normalize toggle in README
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 10:43:42 +02:00
Ethanfel f28759f1e3 feat: improve mask support with neutral fill, mask_strength, and per-path toggles
- Replace zero-fill with neutral gray (0.5) fill so masked background
  pixels stay in-distribution: 0.5 maps to ~0 in CLIP normalized space
  and exactly 0 after sync's [-1,1] normalization
- Add mask_strength float (0–1) for partial background suppression
- Add mask_clip / mask_sync booleans to toggle masking independently
  on the CLIP (384px) and TextSynchformer (224px) encoding paths
- Fix temporal mask sampling: use fps-accurate index formula (same as
  _sample_frames) instead of proportional int(i*M/N)
- Include mask_strength, mask_clip, mask_sync in cache hash when mask
  is connected, so changing any param correctly busts the cache
- Log lines now report masked/skipped state and strength per path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 10:43:01 +02:00
Ethanfel 3dd6badfd9 fix: guarantee offload cleanup on exception with try/finally
Both nodes moved models to GPU before work then back to CPU after.
Any exception (OOM, cancellation, bad input) would skip the cleanup,
leaving models on GPU permanently until ComfyUI restarts.

Wrap the entire work block in try/finally so offload_to_cpu cleanup
always runs regardless of how the node exits. Also removes the unused
`mode` variable in SelvaSampler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 08:40:39 +02:00
Ethanfel 8bb2fb7015 fix: extend OOM catch to decode/vocode, add (masked) to sync log line
- selva_sampler: wrap decode+vocode in their own OOM catch — previously
  OOM during mel decode or vocoding gave a raw CUDA traceback instead
  of the actionable hint
- selva_feature_extractor: sync frames log line now shows (masked) when
  a mask is active, matching the CLIP log line

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 08:38:59 +02:00
Ethanfel f4a7292cde feat: add optional MASK input to SelVA Feature Extractor
Allows per-frame or static segmentation masks to be applied before CLIP
and sync encoding, zeroing background pixels. Useful when multiple objects
compete for the same sound and text prompting alone is insufficient.

- _apply_mask(): resizes mask spatially (nearest-exact), samples temporally
  to match sampled frame count, multiplies into frames
- _hash_inputs(): includes mask bytes in cache key (begin/mid/end sampling)
- INPUT_TYPES: mask added to optional inputs with tooltip
- extract_features(): mask=None parameter, applied after _resize_frames for
  both CLIP (384px) and sync (224px) paths, before normalization
- Log line notes when masking is active

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 08:34:13 +02:00
Ethanfel bd53744e2d feat: comprehensive node improvements
Model Loader:
- bf16 support check — auto-falls back to fp16 on unsupported GPUs
- DESCRIPTION and OUTPUT_TOOLTIPS

Feature Extractor:
- Store variant in features dict and .npz cache
- Progress bar (3 steps: CLIP encode, T5 encode, sync encode)
- Expand cache hash to 32 hex chars
- DESCRIPTION and OUTPUT_TOOLTIPS

Sampler:
- Variant mismatch validation against extracted features
- Cancellation support via throw_exception_if_processing_interrupted()
- OOM catch with actionable error message
- normalize toggle (optional BOOLEAN, default true) for peak normalization
- Remove empty optional: {} block
- DESCRIPTION and OUTPUT_TOOLTIPS

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 18:16:03 +02:00
Ethanfel 429810db5b docs: improve tooltips on all three SelVA nodes
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 18:10:05 +02:00
Ethanfel 57f56c04e2 feat: update demo workflow with VHS_VideoCombine output
- Replace PreviewAudio with VHS_VideoCombine — outputs video+audio together
- Wire fps from FeatureExtractor to VideoCombine frame_rate
- Wire audio from Sampler into VideoCombine
- Clear hardcoded video filename
- Set filename_prefix to SelVA, save_output=true

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 18:07:56 +02:00
Ethanfel ff26d0b87d fix: bug sweep and improvements
- nodes/__init__.py: fix [PrismAudio] leftover label in error print
- selva_feature_extractor: hash beginning, middle and end of video tensor
  instead of just first 1MB, avoiding collisions on videos with same opening frames
- selva_sampler: derive SequenceConfig from model template via dataclasses.replace
  instead of hardcoding sampling_rate/spectrogram_frame_rate per mode

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 18:04:35 +02:00
Ethanfel 83b1da9520 chore: remove all PrismAudio code from main branch
- Delete prismaudio_core/, data_utils/, scripts/, docs/plans/
- Delete PrismAudio nodes (feature_extractor, feature_loader, model_loader, sampler, text_only)
- Delete PrismAudio workflows (video_to_audio, text_to_audio)
- Clean nodes/utils.py: rename PRISMAUDIO_CATEGORY → SELVA_CATEGORY, remove unused helpers
- Strip PrismAudio-only deps from requirements.txt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 17:58:31 +02:00
Ethanfel 679a607a85 feat: wire prompt output from feature extractor to sampler in demo workflow
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 17:13:23 +02:00
Ethanfel d495939367 docs: rewrite README for SelVA
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 17:12:28 +02:00
Ethanfel 982d66e078 chore: remove PrismAudio nodes from selva-integration branch
This branch registers only the three SelVA nodes. PrismAudio nodes stay
on master/feature/lora-trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 17:01:21 +02:00
Ethanfel b4124f58b3 fix: BigVGANv2._from_pretrained() compat with newer huggingface_hub
Newer hf_hub stopped passing proxies/resume_download/local_files_only/token
to _from_pretrained(). Give them defaults so the call doesn't fail when
these kwargs are omitted.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:51:48 +02:00
Ethanfel 2c9d521565 fix: 44k generator HF paths use 44khz suffix (not 44k)
Actual filenames in jnwnlee/SelVA: generator_*_44khz_sup_5.pth.
download_utils.py had the wrong names so those MD5s are unverified — set to
None to skip MD5 check for 44k generators. All other files verified/unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:46:20 +02:00
Ethanfel 28229d62ce fix: MD5 validation on existing files — re-download if corrupt
Previously _ensure() trusted any existing file. Files downloaded by the
broken requests-based code (HTML error pages) would be silently reused.
Now checks MD5 on every load; deletes and re-downloads on mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:42:38 +02:00
Ethanfel 92593189f0 fix: use huggingface_hub for downloads instead of raw requests
download_utils.py used requests without auth — jnwnlee/SelVA returned an
HTML error page which torch then failed to unpickle ('E' / opcode 69).
huggingface_hub.hf_hub_download() handles HF_TOKEN auth automatically,
validates downloads, and retries. Files are still copied to models/selva/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:41:29 +02:00
Ethanfel 614a2e02aa fix: weights_only=False for SelVA checkpoints (PyTorch 2.6 compat)
PyTorch 2.6 changed the default to weights_only=True. SelVA checkpoints
contain non-tensor types (numpy scalars etc.) that fail strict unpickling.
All weights come from trusted sources (jnwnlee/selva HF repo).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:38:31 +02:00
Ethanfel 40388ba6de fix: negative_prompt inline (multiline:false) + VAE filename v1-44.pth not v1-44k.pth
- SelvaSampler: multiline:false puts negative_prompt inline above sliders
- SelvaModelLoader: VAE filenames in download_utils are v1-16.pth/v1-44.pth,
  not v1-{mode}.pth (mode includes the 'k' suffix)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:35:17 +02:00
Ethanfel 789e09535d fix: SelvaSampler — negative_prompt above settings
Move negative_prompt to required inputs, right after prompt, so it appears
above duration/steps/cfg/seed in the ComfyUI node layout.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:31:53 +02:00
Ethanfel 4da4858e4a fix: inline prune helpers when removed from both transformers locations
find_pruneable_heads_and_indices and prune_linear_layer were removed from
both pytorch_utils and modeling_utils in some transformers builds. Provide
minimal inline implementations as final fallback — prune_heads() is never
called at inference time so correctness is only needed for completeness.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:30:58 +02:00
Ethanfel ab8e1e5b7b feat: SelvaFeatureExtractor outputs prompt as STRING
Users can now wire the prompt output directly to SelvaSampler's prompt input,
making the data flow explicit instead of relying on the implicit features fallback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:27:49 +02:00
Ethanfel e3a3384727 fix: SelvaSampler input order — prompt required, negative_prompt optional
ComfyUI renders required inputs above optional ones. Moving negative_prompt
to optional puts prompt first (natural order) and negative_prompt at the
bottom where it belongs as a power-user input. Also guards against
negative_prompt=None when not connected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:27:07 +02:00
Ethanfel 9a985499e7 feat: auto-download SelVA weights on first use
Uses selva_core/utils/download_utils.py (already has URLs + MD5s for all
weights). Models download to models/selva/ on first load. Synchformer reuses
models/prismaudio/synchformer_state_dict.pth if already present (no duplicate
download for PrismAudio users), otherwise downloads to models/selva/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:25:36 +02:00
Ethanfel 27b4424e1a feat: prompt entered once in SelvaFeatureExtractor, reused by SelvaSampler
SelvaFeatureExtractor now stores the prompt in SELVA_FEATURES (both in the
returned dict and the .npz cache). SelvaSampler's prompt is now optional —
when left empty it falls back to the prompt stored in features. A non-empty
override can still be passed when CLIP text should differ from the sync text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:22:59 +02:00
Ethanfel 0e417f4078 fix: transformers compat — find_pruneable_heads_and_indices import
Some transformers builds removed these from pytorch_utils. Fall back to
modeling_utils which exposes them in all known versions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:21:26 +02:00
Ethanfel 6474e2816c fix: two bugs in SelVA nodes
- selva_feature_extractor: cache hash now includes resolved duration;
  same video + different duration override no longer returns stale features
- selva_sampler: MPS-safe noise generation (torch.Generator on CPU then
  move to device, same pattern as PrismAudioSampler)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:39:57 +02:00
Ethanfel c23d210ab2 feat: SelVA video-to-audio example workflow
LoadVideo → SelvaFeatureExtractor → SelvaSampler → PreviewAudio.
Defaults: medium_44k, bf16, 25 steps, cfg=4.5.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:31:53 +02:00
Ethanfel b59b657b6f feat: SelvaSampler — flow matching ODE with CFG and negative prompts
Calls update_seq_lengths with actual feature dimensions (not seq_cfg) to
avoid rounding assertion mismatches. Progress bar tracks each Euler step.
Supports negative prompts for steering, normalizes output to [-1,1].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:31:18 +02:00
Ethanfel 578b501d38 feat: SelvaFeatureExtractor — inline CLIP + TextSynchformer feature extraction
CLIP frames at 8fps→384px (normalize inside FeaturesUtils).
Sync frames at 25fps→224px, normalized to [-1,1] externally.
T5 text encoded via FeaturesUtils, sup tokens prepended, then text-conditioned
sync features extracted via TextSynch.encode_video_with_sync(). Results cached
as .npz keyed by hash(frames[:1MB] + prompt + fps + variant).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:23:40 +02:00
Ethanfel fe94438356 feat: SelvaModelLoader node — loads TextSynch + MMAudio + FeaturesUtils
Resolves weights from models/selva/. Reuses synchformer_state_dict.pth from
models/prismaudio/ (no duplicate download). Supports four variants:
small_16k / small_44k / medium_44k / large_44k.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:21:03 +02:00
Ethanfel 6bc3fd6443 chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:18:09 +02:00
189 changed files with 23219 additions and 11938 deletions
+459
View File
@@ -0,0 +1,459 @@
# LoRA Training for SelVA
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
---
## Overview
Training is split into two steps:
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 34 GB of VRAM.
---
## Requirements
Same environment as SelVA inference. Additional Python packages:
```
torchaudio
soundfile
```
---
## Step 1 — Prepare the dataset
### 1.1 Video format
The feature extractor accepts any input but internally resamples frames to fixed square resolutions (384×384 for CLIP, 224×224 for Synchformer). Both encoders were trained on standard video datasets — predominantly landscape footage. This has two practical implications:
**Aspect ratio** — use **16:9 landscape** whenever possible. Portrait clips (9:16) are mechanically supported but the bicubic stretch into square distorts the image relative to the encoders' training distribution, which can degrade sync feature quality. If your source is portrait, center-crop to square before extraction. Square (1:1) is also fine.
**Resolution** — anything ≥ 480p is sufficient. The extractor downscales to 384px and 224px regardless of source resolution; higher resolution adds no benefit.
**Frame rate** — any. Connect `VHS_VIDEOINFO` from VHS LoadVideo to the feature extractor so fps is read automatically from the file instead of being entered manually.
| Format | Recommendation |
|---|---|
| Aspect ratio | 16:9 landscape (preferred) or 1:1 square |
| Resolution | ≥ 480p (720p+ is fine, no upper limit that matters) |
| Frame rate | Any — set via VHS_VIDEOINFO |
| Portrait (9:16) | Center-crop to square before extraction |
### 1.2 Extract visual features in ComfyUI
For each video clip you want to train on:
1. Load the video with a VHS LoadVideo node.
2. Connect it to **SelVA Feature Extractor**.
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
### Prompt guide
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
**Good prompts are specific about:**
- The sound source (what object is making the sound)
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
- The action producing the sound (if applicable)
| Sound | Weak prompt | Strong prompt |
|---|---|---|
| Dog bark | `dog` | `a large dog barking loudly` |
| Footsteps | `walking` | `heavy boots on a wooden floor` |
| Water | `water` | `water dripping into a metal bucket` |
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
| Door | `door` | `a heavy wooden door slamming shut` |
**Rules of thumb:**
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
### Masking note
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
### 1.3 Collect clean audio
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
```
dataset/my_sound/
dog_bark_001.npz ← from SelVA Feature Extractor
dog_bark_001.wav ← clean isolated audio recording
dog_bark_002.npz
dog_bark_002.wav
dog_bark_003.npz
dog_bark_003.wav
```
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
### 1.4 Optional: prompts.txt
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
```
# One line per file: filename: prompt text
dog_bark.npz: a large dog barking aggressively
dog_bark_001.npz: a dog barking in the distance
```
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
---
## Step 2 — Train
### Option A — SelVA LoRA Trainer node (ComfyUI)
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
```
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
```
### Option B — Command line
```bash
python train_lora.py \
--data_dir dataset/my_sound \
--output_dir lora_output/my_sound \
--variant large_44k \
--selva_dir /path/to/ComfyUI/models/selva \
--rank 16 \
--steps 4000 \
--batch_size 4 \
--lr 1e-4
```
The script will:
1. Load the VAE, CLIP text encoder, and generator.
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
3. Train LoRA adapters for the specified number of steps.
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
---
## CLI Reference
| Argument | Default | Description |
|---|---|---|
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
| `--selva_dir` | required | Path to SelVA model weights directory |
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
| `--lr` | `1e-4` | Learning rate |
| `--steps` | `2000` | Total training steps |
| `--warmup_steps` | `100` | Linear LR warmup steps |
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
| `--save_every` | `500` | Save a checkpoint every N steps |
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
| `--seed` | `42` | Random seed |
| `--timestep_mode` | `uniform` | Timestep sampling: `uniform`, `logit_normal`, or `curriculum` |
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` / `curriculum` |
| `--curriculum_switch` | `0.6` | Fraction of steps to use logit_normal before switching to uniform. Only with `curriculum` |
| `--lora_dropout` | `0.0` | Dropout on the LoRA path only. `0.05``0.1` helps regularize on small datasets |
| `--lora_plus_ratio` | `1.0` | LoRA+ LR ratio: `lr_B = lr × ratio`. `1.0` = standard LoRA, `16.0` = LoRA+ |
---
## Step 3 — Load the adapter in ComfyUI
Connect **SelVA LoRA Loader** between the model loader and the sampler:
```
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
```
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
| Input | Description |
|---|---|
| `model` | SELVA_MODEL from the model loader |
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
---
## Tuning Guide
### Clip length
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
**Practical recommendations:**
| Sound type | Clip strategy |
|---|---|
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 23 times per clip |
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
| Sound with a clear onset (explosion, splash) | Put the onset at ~12 s from the start, not at 0 s — gives the model context |
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
### How many clips do I need?
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
| Dataset size | Scenario | Expected result |
|---|---|---|
| **510 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
| **1530 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
| **3060 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
| **60150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
| **150300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
### Batch size
| Batch size | VRAM (large_44k) | Use case |
|---|---|---|
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
| `8` | ~15 GB | Better convergence on larger datasets |
| `16` | ~20 GB | Best gradient quality when VRAM allows |
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
### Rank
| Rank | Use case |
|---|---|
| `8` | Fine details on a sound the model already knows well |
| `16` | Default — good balance of capacity and VRAM |
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
Higher rank increases VRAM usage and overfitting risk on small datasets.
### Steps
With `batch_size=4` as the default, these are rough guidelines:
| Dataset size | Recommended steps |
|---|---|
| 1020 clips | 20004000 |
| 2050 clips | 40008000 |
| 50+ clips | 600015000 |
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
### Learning rate
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
### Target layers
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
```bash
--target attn.qkv linear1
```
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
### Timestep sampling mode
Controls how training timesteps are sampled at each step.
`uniform` (default) samples all timesteps equally — equivalent to original MMAudio training.
`logit_normal` concentrates more steps near t=0.5 via `sigmoid(N(0, σ))`. This is the semantically rich mid-noise region. Consistently reaches a lower loss floor but the perceptual improvement on small datasets is marginal.
`curriculum` uses logit_normal for the first `curriculum_switch` fraction of steps (default 60%), then switches to uniform for the remainder. The motivation: logit_normal accelerates early structure learning but undertrains the high-t boundary region; uniform then fills in the fine detail. A switch message is logged when the transition happens.
| Mode | When to use |
|---|---|
| `uniform` (default) | Baseline — safe, equivalent to original training |
| `logit_normal` | When you want a lower loss floor; marginal on small datasets |
| `curriculum` | Experimental — may improve convergence quality on small datasets |
The `logit_normal_sigma` parameter controls the width of the logit-normal distribution (used by both `logit_normal` and the first phase of `curriculum`):
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
- σ=0.5: sharper peak, less coverage of extremes
- σ=2.0: broader, approaches uniform
### LoRA dropout
`lora_dropout` applies dropout to the input of the LoRA path (not the frozen base linear). It regularizes the low-rank update without disturbing pretrained weights — helpful on small datasets where the LoRA would otherwise overfit to the training clips.
| Value | Use case |
|---|---|
| `0.0` (default) | No regularization — fine for 30+ clips |
| `0.05` | Light regularization — recommended starting point on 1020 clips |
| `0.1` | Stronger regularization — use if loss plateaus but audio is still noisy |
Dropout is not saved in the adapter file — it only affects training. Loading the adapter at inference does not require setting dropout.
### LoRA+ (asymmetric learning rate)
`lora_plus_ratio` splits the learning rate between LoRA A and B matrices: `lr_B = lr × ratio`. The B matrix is the output-side projection and benefits from a higher LR. Setting ratio to 16 enables the LoRA+ scheme from arXiv:2402.12354.
| Ratio | Effect |
|---|---|
| `1.0` (default) | Standard LoRA — identical A and B learning rates |
| `4.0` | Mild asymmetry |
| `16.0` | LoRA+ — faster convergence, especially on early steps |
LoRA+ is orthogonal to dropout and curriculum sampling — all three can be combined.
### Adapter strength at inference
| Strength | Effect |
|---|---|
| `0.50.7` | Conservative — blends adapter with base model, less noise |
| `1.0` | Full adapter strength (default) |
| `>1.0` | Exaggerated effect, may introduce artifacts |
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.60.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
### Loss interpretation
A typical loss curve:
- Starts around `0.81.0`
- Should reach `0.550.65` after convergence on a clean sound class with 1030 clips
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
- Below `0.1` on a small dataset means overfitting
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
### Precision
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
---
## Output files
```
lora_output/my_sound/
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
adapter_step01000.pt
...
adapter_final.pt ← final adapter with embedded metadata (inference only)
meta.json ← human-readable metadata
sample_step00500.wav ← quick eval sample at each checkpoint
loss_raw.png ← raw loss curve
loss_smoothed.png ← EMA-smoothed loss curve
```
`adapter_final.pt` format:
```python
{
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
"meta": {
"variant": "large_44k",
"rank": 16,
"alpha": 16.0,
"target": ["attn.qkv"],
"steps": 2000
}
}
```
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
---
## Troubleshooting
**`No layers matched target=...`**
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
**`No .npz files found in ...`**
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
**`No audio file found for clip.npz`**
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
**The sound is audible but there is white noise on top**
Lower the adapter strength to `0.60.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
**LoRA appears to have no effect**
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
**Loss does not decrease**
- Increase `batch_size` for more stable gradients.
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
- Check that the audio files are clean and actually contain the target sound.
- Check that the `.npz` features were extracted with a relevant prompt.
**Loss explodes or NaN**
- Lower the learning rate (`5e-5`).
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
**Loss plateaus early (above 0.7)**
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
---
## Observations (work in progress)
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
### Precision and batch size
| Config | Smoothed loss at step 2000 | Notes |
|---|---|---|
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 60008000 at ~0.59 |
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
### logit_normal vs uniform
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
### White noise
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
- Too few clips for the model to confidently predict the target sound
- Imprecise extraction prompts producing unfocused sync features
- Missing mask when multiple objects are in frame
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.03.5 or adapter strength to 0.60.7 helps at inference.
+299 -75
View File
@@ -1,156 +1,380 @@
# ComfyUI-PrismAudio # ComfyUI-SelVA
Custom nodes for [PrismAudio](https://huggingface.co/FunAudioLLM/PrismAudio) (ICLR 2026) — video-to-audio and text-to-audio generation using decomposed Chain-of-Thought reasoning with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE. Custom nodes for [SelVA](https://github.com/jnwnlee/selva) — video-to-audio generation driven by text prompts. SelVA conditions audio synthesis on both visual content and natural language, letting you describe *what* sounds to generate rather than just *when*.
## Installation Built on [MMAudio](https://github.com/hkchengrex/MMAudio) with a TextSynchformer encoder that injects text guidance directly into the visual sync stream.
Clone into your ComfyUI custom nodes directory: ---
```bash
cd ComfyUI/custom_nodes
git clone https://github.com/Ethanfel/ComfyUI-Prismaudio.git ComfyUI-PrismAudio
pip install -r ComfyUI-PrismAudio/requirements.txt
```
**flash-attn** is optional — detected at runtime, falls back to PyTorch SDPA if unavailable.
## Nodes ## Nodes
### PrismAudio Model Loader ### SelVA Model Loader
Loads the DiT diffusion model and VAE. Auto-downloads weights from HuggingFace on first use. Loads the generator, TextSynchformer encoder, and all feature utilities (CLIP, T5, Synchformer, VAE). Weights are auto-downloaded from HuggingFace on first use.
| Input | Options | Description | | Input | Options | Description |
|-------|---------|-------------| |-------|---------|-------------|
| `precision` | auto / fp32 / fp16 / bf16 | DiT and conditioner dtype. VAE is always fp32. | | `variant` | small_16k / small_44k / medium_44k / large_44k | Model size and output sample rate |
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management. | | `precision` | bf16 / fp16 / fp32 | Compute dtype |
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management |
**Output:** `model` (SELVA_MODEL)
--- ---
### PrismAudio Feature Extractor ### SelVA Feature Extractor
Extracts video features (VideoPrism LvT, Synchformer) and text features (T5-Gemma) from a video in a subprocess. Results are cached on disk. Extracts CLIP visual features and text-guided sync features from a video. Results are cached on disk — re-running with the same inputs is instant.
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From SelVA Model Loader |
| `video` | IMAGE tensor from any ComfyUI video loader | | `video` | IMAGE tensor from any ComfyUI video loader |
| `caption_cot` | Chain-of-thought description of the audio scene | | `prompt` | Text description of the audio to generate |
| `video_info` | *(optional)* `VHS_VIDEOINFO` from VHS LoadVideo — sets fps automatically | | `video_info` | *(optional)* VHS_VIDEOINFO from VHS LoadVideo — sets fps automatically |
| `fps` | Source fps — ignored if `video_info` is connected | | `fps` | Source fps — ignored if `video_info` is connected |
| `python_env` | `managed_env` (auto-created isolated venv, recommended) or `comfyui_env` (current Python, see warning below) | | `duration` | Override clip duration in seconds. `0` = infer from video length |
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir. | | `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir |
| `hf_token` | HuggingFace token for gated models. Prefer `HF_TOKEN` env var instead. | | `mask` | *(optional)* Segmentation mask `[T,H,W]` float [0,1] — static (1 frame) or per-frame |
| `mask_strength` | Background suppression strength. `1.0` = full neutral fill, `0.0` = no effect |
| `mask_clip` | Apply mask to CLIP features (384px path). Disable to let CLIP see the full scene |
| `mask_sync` | Apply mask to TextSynchformer sync features (224px path) |
**Outputs:** `features` (PRISMAUDIO_FEATURES), `fps` (FLOAT) **Outputs:** `features` (SELVA_FEATURES), `fps` (FLOAT), `prompt` (STRING)
**`managed_env`** auto-creates a venv at `_extract_env/` inside the plugin directory on first use and installs JAX, TF, VideoPrism, and Synchformer. This takes several minutes the first time. Connect `prompt` output to the Sampler's `prompt` input to avoid entering it twice.
**`comfyui_env`** uses the current ComfyUI Python — JAX/TF/videoprism must already be installed. Installing them into the ComfyUI environment may conflict with existing packages. #### Masking
Connect a segmentation mask (SAM2, Grounding DINO+SAM, or any ComfyUI mask node) to isolate a specific object's motion before encoding. Background pixels are filled with a neutral value (0.5) rather than zeroed — this keeps them in-distribution for CLIP and maps to exactly 0 after sync's `[-1,1]` normalization, minimising the influence of background motion on the generated audio.
Use `mask_sync=true, mask_clip=false` if you want sync features focused on the target object while CLIP still sees the full scene for broader context. Changing any mask parameter correctly busts the feature cache.
--- ---
### PrismAudio Feature Loader ### SelVA Sampler
Loads a pre-computed `.npz` feature file. Use this to re-use extracted features without re-running the extractor. Generates audio from video features. Runs the rectified flow ODE with classifier-free guidance.
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `npz_path` | Path to a `.npz` file produced by the Feature Extractor | | `model` | From SelVA Model Loader (or any loader/loader chain) |
| `features` | From SelVA Feature Extractor |
| `prompt` | Text description — leave empty to use the prompt stored in features |
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
| `duration` | Audio duration in seconds. `0` = use duration from features |
| `steps` | Sampling steps (default: 25) |
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
| `seed` | RNG seed |
| `normalize` | RMS-normalize output to `target_lufs` (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
**Output:** `AUDIO`
--- ---
### PrismAudio Sampler ### SelVA LoRA Loader
Video-to-audio generation. Takes model + features, produces AUDIO. Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From Model Loader | | `model` | SELVA_MODEL from Model Loader |
| `features` | From Feature Extractor or Feature Loader | | `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
| `duration` | Audio duration in seconds. Set to `0` to use the video duration from features automatically. | | `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
| `steps` | Sampling steps (default: 100) |
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) | **Output:** `model` (SELVA_MODEL with adapter injected)
---
### SelVA LoRA Trainer
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
---
### SelVA LoRA Scheduler
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `experiments_file` | Path to JSON sweep config |
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Skip Experiment
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
**Output:** `flag_path` (STRING)
---
### SelVA LoRA Evaluator
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
---
### SelVA Dataset Browser
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
**Outputs:** video path, audio path, frames directory, label, total count
---
### SelVA VAE Roundtrip
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `audio` | AUDIO to test |
**Output:** `audio_reconstructed` (AUDIO)
---
### SelVA HF Smoother
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
**Output:** `audio` (AUDIO)
---
### SelVA Spectral Matcher
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
**Output:** `audio` (AUDIO)
---
### SelVA Textual Inversion Trainer
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
---
### SelVA Textual Inversion Loader
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
---
### SelVA TI Scheduler
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Activation Steering Extractor
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with `.npz` feature files |
| `output_path` | Where to save `steering_vectors.pt` |
| `n_samples` | Clips to average over (default: 16) |
| `seed` | RNG seed | | `seed` | RNG seed |
**Output:** `steering_path` (STRING)
--- ---
### PrismAudio Text Only ### SelVA Activation Steering Loader
Text-to-audio generation without video. Uses the T5-Gemma encoder. Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
**Output:** `steering_vectors` (STEERING_VECTORS)
---
### SelVA BigVGAN Trainer
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From Model Loader | | `model` | SELVA_MODEL |
| `text_prompt` | Chain-of-thought audio scene description. Longer, more detailed prompts produce better results. | | `data_dir` | Directory with target-style audio files (searched recursively) |
| `duration` | Audio duration in seconds | | `output_path` | Where to save the fine-tuned vocoder `.pt` |
| `steps` | Sampling steps (default: 100) | | `train_mode` | `snake_alpha_only` (default) or `all_params` |
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) | | `steps` | Training steps (default: 2000) |
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
| `batch_size` | Clips per step (default: 4) |
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
| `save_every` | Checkpoint interval in steps (default: 500) |
| `seed` | RNG seed | | `seed` | RNG seed |
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
---
### SelVA BigVGAN Loader
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
---
### SelVA DITTO Optimizer
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `features` | From SelVA Feature Extractor |
| `prompt` | Sound description (leave empty to use features prompt) |
| `negative_prompt` | Sounds to suppress |
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
| `steps` | Euler steps for the final generation pass (default: 25) |
| `cfg_strength` | CFG scale (default: 4.5) |
| `seed` | RNG seed |
| `normalize` | *(optional)* RMS normalize output (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
**Output:** `AUDIO`
--- ---
## Workflows ## Workflows
### Video-to-Audio ### Basic generation
``` ```
VHS LoadVideo ──► PrismAudio Feature Extractor ──► PrismAudio Sampler ──► Save Audio VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
(video_info) ──────────────────► (fps auto) (video_info)
(features) ────────────────────► (features) (features) ──────────────────────────────────►│
duration=0 ─────────────────────► (auto from features) │ (prompt) ────────────────────────────────────►│
``` ```
### Pre-computed Features ### DITTO style transfer (recommended first approach)
``` ```
PrismAudio Feature Loader (.npz) ──► PrismAudio Sampler ──► Save Audio SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
SelVA Feature Extractor ──(features)────────────────────────────────────►│
(prompt) ──────────────────────────────────────►│
BJ reference_dir ───────────────────────────────────────────────────────►│
``` ```
### Text-to-Audio No training required. Each run optimizes x₀ independently for the current video and reference set.
### Vocoder fine-tuning
``` ```
PrismAudio Text Only ──► Save Audio SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
▲ ▲
checkpoint .pt SelVA Feature Extractor
``` ```
## HuggingFace Authentication ### LoRA training
Required for T5-Gemma (gated model) and PrismAudio weights. See [LORA_TRAINING.md](LORA_TRAINING.md).
1. Visit <https://huggingface.co/FunAudioLLM/PrismAudio> and accept the license. ---
2. Authenticate via one of:
- **Environment variable:** `export HF_TOKEN=hf_...`
- **CLI login:** `huggingface-cli login`
There is no `hf_token` widget on the main nodes by design — ComfyUI saves all STRING values to workflow JSON, which would expose your token. The Feature Extractor has an `hf_token` input as a convenience but using `HF_TOKEN` env var is preferred. ## Installation
## Model Files ```bash
cd ComfyUI/custom_nodes
git clone https://github.com/Ethanfel/ComfyUI-SelVA.git
pip install -r ComfyUI-SelVA/requirements.txt
```
Weights are auto-downloaded to `ComfyUI/models/prismaudio/`: ---
## Model Weights
Weights are auto-downloaded to `ComfyUI/models/selva/` on first load. No manual setup required.
| File | Size | Description | | File | Size | Description |
|------|------|-------------| |------|------|-------------|
| `prismaudio.ckpt` | ~2.7 GB | Diffusion model (DiT) | | `video_enc_sup_5.pth` | ~300 MB | TextSynchformer encoder |
| `vae.ckpt` | ~2.5 GB | Stable Audio 2.0 VAE | | `generator_small_16k_sup_5.pth` | ~340 MB | Small generator, 16 kHz output |
| `synchformer_state_dict.pth` | ~950 MB | Synchformer visual encoder | | `generator_small_44k_sup_5.pth` | ~340 MB | Small generator, 44.1 kHz output |
| `generator_medium_44k_sup_5.pth` | ~860 MB | Medium generator, 44.1 kHz output |
| `generator_large_44k_sup_5.pth` | ~2.0 GB | Large generator, 44.1 kHz output |
| `v1-16.pth` | ~1.1 GB | VAE for 16 kHz |
| `v1-44.pth` | ~1.1 GB | VAE for 44.1 kHz |
| `best_netG.pt` | ~90 MB | BigVGAN vocoder for 16 kHz |
| `synchformer_state_dict.pth` | ~950 MB | Synchformer (shared with PrismAudio if present) |
T5-Gemma and VideoPrism LvT are cached in `~/.cache/huggingface/`. CLIP (DFN5B-ViT-H-14-384) and T5 (flan-t5-base) are downloaded automatically from HuggingFace to `~/.cache/huggingface/`.
---
## VRAM Requirements ## VRAM Requirements
| VRAM | Recommended settings | | VRAM | Recommended settings |
|------|----------------------| |------|----------------------|
| 24 GB+ | `keep_in_vram`, any precision | | 24 GB+ | `keep_in_vram`, any variant |
| 1224 GB | `offload_to_cpu`, bf16/fp16 | | 1224 GB | `offload_to_cpu`, medium or smaller |
| 812 GB | `offload_to_cpu`, fp16 | | 812 GB | `offload_to_cpu`, small variant, fp16 |
| < 8 GB | May work with `offload_to_cpu` + fp16 |
## Troubleshooting The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available, otherwise `offload_to_cpu`.
- **Gated model errors** — Accept the license at <https://huggingface.co/FunAudioLLM/PrismAudio> and set `HF_TOKEN`. ---
- **VRAM errors** — Switch `offload_strategy` to `offload_to_cpu` and/or use `fp16` precision.
- **Feature extraction fails** — Ensure `synchformer_state_dict.pth` is in `models/prismaudio/`. On first run with `managed_env`, installation takes several minutes. ## Style Transfer
- **flash-attn** — Optional. Auto-detected at runtime; falls back to PyTorch SDPA.
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
---
## Credits ## Credits
PrismAudio by [FunAudioLLM](https://github.com/FunAudioLLM) (ICLR 2026). [Model & weights](https://huggingface.co/FunAudioLLM/PrismAudio). - [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
+158
View File
@@ -0,0 +1,158 @@
# Style Transfer for SelVA
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
---
## Why standard fine-tuning is hard
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~1015 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.210.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
The approaches below are ordered by expected quality and ease of use.
---
## Tier 1 — DITTO (recommended first try)
**Node: SelVA DITTO Optimizer**
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
**How it works:**
For each video clip you want to process:
1. Run SelVA Feature Extractor as usual.
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
```
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
SelVA Feature Extractor ──(features)────────────────────────►│
(prompt) ──────────────────────────►│
BJ clips ───────────────────────────(reference_dir) ─────────►│
```
**Tuning guide:**
| Parameter | Starting value | When to adjust |
|---|---|---|
| `n_opt_steps` | 50 | Increase to 100200 if style shift is too subtle |
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
| `style_weight` | 1.0 | Increase to 25 for stronger BJ character; watch for incoherence |
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~46 GB additional VRAM over baseline inference.
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 515 minutes depending on GPU.
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
---
## Tier 2 — Vocoder Fine-tuning
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
### Why plain mel L1 loss fails
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 38 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
### `snake_alpha_only` mode (default, recommended)
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
### `all_params` mode with discriminator
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
### Workflow
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor
```
### Tuning guide
| Parameter | Default | Notes |
|---|---|---|
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
| `steps` | 2000 | 10002000 for snake_alpha_only; 30005000 for all_params |
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
| `batch_size` | 4 | 48 for stable gradients |
| `segment_seconds` | 1.0 | 12 s segments recommended |
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
---
## Tier 3 — DITTO + Vocoder (combined)
Stack both:
```
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
```
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
---
## What doesn't work (and why)
### Standard LoRA
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
### Textual inversion
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
### Activation steering (global mean difference)
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 36 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
---
## Reference dataset preparation
Use the same audio clips for both DITTO and vocoder fine-tuning:
- **Minimum:** 2030 clips. DITTO works from 5+; vocoder benefits from 40+.
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
+1 -1
View File
@@ -1,5 +1,5 @@
""" """
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
""" """
import sys import sys
import os import os
-337
View File
@@ -1,337 +0,0 @@
"""
PrismAudio feature extraction utilities.
Implements FeaturesUtils used by scripts/extract_features.py to extract:
- Text features via T5-Gemma (transformers)
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
- Sync features via Synchformer visual encoder (PyTorch)
"""
import os
import torch
import torch.nn as nn
import numpy as np
class FeaturesUtils:
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
self.device = device or torch.device("cpu")
self._t5_tokenizer = None
self._t5_encoder = None
self._vp_model = None
self._vp_state = None
self._vp_text_tokenizer = None
self._sync_model = None
self._synchformer_ckpt = synchformer_ckpt
self._load_synchformer()
# ------------------------------------------------------------------
# T5-Gemma text encoding
# ------------------------------------------------------------------
def _ensure_t5(self):
if self._t5_encoder is not None:
return
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id = "google/t5gemma-l-l-ul2-it"
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
self._t5_encoder = (
AutoModelForSeq2SeqLM.from_pretrained(model_id)
.get_encoder()
.to(self.device)
.eval()
)
def encode_t5_text(self, texts):
"""
Args:
texts: list of str
Returns:
Tensor [seq_len, 1024]
"""
self._ensure_t5()
tokens = self._t5_tokenizer(
texts, return_tensors="pt", padding=True
).to(self.device)
with torch.no_grad():
out = self._t5_encoder(**tokens)
# Move encoder off GPU to save VRAM
self._t5_encoder.to("cpu")
torch.cuda.empty_cache()
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
# ------------------------------------------------------------------
# VideoPrism video + text encoding (JAX)
# ------------------------------------------------------------------
def _ensure_videoprism(self):
if self._vp_model is not None:
return
from videoprism import models as vp
import jax
model_name = "videoprism_lvt_public_v1_large"
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
self._vp_model = vp.get_model(model_name)
self._vp_state = vp.load_pretrained_weights(model_name)
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
jax_dev = jax.devices()[0]
self._jax_forward = jax.jit(
lambda x, y, z: self._vp_model.apply(
self._vp_state, x, y, z, train=False, return_intermediate=True
),
device=jax_dev,
)
def encode_video_and_text_with_videoprism(self, clip_input, texts):
"""
Args:
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
Returns:
global_video_features: Tensor [1, D]
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
global_text_features: Tensor [1, D]
"""
self._ensure_videoprism()
import jax.numpy as jnp
from videoprism import models as vp
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
frames = clip_input.squeeze(0) # [T, C, H, W]
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
frames_np = frames.cpu().numpy().astype(np.float32)
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
# Joint video+text forward with intermediate outputs
video_embeddings, text_embeddings, outputs = self._jax_forward(
frames_jax, text_ids, text_paddings
)
# Per-frame features: [B, T, 1024] L2-normalized
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
# Global video embedding: [1024] → [1, 1024]
global_video = torch.from_numpy(
np.array(video_embeddings[0])
).unsqueeze(0).to(self.device) # [1, 1024]
# Global text embedding: [1024] → [1, 1024]
global_text = torch.from_numpy(
np.array(text_embeddings[0])
).unsqueeze(0).to(self.device) # [1, 1024]
return global_video, per_frame, global_text
# ------------------------------------------------------------------
# Synchformer sync feature encoding
# ------------------------------------------------------------------
def _load_synchformer(self):
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
return
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
if isinstance(state, dict) and "model" in state:
state_dict = state["model"]
else:
state_dict = state
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
self._sync_model.eval()
def encode_video_with_sync(self, sync_input):
"""
Args:
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
Returns:
sync_features: Tensor [num_segments, 768]
"""
if self._sync_model is None:
raise RuntimeError(
"[FeaturesUtils] Synchformer checkpoint not loaded. "
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
)
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
with torch.no_grad():
return self._sync_model(frames)
# ------------------------------------------------------------------
# Synchformer visual encoder — TimeSformer-style ViT-B/16
# Architecture reverse-engineered from synchformer_state_dict.pth
# ------------------------------------------------------------------
import torch.nn.functional as F
class _PatchEmbed(nn.Module):
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
def __init__(self):
super().__init__()
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
class _ViTAttn(nn.Module):
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
class _BlockMLP(nn.Module):
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
def __init__(self, dim=768, mlp_dim=3072):
super().__init__()
self.fc1 = nn.Linear(dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, dim)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _TimeSformerBlock(nn.Module):
"""
Factorized space-time attention block.
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = _ViTAttn(dim, num_heads)
self.norm3 = nn.LayerNorm(dim)
self.timeattn = _ViTAttn(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = _BlockMLP(dim)
def forward(self, x, T):
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
x = x + self.attn(self.norm1(x))
# Temporal attention: for each spatial position, attend across T frames
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
xt = x.permute(1, 0, 2)
xt = xt + self.timeattn(self.norm3(xt))
x = xt.permute(1, 0, 2)
x = x + self.mlp(self.norm2(x))
return x
class _SpatialAttnAgg(nn.Module):
"""
Aggregates 196 spatial patches → 1 feature per frame using a
TransformerEncoderLayer with a learnable CLS token.
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.linear1 = nn.Linear(dim, dim * 4)
self.linear2 = nn.Linear(dim * 4, dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# x: [T, 196, 768] — spatial patches (CLS stripped)
T = x.shape[0]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
xn = self.norm1(x)
x = x + self.self_attn(xn, xn, xn)[0]
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x[:, 0, :] # [T, 768] — CLS per frame
class _SynchformerVisualEncoder(nn.Module):
"""
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
"""
def __init__(self, state_dict, device):
super().__init__()
self.device = device
self.segment_frames = 8
self.patch_embed = _PatchEmbed()
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
self.norm = nn.LayerNorm(768)
self.spatial_attn_agg = _SpatialAttnAgg()
# Load weights from vfeat_extractor.* prefix
prefix = "vfeat_extractor."
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
# Exclude 3D patch embed (we use 2D only)
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
missing, unexpected = self.load_state_dict(sub, strict=False)
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
if missing:
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
self.to(device)
def forward(self, frames):
"""
Args:
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
Returns:
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
"""
T = frames.shape[0]
seg = self.segment_frames
num_seg = max(1, T // seg)
T_aligned = num_seg * seg
results = []
for i in range(num_seg):
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
results.append(self._forward_segment(chunk))
return torch.cat(results, dim=0) # [T_aligned, 768]
def _forward_segment(self, x):
# x: [8, 3, 224, 224]
T = x.shape[0] # 8
# Patch embedding + CLS token
x = self.patch_embed(x) # [8, 196, 768]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
# Positional + temporal embeddings
x = x + self.pos_embed # broadcast (1,197,768)
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
# Transformer blocks (factorized space-time)
for block in self.blocks:
x = block(x, T)
x = self.norm(x)
# Aggregate spatial patches → 1 feature per frame
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
+170
View File
@@ -0,0 +1,170 @@
# Audio Dataset Pipeline for Generative Model Training
Research notes on audio cleaning, augmentation, and quality metrics for LoRA fine-tuning of MMAudio/SelVA. Based on papers and tooling survey (April 2026).
---
## Core Principle
Augmentation for generative models ≠ augmentation for classifiers.
The goal is **not invariance** — it is expanding the training manifold so the model learns the distribution of a sound rather than memorizing a fixed set of waveforms.
With 10 clips, velocity field collapse (arXiv:2410.23594) is mathematically expected: the flow-matching model memorizes the training trajectories instead of generalizing. More diverse data is the only real fix.
---
## Recommended Pipeline
### Step 1 — Quality Screening
```python
# Clipping check
clip_ratio = np.sum(np.abs(audio) >= 0.99) / len(audio) # flag if > 0.1%
# DC offset check + removal
dc = np.mean(audio)
audio -= dc
# LUFS normalization to -14 LUFS (essential for training consistency)
# pip install pyloudnorm
import pyloudnorm as pyln
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio)
audio = pyln.normalize.loudness(audio, loudness, -14.0)
# Or via ffmpeg: ffmpeg -af loudnorm=I=-14:LRA=7:TP=-1
# DNSMOS quality gate (discard if OVRL < 3.5 for training; < 2.5 is unusable)
# from Microsoft DNS-Challenge repo
```
### Step 2 — Cleaning
| Tool | Install | Use |
|---|---|---|
| **AudioSep** | `pip install audiosep` | Isolate target sound from background — most impactful tool |
| **noisereduce** | `pip install noisereduce` | Light stationary/non-stationary denoising, preserves character |
| **librosa** | `pip install librosa` | Silence trimming: `librosa.effects.trim(audio, top_db=30)` |
| **torchaudio.transforms.Fade** | (torchaudio) | Prevent click artifacts at clip edges |
| **DeepFilterNet** | `pip install deepfilternet` | Heavy denoising — good for speech, may alter tonal sounds |
**AudioSep usage:**
```python
from audiosep import AudioSep
model = AudioSep.from_pretrained("audio-agi/audiosep")
# ~1.5 GB checkpoint, ~4 GB VRAM
model.inference(audio_path, "a dog barking loudly", output_path)
```
### Step 3 — Waveform Augmentation (10 clips → 50100)
Apply stochastically per clip:
| Transform | Params | Notes |
|---|---|---|
| **PitchShift** | ±13 semitones | 3 variants per clip. Limit to ±1 st for tonal/pitched sounds |
| **ApplyImpulseResponse** | 5 different RIRs | 5 variants per clip — EchoThief (~150 free IRs) or pyroomacoustics |
| **LoudnessNormalization** | ±2 dB random | Subtle level variation |
| **SevenBandParametricEQ** | ±3 dB | Gentle spectral variation |
| **TimeStretch** | 0.91.1× only | Do NOT use 2× to pad short clips — breaks video sync |
```python
# pip install audiomentations pedalboard pyroomacoustics
import audiomentations as A
augment = A.Compose([
A.PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
A.ApplyImpulseResponse(ir_paths="path/to/irs/", p=0.5),
A.SevenBandParametricEQ(min_gain_db=-3, max_gain_db=3, p=0.3),
A.LoudnessNormalization(min_lufs=-16, max_lufs=-12, p=0.5),
A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3),
])
audio_aug = augment(samples=audio, sample_rate=sr)
```
**RIR sources:**
- EchoThief: ~150 free real-world IRs (churches, caves, parking garages)
- pyroomacoustics: synthetic room simulation, fully controllable
### Step 4 — Latent Augmentation (at training time)
After VAE encoding:
**Latent mixup** between same-category pairs:
```python
# Mix latents BEFORE flow-matching noise is added
# Only mix clips from the same sound category — cross-category mixing produces garbage
lam = torch.distributions.Beta(0.4, 0.4).sample()
z_mix = lam * z1 + (1 - lam) * z2
```
With 10 clips: C(10,2) = 45 possible pairs → significant expansion without new recordings.
**Small Gaussian noise:**
```python
z_noised = z + torch.randn_like(z) * 0.02 * z.std()
```
Prevents trivial memorization of exact latent coordinates.
MusicLDM (arXiv:2308.01546) shows latent mixup > waveform mixup for generative quality.
---
## Transforms to AVOID for Generative Training
| Transform | Why |
|---|---|
| ClippingDistortion, BitCrush, TanhDistortion, Mp3Compression | Model learns the artifact |
| Reverse | Breaks temporal structure for video-to-audio task |
| TimeMask (creating silence gaps) | Unnatural — model learns to produce silence |
| TimeStretch > 1.3× | Phase vocoder artifacts become part of the target distribution |
| Heavy background noise (< 15 dB SNR) | Model learns to reproduce the noise |
---
## Quality Metrics
| Metric | Tool | Threshold |
|---|---|---|
| DNSMOS P.835 (SIG/BAK/OVRL) | Microsoft DNS-Challenge | OVRL > 3.5 for training |
| LUFS | pyloudnorm | Normalize all clips to -14 LUFS |
| WADA-SNR | (standalone) | No-reference SNR estimate |
| Clipping ratio | NumPy | Flag if > 0.1% of samples at ±0.99 |
---
## Tool Reference
| Tool | Install | Purpose |
|---|---|---|
| audiomentations | `pip install audiomentations` | Primary augmentation library |
| pedalboard | `pip install pedalboard` | Higher quality pitch shift, IR convolution |
| AudioSep | `pip install audiosep` | Source separation / isolation |
| noisereduce | `pip install noisereduce` | Non-stationary denoising |
| DeepFilterNet | `pip install deepfilternet` | Heavy denoising (speech-optimized) |
| pyloudnorm | `pip install pyloudnorm` | LUFS normalization |
| Silero VAD | `pip install silero-vad` | Voice/silence detection |
| pyroomacoustics | `pip install pyroomacoustics` | Synthetic RIR generation |
---
## Integration with PrismAudio / SelVA
No established ComfyUI audio preprocessing ecosystem as of early 2026. Build thin wrapper nodes around the tools above. PrismAudio already has all required patterns (subprocess isolation, AUDIO type transport).
**Target node set:**
- `SelVA Dataset Cleaner` — wraps noisereduce + LUFS normalization + trim + DNSMOS gate
- `SelVA Dataset Augmenter` — wraps audiomentations Compose pipeline
Steps 13 are preprocessing (run once before feature extraction).
Step 4 (latent mixup) is a training loop modification — integrate into `selva_lora_trainer.py`.
---
## Key Papers
| Paper | ArXiv | Finding |
|---|---|---|
| MusicLDM | 2308.01546 | Latent mixup > waveform mixup for generative quality |
| EDMSound | 2311.08667 | Memorization documented — same failure mode as 10-clip training |
| Synthio | 2410.02056 | Synthetic audio as augmentation data (ICLR 2025) |
| HunyuanVideo-Foley | 2508.16930 | V2A data pipeline at scale (100K hrs) |
| FM memorization | 2410.23594 | Velocity field collapse theory — proves early overfitting on small datasets |
+184
View File
@@ -0,0 +1,184 @@
# AudioX vs SelVA — Evaluation
AudioX (arXiv:2503.10522, ICLR 2026) is a unified multimodal audio generation model from HKUST.
This document compares it against SelVA/MMAudio and assesses the cost of adding it to PrismAudio.
---
## Quick Decision Guide
| Situation | Use |
|---|---|
| Video → realistic sound effects | **SelVA** — faster, purpose-built, MIT license |
| Music generation from video or text | **AudioX** — SelVA cannot do this |
| Audio inpainting / music continuation | **AudioX** — SelVA cannot do this |
| LoRA fine-tuning on a custom sound | **SelVA** — full training infrastructure already exists |
| Variable output duration | **AudioX** — SelVA is fixed at 8 s |
| Inference speed matters | **SelVA** — 25 steps vs 250 (10× faster) |
| Non-commercial research | Either |
| Any commercial use | **SelVA only** — AudioX is CC-BY-NC-4.0 |
---
## Architecture
| Dimension | SelVA (MMAudio) | AudioX-MAF |
|---|---|---|
| Core paradigm | Flow matching | Diffusion (k-diffusion / DPM++) |
| Inference steps | 25 ODE steps (Euler) | 250 diffusion steps (DPM++ 3M SDE) |
| Sample rate | 44.1 kHz (large) / 16 kHz (small) | 48 kHz (fixed) |
| Generator | MM-DiT, velocity prediction | ContinuousMMDiTTransformer |
| Video encoder | Synchformer | Synchformer (AudioX custom re-impl, same concept) |
| VAE / codec | DAC (descript-audio-codec) | DAC + AudioCraft options |
| Text encoder | T5-large | T5 (configurable small → XXL) |
| Video-audio fusion | Cross-attention in MM-DiT | MAF: dual-projection (dim alignment + seq length alignment) |
| Output duration | Fixed 8 s | Configurable via `sample_size` (default ~44 s at 48kHz) |
| Training data | ~2 M samples (MMAudio paper) | 7 M samples (IF-caps dataset, curated) |
| License | MIT | CC-BY-NC-4.0 |
**MAF (Multimodal Adaptive Fusion):** AudioX's key architectural contribution. Instead of directly
concatenating multimodal tokens into the DiT's cross-attention, MAF projects each modality to
match the latent's sequence length via a dedicated linear + transposed-conv stack, then applies
`MMDitSingleBlock` layers for cross-modal fusion. The paper reports this improves cross-modal
alignment particularly for video-to-audio tasks.
**Flow matching vs diffusion:** Flow matching (SelVA) trains a single velocity field to move
directly from noise to data along a straight trajectory — this is why 25 steps suffice. Standard
diffusion (AudioX) approximates a longer stochastic path, requiring 250 steps for quality output.
This is not a quality difference per se; flow matching is simply more efficient.
---
## Capabilities
| Task | SelVA | AudioX |
|---|---|---|
| Video → sound effects | ✓ (primary use case) | ✓ |
| Text → sound effects | Partial (T5 conditions quality but not primary) | ✓ (strong benchmark scores) |
| Video → music | ✗ | ✓ |
| Text → music | ✗ | ✓ |
| Audio inpainting | ✗ | ✓ (mask_args parameter) |
| Music continuation | ✗ | ✓ (init_audio parameter) |
| Variable output duration | ✗ (fixed 8 s) | ✓ |
| Multiple input modalities simultaneously | Partial | ✓ (text + video + audio at once) |
AudioX benchmarks claim superior results on text-to-audio (AudioCaps) and text-to-music
(MusicCaps) vs prior models. Video-to-audio comparison against MMAudio specifically is not
prominently featured in the paper. Perceptual evaluation confirms this: AudioX does not sound
noticeably better than SelVA on video-to-audio tasks. AudioX's advantage is **breadth**
(music, inpainting, variable duration), not raw video-to-audio quality.
---
## Integration Cost
Adding AudioX inference-only nodes to PrismAudio would require:
### New nodes (3 files)
```
nodes/
audiox_model_loader.py AUDIOX_MODEL loader — get_pretrained_model("HKUSTAudio/AudioX-MAF")
audiox_sampler.py wraps generate_diffusion_cond(), inputs: model + text + video + audio
audiox_feature_extractor.py optional — pre-extract Synchformer sync features (caching)
```
### Installation
```bash
pip install git+https://github.com/ZeyueT/AudioX.git
```
New dependencies not currently in PrismAudio:
- `pytorch-lightning==2.4.0`
- `k-diffusion==0.1.1`
- `v-diffusion-pytorch==0.0.2`
- `descript-audio-codec==1.0.0` (already used by SelVA — no conflict, same package)
- `gradio==4.44.1` (optional — only for the upstream Gradio UI)
Model weights: `HKUSTAudio/AudioX-MAF` on HuggingFace (~several GB).
### Inference API surface
```python
from audiox import get_pretrained_model
from audiox.inference.generation import generate_diffusion_cond
model, config = get_pretrained_model("HKUSTAudio/AudioX-MAF")
output = generate_diffusion_cond(
model,
steps=250,
cfg_scale=6.0,
conditioning={
"text_prompt": "a dog barking",
"video_prompt": {"video": frames_tensor, "sync_features": sync_feat},
"seconds_total": 8.0,
},
sample_size=384000, # 8 s at 48kHz
sample_rate=48000,
device="cuda",
)
# output: torch.Tensor (batch, channels, num_samples) float32 [-1, 1]
```
---
## LoRA Training
Adding AudioX LoRA training to PrismAudio is **significantly harder** than the SelVA trainer:
| Aspect | SelVA LoRA | AudioX LoRA |
|---|---|---|
| Loss function | Single MSE velocity loss | Diffusion loss over 250-step schedule |
| Training steps needed | ~2000 steps practical | Unknown — likely much more |
| Step cost | Fast (1 velocity prediction) | Slow (full diffusion forward pass per step) |
| Existing infrastructure | Full trainer + scheduler + experiments | Nothing — would need to build from scratch |
| Noise schedule | Trivial (linear interpolation) | Cosine alpha-sigma schedule |
| Prior art for LoRA | LoRA on flow matching well-studied | Less explored; closer to Stable Diffusion LoRA |
**Conclusion:** AudioX LoRA training is feasible (it would follow SD-style LoRA with the DPM++
noise schedule) but would be a substantial new project. Not worth building until inference nodes
are stable and there is a clear use case that SelVA cannot serve.
---
## License
AudioX weights are released under **CC-BY-NC-4.0** (Creative Commons Non-Commercial).
- Free for personal use, research, and non-commercial projects
- **Cannot be used in commercial products or services** without a separate agreement
- Attribution required
- SelVA/MMAudio: MIT (no restrictions)
If PrismAudio is ever distributed as part of a commercial tool, AudioX nodes must be clearly
opt-in with a license warning, or excluded entirely.
---
## Recommendation
**Short term:** AudioX is not a replacement for SelVA for the current use case (video → custom
sound effects with LoRA fine-tuning). SelVA is faster, has full training infrastructure, and
is MIT licensed.
**When AudioX becomes worth integrating:**
- If you need to generate background music synchronized to video
- If you need audio inpainting (fill a gap in an existing audio track)
- If you need text-to-audio generation without a video input
- After verifying the CC-BY-NC-4.0 license is acceptable for your use
**Estimated integration effort for inference nodes only:** 23 days of work (3 new node files,
dependency management, testing). No changes to existing SelVA nodes required — they would
coexist in the same package.
---
## References
- Paper: arXiv:2503.10522 — *AudioX: Diffusion Transformer for Anything-to-Audio Generation*
- GitHub: https://github.com/ZeyueT/AudioX
- Model weights: https://huggingface.co/HKUSTAudio/AudioX-MAF
- Demo: https://huggingface.co/spaces/Zeyue7/AudioX
- Project page: https://zeyuet.github.io/AudioX/
@@ -1,194 +0,0 @@
# ComfyUI-PrismAudio Design Document
**Date:** 2026-03-27
**Status:** Approved
## Overview
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
## Architecture
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
## Project Structure
```
ComfyUI-PrismAudio/
├── __init__.py # Node registration
├── nodes/
│ ├── __init__.py
│ ├── model_loader.py # PrismAudioModelLoader
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
│ ├── sampler.py # PrismAudioSampler
│ ├── text_only.py # PrismAudioTextOnly
│ └── utils.py # Shared helpers
├── prismaudio_core/ # Extracted inference code from PrismAudio
│ ├── __init__.py
│ ├── configs/
│ │ └── prismaudio.json
│ ├── models/ # DiT, conditioners, autoencoders, etc.
│ ├── inference/ # sampling.py, generation.py
│ └── factory.py # create_model_from_config
├── scripts/
│ ├── extract_features.py # Standalone VideoPrism feature extraction
│ └── environment.yml # Conda env for extraction (JAX + TF)
├── requirements.txt # PyTorch-only deps (no JAX/TF)
└── README.md
```
## Nodes
### PrismAudioModelLoader
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
| **Output** | | |
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
**Token resolution order** (no widget — env/CLI only for security):
1. `HF_TOKEN` environment variable
2. `huggingface-cli login` cached token
3. None — fails on gated models with clear error message linking to license page
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
### PrismAudioFeatureLoader
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| npz_path | STRING | Path to .npz file |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
### PrismAudioFeatureExtractor
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| video | IMAGE | ComfyUI video frames tensor |
| caption_cot | STRING | CoT description text |
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
### PrismAudioSampler
Main generation node — takes model + features, produces audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
| cot_description | STRING | Multiline CoT text |
| duration | FLOAT | 1.0-30.0, defaults to video length |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
**Pipeline:**
1. Encode CoT text via T5-Gemma -> text_features
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
3. Compute latent_seq_len = round(44100 / 2048 * duration)
4. Generate noise [1, 64, latent_seq_len] from seed
5. Discrete Euler sampling (rectified flow) with CFG
6. VAE decode -> stereo waveform at 44100 Hz
7. Normalize to [-1, 1], return as AUDIO
### PrismAudioTextOnly
Text-to-audio without video input.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| text_prompt | STRING | Text description |
| duration | FLOAT | 1.0-30.0 |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
## VRAM Management
Adaptive strategy using `comfy.model_management`:
| Available VRAM | Behavior |
|---|---|
| 24GB+ | Keep diffusion + VAE in VRAM |
| 12-24GB | Sequential offload between stages |
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
| <8GB | Warn user, attempt with aggressive offload + fp16 |
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
## Feature Extraction Paths
### Path 1: Pre-computed .npz (FeatureLoader)
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
### Path 2: Subprocess bridge (FeatureExtractor)
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
### Path 3: Text-only (TextOnly node)
No video features needed. T5-Gemma text encoding only (PyTorch-native).
## Dependencies
### ComfyUI environment (`requirements.txt`)
```
einops>=0.7.0
safetensors
huggingface_hub
transformers>=4.52.3
k-diffusion>=0.1.1
```
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
### Extraction environment (`scripts/environment.yml`)
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
## Model Files
Stored in `ComfyUI/models/prismaudio/`:
| File | Size | Source |
|------|------|--------|
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
## Design Decisions
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,606 @@
# Audio Dataset Pipeline Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** Add 5 chainable ComfyUI nodes for in-memory audio dataset preprocessing: load → resample → LUFS normalize → inspect/filter → extract single item.
**Architecture:** Single new file `nodes/selva_dataset_pipeline.py` defines a custom `AUDIO_DATASET` type (list of dicts) and all 5 node classes. Nodes are stateless transforms — each takes `AUDIO_DATASET` and returns `AUDIO_DATASET`. No disk I/O except in the Loader. Register all nodes in `nodes/__init__.py`.
**Tech Stack:** `pyloudnorm` (BS.1770-4 LUFS), `soxr` (VHQ resampling), `torchaudio`, `torch`. Both confirmed present in the ComfyUI environment at `/media/p5/miniforge3/envs/latestcomfyui`.
---
## The `AUDIO_DATASET` type
Used as the ComfyUI type string `"AUDIO_DATASET"`. At runtime it is a Python list of dicts:
```python
[
{
"waveform": torch.Tensor, # shape [1, C, L], float32, range [-1, 1]
"sample_rate": int,
"name": str, # original filename stem, for reporting
},
...
]
```
---
### Task 1: Create the file skeleton and AUDIO_DATASET constant
**Files:**
- Create: `nodes/selva_dataset_pipeline.py`
**Step 1: Write the file with imports and type constant only**
```python
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
```
**Step 2: Verify import works (no test framework needed — just a quick smoke check)**
```bash
cd /media/p5/Comfyui-Prismaudio
python3 -c "from nodes.selva_dataset_pipeline import AUDIO_DATASET; print(AUDIO_DATASET)"
```
Expected output: `AUDIO_DATASET`
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add audio dataset pipeline skeleton"
```
---
### Task 2: SelvaDatasetLoader
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Loader class**
```python
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = torchaudio.load(str(f)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader
node = SelvaDatasetLoader()
ds, = node.load('/media/unraid/davinci/Selva/BJ')
print(len(ds), 'clips', ds[0]['name'], ds[0]['waveform'].shape, ds[0]['sample_rate'])
"
```
Expected: prints clip count, first clip name, shape like `torch.Size([1, 2, 352800])`, sample rate.
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLoader node"
```
---
### Task 3: SelvaDatasetResampler
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Resampler class**
Uses `soxr` directly for VHQ quality. `soxr.resample` operates on numpy arrays, shape `[L, C]` (time-first).
```python
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetResampler
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetResampler().resample(ds, 44100)
print('ok', ds2[0]['sample_rate'], ds2[0]['waveform'].shape)
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetResampler node (soxr VHQ)"
```
---
### Task 4: SelvaDatasetLUFSNormalizer
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the LUFS normalizer class**
`pyloudnorm.Meter` requires numpy float64 array shape `[L]` (mono) or `[L, C]` (multichannel, channels last). True peak limit applied after gain.
```python
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetLUFSNormalizer
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetLUFSNormalizer().normalize(ds, -23.0, -1.0)
print('ok', ds2[0]['name'], ds2[0]['waveform'].abs().max().item())
"
```
Expected: peak ≤ ~0.89 (≈ -1 dBTP).
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)"
```
---
### Task 5: SelvaDatasetInspector
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add helper functions for artifact detection**
```python
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
window = torch.hann_window(n_fft)
mono = wav[0].mean(0) # [L]
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
```
**Step 2: Add the Inspector class**
```python
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
```
**Step 3: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetInspector
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
clean, report = SelvaDatasetInspector().inspect(ds, skip_rejected=False, min_snr_db=15.0, check_codec_artifacts=True)
print(report)
"
```
Expected: report with per-clip OK/FLAGGED lines and summary counts.
**Step 4: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)"
```
---
### Task 6: SelvaDatasetItemExtractor
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the extractor class**
```python
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetItemExtractor
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
audio, name, total = SelvaDatasetItemExtractor().extract(ds, 0)
print(name, total, audio['waveform'].shape, audio['sample_rate'])
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetItemExtractor node"
```
---
### Task 7: Register all nodes in __init__.py
**Files:**
- Modify: `nodes/__init__.py:4-25`
**Step 1: Add the 5 new entries to `_NODES`**
Add inside the `_NODES` dict, after `"SelvaDittoOptimizer"`:
```python
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
```
**Step 2: Verify registration**
```bash
python3 -c "
import sys; sys.path.insert(0, '/media/p5/Comfyui-Prismaudio')
from nodes import NODE_CLASS_MAPPINGS
keys = [k for k in NODE_CLASS_MAPPINGS if 'Dataset' in k]
print(keys)
"
```
Expected: list of 5 dataset node keys.
**Step 3: Final commit**
```bash
git add nodes/__init__.py
git commit -m "feat: register audio dataset pipeline nodes in __init__.py"
```
---
## Summary
5 nodes in `nodes/selva_dataset_pipeline.py`, all registered in `__init__.py`:
| Node | In | Out |
|------|----|-----|
| SelvaDatasetLoader | folder path | AUDIO_DATASET |
| SelvaDatasetResampler | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetLUFSNormalizer | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetInspector | AUDIO_DATASET | AUDIO_DATASET + STRING |
| SelvaDatasetItemExtractor | AUDIO_DATASET + index | AUDIO + name + total |
Dependencies: `pyloudnorm`, `soxr` — both confirmed present in the ComfyUI env.
+77
View File
@@ -0,0 +1,77 @@
{
"name": "alpha_scale_sweep",
"description": "Fix LoRA noise contamination (flatness 0.013→0.094 at alpha=rank). Root cause: alpha=rank (scale=1.0) at high rank drowns base model priors. Testing dramatically lower alpha to nudge rather than overwrite. All runs at lr=3e-4 (best stable LR from r128_sweet_spot).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/alpha_scale_sweep",
"base": {
"steps": 6000,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "g1_r16_alpha4",
"group": "conservative",
"description": "Back to basics: rank=16 alpha=4 (scale=0.25). Small adapter, gentle scale — cleanest possible LoRA signal.",
"rank": 16,
"alpha": 4.0
},
{
"id": "g1_r16_alpha16",
"group": "conservative",
"description": "rank=16 alpha=16 (scale=1.0) — the original default. Reference point: is the noise issue rank-specific or universal?",
"rank": 16,
"alpha": 16.0
},
{
"id": "g2_r32_alpha8",
"group": "mid",
"description": "rank=32 alpha=8 (scale=0.25). More capacity than r16 but still gentle scale.",
"rank": 32,
"alpha": 8.0
},
{
"id": "g2_r32_alpha32",
"group": "mid",
"description": "rank=32 alpha=32 (scale=1.0). Same rank, full scale — isolates whether scale or rank is causing noise.",
"rank": 32,
"alpha": 32.0
},
{
"id": "g3_r128_alpha8",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=8 (scale=0.0625). High capacity, very gentle contribution — can r128 stay clean at low alpha?",
"rank": 128,
"alpha": 8.0
},
{
"id": "g3_r128_alpha16",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=16 (scale=0.125). Slightly more signal than alpha=8.",
"rank": 128,
"alpha": 16.0
},
{
"id": "g3_r128_alpha32",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=32 (scale=0.25). Same scale as r16_alpha4 and r32_alpha8 — comparable contribution across ranks.",
"rank": 128,
"alpha": 32.0
}
]
}
+31
View File
@@ -0,0 +1,31 @@
{
"name": "bigvgan_disc_fm_retest",
"description": "Retest discriminator feature matching after bfloat16 dtype fix. Uses optimal config from overnight sweep (snake_alpha, GAFilter, lr=1e-4, phase=1.0, L2-SP=1e-3, 5000 steps).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_disc_fm_retest",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_5k_control",
"description": "Control: best config from overnight sweep without discriminator. Baseline for A/B comparison."
},
{
"id": "disc_fm_5k",
"description": "Discriminator feature matching at 5k steps. Tests if perceptual FM loss improves over mel+phase alone.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
}
]
}
@@ -0,0 +1,35 @@
{
"name": "bigvgan_optimized_dataset",
"description": "BigVGAN fine-tuning on optimized dataset (134 clips, 44.1kHz, LUFS-normalized). Standard mode (no LoRA) — trains decoder to faithfully reconstruct target domain audio from mel spectrograms. Uses optimal config from prior sweeps.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_optimized_dataset",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42
},
"experiments": [
{
"id": "standard_5k",
"description": "Standard mode: mel from clean FLAC → BigVGAN → reconstruct FLAC. No LoRA. Directly improves VAE roundtrip quality."
},
{
"id": "disc_fm_5k",
"description": "Standard mode + discriminator feature matching. Tests if perceptual loss helps on clean audio reconstruction.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "standard_10k",
"description": "Extended 10k steps. More data passes on 134 clips may extract more from the optimized dataset.",
"steps": 10000
}
]
}
+65
View File
@@ -0,0 +1,65 @@
{
"name": "bigvgan_overnight",
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
"base": {
"train_mode": "snake_alpha_only",
"steps": 3000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_3k_baseline",
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
},
{
"id": "snake_5k",
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
"steps": 5000
},
{
"id": "snake_no_gafilter",
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
"use_gafilter": false
},
{
"id": "snake_no_phase",
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
"lambda_phase": 0.0
},
{
"id": "snake_phase_2",
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
"lambda_phase": 2.0
},
{
"id": "snake_lr5e-5",
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
"lr": 5e-5,
"steps": 5000
},
{
"id": "snake_disc_fm",
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "all_2k_l2sp1e-2",
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
"train_mode": "all_params",
"steps": 2000,
"lr": 1e-5,
"lambda_l2sp": 1e-2
}
]
}
+39
View File
@@ -0,0 +1,39 @@
{
"name": "eval_r128_candidates",
"description": "Top candidates from r128_sweet_spot. Comparing the two lowest-loss runs, the stable lr=3e-4, and the curriculum run that hit 0.161 before regressing. Baseline included as perceptual reference.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_dir": "/media/unraid/davinci/Selva/BJ/evals/r128_candidates",
"steps": 25,
"seed": 42,
"adapters": [
{
"id": "baseline",
"description": "No LoRA — base model output for perceptual reference"
},
{
"id": "lr_5e4_r128",
"description": "Best loss overall (0.137), still descending at step 10k",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_5e4/adapter_final.pt"
},
{
"id": "lr_3e4_r256",
"description": "Tied with lr_5e4 at 0.139, higher rank — does extra capacity help perceptually?",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g4_r256_lr_3e4/adapter_final.pt"
},
{
"id": "lr_3e4_r128",
"description": "Stable plateau from step 4k to 10k (0.221) — visually confirmed clean spectrograms",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_3e4/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4",
"description": "Best min loss of all (0.161 at step 6k), regressed to 0.193 after curriculum switch — curious if the early checkpoint sounds better",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4_step6000",
"description": "Same run at its actual best step (before regression) — compare against adapter_final to hear the regression",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_step06000.pt"
}
]
}
+33
View File
@@ -0,0 +1,33 @@
{
"name": "lora_logit_cosine_combo",
"description": "Combine the two best findings from optimized dataset sweep: logit-normal timestep sampling + cosine LR schedule. Both individually outperformed baseline by large margins (56% and 68% lower loss). Tests if gains stack.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_logit_cosine_combo",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "logit_normal_cosine",
"description": "Logit-normal timesteps + cosine LR decay. Combining the two best individual improvements.",
"timestep_mode": "logit_normal",
"lr_schedule": "cosine"
},
{
"id": "logit_normal_control",
"description": "Control: logit-normal only (constant LR). Reproduces previous winner for direct comparison.",
"timestep_mode": "logit_normal"
}
]
}
+64
View File
@@ -0,0 +1,64 @@
{
"name": "lora_optimized_dataset",
"description": "LoRA training on optimized dataset (134 clips: resampled 44.1kHz, LUFS-normalized, spectral matched, HF smoothed, gain-augmented). Tests latent augmentation and schedule variants on top of known-best config (PiSSA, rank=128, lr=3e-4).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_optimized_dataset",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "baseline",
"description": "Control: known-best config (PiSSA r128 lr=3e-4) on the optimized dataset. No latent augmentation."
},
{
"id": "latent_mixup",
"description": "Latent mixup alpha=0.4 (MusicLDM). Tests if mixing training latents reduces memorization on 134 clips.",
"latent_mixup_alpha": 0.4
},
{
"id": "latent_noise",
"description": "Latent noise sigma=0.02. Mild Gaussian noise on training latents for regularization.",
"latent_noise_sigma": 0.02
},
{
"id": "mixup_and_noise",
"description": "Both latent mixup (0.4) and noise (0.02). Combined regularization.",
"latent_mixup_alpha": 0.4,
"latent_noise_sigma": 0.02
},
{
"id": "cosine_schedule",
"description": "Cosine LR decay. lr=3e-4 was stable with constant, but cosine may extract more from 5k steps.",
"lr_schedule": "cosine"
},
{
"id": "cosine_mixup",
"description": "Cosine LR + latent mixup. Best regularization combo candidate.",
"lr_schedule": "cosine",
"latent_mixup_alpha": 0.4
},
{
"id": "logit_normal",
"description": "Logit-normal timestep sampling (sigma=1.0). Concentrates training near t=0.5 where flow matching is hardest.",
"timestep_mode": "logit_normal"
},
{
"id": "curriculum_mixup",
"description": "Curriculum timesteps (logit_normal first 60%, then uniform) + latent mixup. Full regularization stack.",
"timestep_mode": "curriculum",
"latent_mixup_alpha": 0.4
}
]
}
+62
View File
@@ -0,0 +1,62 @@
{
"name": "pissa_sweep",
"description": "PiSSA vs standard init ablation at rank 128. Best prior config (lr=3e-4, bs=16, 10k steps) as baseline. PiSSA starts on-manifold via SVD init — should eliminate intruder dimensions. rsLoRA stabilises scaling at high rank.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": true
},
"experiments": [
{
"id": "standard_baseline",
"description": "Standard Kaiming init + classic alpha/rank scaling. Replicates best prior config for A/B comparison.",
"init_mode": "standard",
"use_rslora": false
},
{
"id": "pissa_rslora",
"description": "PiSSA init + rsLoRA scaling. Full Tier-S config. Should start on-manifold and avoid intruder dimensions."
},
{
"id": "pissa_classic_scale",
"description": "PiSSA init + classic alpha/rank scaling. Isolates PiSSA contribution from rsLoRA.",
"use_rslora": false
},
{
"id": "standard_rslora",
"description": "Standard init + rsLoRA only. Isolates rsLoRA contribution from PiSSA.",
"init_mode": "standard"
},
{
"id": "pissa_rslora_lr1e-4",
"description": "PiSSA+rsLoRA at lower lr=1e-4. PiSSA starts closer to optimum — may need less aggressive lr.",
"lr": 1e-4
},
{
"id": "pissa_rslora_lr5e-4",
"description": "PiSSA+rsLoRA at higher lr=5e-4. Test if on-manifold start tolerates faster learning.",
"lr": 5e-4
},
{
"id": "pissa_rslora_dropout",
"description": "PiSSA+rsLoRA with dropout 0.05. Note: PiSSA forces dropout=0 (principal components should not be dropped) — this tests standard init with rsLoRA + dropout.",
"init_mode": "standard",
"lora_dropout": 0.05
}
]
}
+103
View File
@@ -0,0 +1,103 @@
{
"name": "r128_sweet_spot",
"description": "Find the noise-free sweet spot on rank 128. LoRA+ ratio=16 caused noise — testing higher base LR without LoRA+ as a cleaner alternative. Target loss range 0.250.35. Also probing rank 256 since 102GB VRAM allows it.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r128_lr_2e4",
"group": "lr",
"description": "LR=2e-4. Conservative 2× step up from baseline — noise-free descent toward sweet spot.",
"lr": 2e-4
},
{
"id": "g1_r128_lr_3e4",
"group": "lr",
"description": "LR=3e-4. 3× baseline — landed at 0.41 on r64, should reach 0.250.35 on r128.",
"lr": 3e-4
},
{
"id": "g1_r128_lr_5e4",
"group": "lr",
"description": "LR=5e-4. Aggressive but no LoRA+ B-matrix asymmetry — cleaner noise profile.",
"lr": 5e-4
},
{
"id": "g2_r128_curriculum",
"group": "curriculum",
"description": "Curriculum only at baseline LR. Clean slow descent — reference for what curriculum contributes alone.",
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum",
"group": "curriculum",
"description": "LR=3e-4 + curriculum. Speed of higher LR with coverage of curriculum — no LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum_dropout",
"group": "curriculum",
"description": "LR=3e-4 + curriculum + dropout=0.05. Full controlled stack without LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum",
"lora_dropout": 0.05
},
{
"id": "g3_r128_lora_plus_4",
"group": "lora_plus",
"description": "LoRA+ ratio=4 (lr_B=4e-4). Much more conservative than ratio=16 — tests if noise came from ratio not the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g4_r256_baseline",
"group": "rank256",
"description": "Rank 256 at baseline LR. 102GB VRAM makes this viable — does more capacity keep helping?",
"rank": 256
},
{
"id": "g4_r256_lr_3e4",
"group": "rank256",
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
"rank": 256,
"lr": 3e-4
},
{
"id": "g5_r128_lr_2e4_cosine",
"group": "cosine",
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 60008000 by decaying LR to ~0 instead of staying flat.",
"lr": 2e-4,
"lr_schedule": "cosine"
},
{
"id": "g5_r128_lr_3e4_cosine",
"group": "cosine",
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
"lr": 3e-4,
"lr_schedule": "cosine"
}
]
}
+130
View File
@@ -0,0 +1,130 @@
{
"name": "r64_overnight",
"description": "Focused rank-64 overnight sweep. All experiments use rank 64 as base — confirmed best from tier1_thorough early results. 8000 steps to reach convergence (none converged at 4000).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r64_overnight",
"base": {
"steps": 8000,
"rank": 64,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r64_baseline",
"group": "rank",
"description": "Rank 64 baseline — clean reference at 8000 steps."
},
{
"id": "g1_r128_baseline",
"group": "rank",
"description": "Rank 128 — 102GB VRAM makes this free. Does doubling rank from 64 help further?",
"rank": 128
},
{
"id": "g2_r64_alpha_32",
"group": "alpha",
"description": "Rank 64 alpha=32 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 32.0
},
{
"id": "g2_r64_alpha_16",
"group": "alpha",
"description": "Rank 64 alpha=16 (scale=0.25). More aggressive scale reduction — may over-constrain.",
"alpha": 16.0
},
{
"id": "g3_r64_lora_plus",
"group": "regularisation",
"description": "LoRA+ ratio=16. lr_B = 16 × lr_A. Faster convergence at constant step budget.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_r64_dropout_0.05",
"group": "regularisation",
"description": "Dropout=0.05. Light sparsity regularisation on LoRA path.",
"lora_dropout": 0.05
},
{
"id": "g3_r64_dropout_0.1",
"group": "regularisation",
"description": "Dropout=0.1. Stronger regularisation — tests if 49 clips needs heavier constraint.",
"lora_dropout": 0.1
},
{
"id": "g3_r64_curriculum",
"group": "regularisation",
"description": "Curriculum sampling: logit_normal for steps 1-4800, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_r64_lr_low",
"group": "lr",
"description": "LR=3e-5. 3× lower — checks if 1e-4 is overshooting at rank 64.",
"lr": 3e-5
},
{
"id": "g4_r64_lr_high",
"group": "lr",
"description": "LR=3e-4. 3× higher — may converge faster but risk instability.",
"lr": 3e-4
},
{
"id": "g5_r64_target_full",
"group": "target",
"description": "Rank 64 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g5_r128_target_full",
"group": "target",
"description": "Rank 128 + full target. Maximum possible coverage with available VRAM.",
"rank": 128,
"target": "attn.qkv linear1"
},
{
"id": "g6_r64_full_tier1",
"group": "combined",
"description": "All Tier 1 at rank 64: LoRA+ 16 + dropout 0.05 + curriculum. Full stack at 8000 steps.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r64_alpha32_full",
"group": "combined",
"description": "Rank 64 alpha=32 + all Tier 1. Best alpha scaling + best regularisation stack.",
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r128_full_tier1",
"group": "combined",
"description": "Rank 128 + all Tier 1. Tests if more capacity + regularisation beats rank 64 full.",
"rank": 128,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
}
]
}
+52
View File
@@ -0,0 +1,52 @@
{
"name": "ti_sweep_1",
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
"base": {
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_text": "",
"lr": 2e-4,
"n_tokens": 4,
"inject_mode": "suffix"
},
"experiments": [
{
"id": "n4_baseline",
"group": "reference",
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
},
{
"id": "n4_clamped",
"group": "norm_clamp",
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
},
{
"id": "n4_prefix_clamped",
"group": "norm_clamp",
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
"inject_mode": "prefix"
},
{
"id": "n8_prefix_clamped",
"group": "norm_clamp",
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
"n_tokens": 8,
"inject_mode": "prefix"
},
{
"id": "n4_prefix_warm_clamped",
"group": "norm_clamp",
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
"inject_mode": "prefix",
"init_text": "mechanical impact sound design"
}
]
}
+61
View File
@@ -0,0 +1,61 @@
{
"name": "tier1_sweep",
"description": "Ablation of Tier 1 improvements: LoRA+, dropout, curriculum sampling. Baseline = uniform, no regularisation.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "lora_sweeps/tier1_sweep",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "baseline",
"description": "Standard LoRA — no Tier 1 changes. Reference point."
},
{
"id": "lora_plus_16",
"description": "LoRA+ only: lr_B = 16 * lr_A. Should converge faster in early steps.",
"lora_plus_ratio": 16.0
},
{
"id": "dropout_0.05",
"description": "LoRA dropout 0.05 only. Light regularisation for 49-clip dataset.",
"lora_dropout": 0.05
},
{
"id": "dropout_0.1",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "curriculum",
"description": "Curriculum sampling only: logit_normal for steps 1-2400, then uniform. Should improve convergence vs pure uniform.",
"timestep_mode": "curriculum"
},
{
"id": "full_tier1",
"description": "All Tier 1 combined: LoRA+ + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "rank_64",
"description": "Rank 64 baseline — MMAudio LoRA guide default. More expressive adapter for 49-clip dataset.",
"rank": 64
}
]
}
+144
View File
@@ -0,0 +1,144 @@
{
"name": "tier1_thorough",
"description": "Full overnight Tier 1 ablation on 49-clip BJ dataset. 4 groups: rank, alpha, regularisation, and best combinations. ~10-12h depending on GPU.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/tier1_thorough",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 1000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_rank_16",
"group": "rank",
"description": "Rank 16 baseline — reference point for all groups."
},
{
"id": "g1_rank_32",
"group": "rank",
"description": "Rank 32 — midpoint. Does doubling rank improve quality without overfitting?",
"rank": 32
},
{
"id": "g1_rank_64",
"group": "rank",
"description": "Rank 64 — MMAudio LoRA guide default. Maximum expressiveness at 49 clips.",
"rank": 64
},
{
"id": "g2_alpha_half_r16",
"group": "alpha",
"description": "Alpha=8 with rank 16 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 8.0
},
{
"id": "g2_alpha_half_r64",
"group": "alpha",
"description": "Alpha=32 with rank 64 (scale=0.5). Best-practice scaling for high-rank adapters.",
"rank": 64,
"alpha": 32.0
},
{
"id": "g3_lora_plus_4",
"group": "regularisation",
"description": "LoRA+ ratio=4 — conservative asymmetric LR. Lower bound for the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g3_lora_plus_16",
"group": "regularisation",
"description": "LoRA+ ratio=16 — standard from FLUX LoRA literature. Faster early convergence.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_dropout_0.05",
"group": "regularisation",
"description": "LoRA dropout 0.05 only. Light sparsity regularisation (arXiv:2404.09610).",
"lora_dropout": 0.05
},
{
"id": "g3_dropout_0.1",
"group": "regularisation",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "g3_curriculum",
"group": "regularisation",
"description": "Curriculum sampling only: logit_normal steps 1-2400, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r16",
"group": "combined",
"description": "All Tier 1 at rank 16: LoRA+ 16 + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r64",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32. Best expressiveness + best regularisation.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g5_lr_low",
"group": "lr",
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
"lr": 3e-5
},
{
"id": "g5_lr_high",
"group": "lr",
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
"lr": 3e-4
},
{
"id": "g6_target_full_r16",
"group": "target",
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g6_target_full_r64",
"group": "target",
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
"rank": 64,
"alpha": 32.0,
"target": "attn.qkv linear1"
},
{
"id": "g4_full_r64_6k",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum",
"steps": 6000,
"save_every": 1000
}
]
}
+30
View File
@@ -0,0 +1,30 @@
{
"name": "vocoder_finetune",
"description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "r128_lr_3e4_bj_vocoder",
"description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4."
}
]
}
+34 -6
View File
@@ -2,11 +2,39 @@ NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {}
_NODES = { _NODES = {
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"), "SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"), "SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), "SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), "SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), "SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
"SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"),
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetCompressor": (".selva_dataset_pipeline", "SelvaDatasetCompressor", "SelVA Dataset Compressor"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
@@ -16,4 +44,4 @@ for key, (module_path, class_name, display_name) in _NODES.items():
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name) NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
print(f"[PrismAudio] Skipping {key}: {e}") print(f"[SelVA] Skipping {key}: {e}")
-207
View File
@@ -1,207 +0,0 @@
import os
import sys
import hashlib
import subprocess
import tempfile
import torch
from .utils import PRISMAUDIO_CATEGORY
from .feature_loader import PrismAudioFeatureLoader
# Managed venv created automatically when python_env is left as default
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
_EXTRACT_PACKAGES = [
"torch", "torchaudio", "torchvision",
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
"tensorflow-cpu>=2.16.0",
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
"jax[cuda13]", "flax",
"transformers", "decord", "einops", "numpy", "mediapy",
"git+https://github.com/google-deepmind/videoprism.git",
]
def _pip_install(pip, *packages, label=None):
"""Install one or more packages with visible output; raise on failure."""
tag = label or packages[0]
print(f"[PrismAudio] installing {tag} ...", flush=True)
result = subprocess.run(
[pip, "install", "--progress-bar", "on"] + list(packages),
capture_output=False,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
"See pip output above for details."
)
print(f"[PrismAudio] {tag} OK", flush=True)
def _ensure_extract_env():
"""Create and populate the managed venv on first use."""
if os.path.exists(_MANAGED_PYTHON):
return _MANAGED_PYTHON
import shutil
if os.path.exists(_MANAGED_VENV):
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
shutil.rmtree(_MANAGED_VENV)
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
print("[PrismAudio] Upgrading pip...", flush=True)
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
total = len(_EXTRACT_PACKAGES)
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
_pip_install(pip, pkg, label=label)
print("[PrismAudio] Feature-extraction env ready.", flush=True)
return _MANAGED_PYTHON
def _hash_inputs(video_tensor, cot_text):
"""Create a hash of the inputs for caching."""
h = hashlib.sha256()
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
h.update(cot_text.encode())
return h.hexdigest()[:16]
def _save_frames_to_npy(video_tensor, output_path):
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
Lossless — avoids H.264 encode/decode roundtrip.
"""
import numpy as np
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
np.save(output_path, frames_np)
class PrismAudioFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": ("IMAGE",),
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
},
"optional": {
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
RETURN_NAMES = ("features", "fps")
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
# Resolve fps from VHS video_info if connected
if video_info is not None:
fps = video_info["loaded_fps"]
# Resolve python binary
if python_env == "comfyui_env":
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
python_bin = sys.executable
else:
python_bin = _ensure_extract_env()
# Determine cache directory
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
os.makedirs(cache_dir, exist_ok=True)
# Check cache
cache_hash = _hash_inputs(video, caption_cot)
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
if os.path.exists(cached_path):
print(f"[PrismAudio] Using cached features: {cached_path}")
loader = PrismAudioFeatureLoader()
features, = loader.load_features(cached_path)
return (features, float(fps))
# Save frames to temp file (lossless .npy, no codec roundtrip)
import time
t0 = time.perf_counter()
frames = video.shape[0]
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
tmp_video = tmp.name
_save_frames_to_npy(video, tmp_video)
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
# Build subprocess command
script_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"scripts", "extract_features.py"
)
import folder_paths
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
if not os.path.exists(synchformer_ckpt):
raise RuntimeError(
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
)
cmd = [
python_bin,
script_path,
"--video", tmp_video,
"--cot_text", caption_cot,
"--output", cached_path,
"--source_fps", str(fps),
"--synchformer_ckpt", synchformer_ckpt,
]
# Build env: inherit current env, inject HF token if provided
import copy
env = copy.copy(os.environ)
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
if token:
env["HF_TOKEN"] = token
env["HUGGING_FACE_HUB_TOKEN"] = token
else:
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
try:
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
result = subprocess.run(
cmd,
capture_output=False,
timeout=600, # 10 minute timeout
env=env,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
"See output above for details."
)
print("[PrismAudio] Feature extraction subprocess finished successfully.")
finally:
if os.path.exists(tmp_video):
os.unlink(tmp_video)
# Load the extracted features
loader = PrismAudioFeatureLoader()
features, = loader.load_features(cached_path)
return (features, float(fps))
-53
View File
@@ -1,53 +0,0 @@
import os
import numpy as np
import torch
from .utils import PRISMAUDIO_CATEGORY
# Keys consumed by the conditioners (video_features, text_features, sync_features)
# global_video_features and global_text_features are NOT consumed by any conditioner
# in the prismaudio.json config — they are unused.
REQUIRED_KEYS = [
"video_features",
"text_features",
"sync_features",
]
class PrismAudioFeatureLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "load_features"
CATEGORY = PRISMAUDIO_CATEGORY
def load_features(self, npz_path):
if not os.path.exists(npz_path):
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
data = np.load(npz_path, allow_pickle=True)
features = {}
for key in REQUIRED_KEYS:
if key in data:
features[key] = torch.from_numpy(data[key]).float()
else:
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
if key == "sync_features":
features[key] = torch.zeros(8, 768)
else:
features[key] = torch.zeros(1, 1024)
# Load duration if present
if "duration" in data:
features["duration"] = float(data["duration"])
return (features,)
-154
View File
@@ -1,154 +0,0 @@
import os
import json
import torch
import folder_paths
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
get_device, get_offload_device, determine_precision, determine_offload_strategy,
soft_empty_cache, resolve_hf_token,
)
# HuggingFace repo for auto-download
HF_REPO_ID = "FunAudioLLM/PrismAudio"
REQUIRED_FILES = {
"diffusion": "prismaudio.ckpt",
"vae": "vae.ckpt",
"synchformer": "synchformer_state_dict.pth",
}
def _download_if_missing(filename, model_dir, hf_token=None):
"""Download a model file from HuggingFace if not present locally."""
filepath = os.path.join(model_dir, filename)
if os.path.exists(filepath):
return filepath
from huggingface_hub import hf_hub_download
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
try:
downloaded = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=model_dir,
token=hf_token or None,
)
return downloaded
except Exception as e:
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
raise RuntimeError(
f"[PrismAudio] Model '{filename}' requires license acceptance. "
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
f"then set HF_TOKEN env var or run: huggingface-cli login"
) from e
raise
class PrismAudioModelLoader:
@classmethod
def INPUT_TYPES(cls):
register_model_folder()
return {
"required": {
"precision": (["auto", "fp32", "fp16", "bf16"],),
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
},
}
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = PRISMAUDIO_CATEGORY
def load_model(self, precision, offload_strategy):
device = get_device()
dtype = determine_precision(precision, device)
strategy = determine_offload_strategy(offload_strategy)
token = resolve_hf_token()
model_dir = get_prismaudio_model_dir()
# Auto-download missing files
for key, filename in REQUIRED_FILES.items():
_download_if_missing(filename, model_dir, hf_token=token)
# Load config
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"prismaudio_core", "configs", "prismaudio.json"
)
with open(config_path) as f:
model_config = json.load(f)
# Create model from config
from prismaudio_core.factory import create_model_from_config
model = create_model_from_config(model_config)
# Load diffusion weights
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
if "state_dict" in diffusion_state:
diffusion_state = diffusion_state["state_dict"]
diff_result = model.load_state_dict(diffusion_state, strict=False)
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
if diff_result.missing_keys:
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
if diff_result.unexpected_keys:
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
# Sample a few ckpt keys to verify prefix alignment
sample_keys = list(diffusion_state.keys())[:5]
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
# Load VAE weights separately
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
vae_full_state = comfy.utils.load_torch_file(vae_path)
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
# Sample raw keys to see actual prefix
vae_sample_keys = list(vae_full_state.keys())[:8]
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
# Strip "autoencoder." prefix from keys
vae_state = {}
prefix = "autoencoder."
for k, v in vae_full_state.items():
if k.startswith(prefix):
vae_state[k[len(prefix):]] = v
else:
vae_state[k] = v
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
# Sample model keys to compare
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
# strict=False: vae.ckpt is a training checkpoint that also contains
# discriminator, loss modules, and EMA wrappers not present in the
# inference AudioAutoencoder — ignore those extra keys.
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
# (AutoencoderPretransform.load_state_dict doesn't return the result)
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
if vae_result.missing_keys:
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
# Apply precision: DiT + conditioners in user-selected dtype,
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
model.model.to(dtype) # DiTWrapper
model.conditioner.to(dtype) # MultiConditioner
# model.pretransform stays in fp32
if strategy == "keep_in_vram":
model = model.to(device)
else:
model = model.to(get_offload_device())
model.eval()
return ({
"model": model,
"dtype": dtype,
"strategy": strategy,
"config": model_config,
"model_dir": model_dir,
},)
-165
View File
@@ -1,165 +0,0 @@
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache,
)
class PrismAudioSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"features": ("PRISMAUDIO_FEATURES",),
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, features, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
# Resolve duration: 0 means use video duration from features
if duration <= 0:
if "duration" not in features:
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
duration = features["duration"]
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
# Compute latent dimensions
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Note: no seq length config needed — the model adapts to input tensor shapes
# dynamically via its transformer architecture.
# Determine if video features are present (not all zeros)
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
video_feat = features["video_features"].to(device, dtype=dtype)
sync_feat = features["sync_features"].to(device, dtype=dtype)
# Build metadata as a TUPLE of dicts (one per batch sample)
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
sample_meta = {
"video_features": video_feat,
"text_features": features["text_features"].to(device, dtype=dtype),
"sync_features": sync_feat,
"video_exist": torch.tensor(has_video),
}
metadata = (sample_meta,)
# Move model to device if offloaded
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
# Run conditioning
conditioning = diffusion.conditioner(metadata, device)
# Handle missing video: substitute learned empty embeddings
if not has_video:
_substitute_empty_features(diffusion, conditioning, device, dtype)
# Assemble conditioning inputs for the DiT
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
# Sample with progress bar
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
fakes_f = fakes.float()
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
# Offload diffusion model and conditioner before VAE decode
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes_f)
# Offload VAE
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp (matching reference: div by max abs before clamp)
audio = audio.float()
pre_norm_std = audio.std().item()
pre_norm_peak = audio.abs().max().item()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
def _substitute_empty_features(diffusion, conditioning, device, dtype):
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
near-zero activations, NOT the learned null signal the model was trained with.
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
"""
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
# Substitute video_features with learned empty_clip_feat
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
batch_size = conditioning['video_features'][0].shape[0]
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
conditioning['video_features'][0] = empty_expanded
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
# Substitute sync_features with learned empty_sync_feat
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
batch_size = conditioning['sync_features'][0].shape[0]
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
conditioning['sync_features'][0] = empty_expanded
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
@@ -0,0 +1,201 @@
"""SelVA Activation Steering Extractor.
Computes per-block steering vectors by running the frozen generator on the
training dataset and recording how target style's conditioning shifts the DiT hidden
states vs. empty/unconditional conditioning.
For each block i:
steering[i] = mean(latent_hidden | target style conditions)
- mean(latent_hidden | empty conditions)
The resulting vectors are injected at inference time (via SelVA Sampler's
steering_strength input) to nudge the denoising trajectory toward target style's
activation patterns without modifying any model weights.
"""
import random
from pathlib import Path
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import _prepare_dataset
def _collect_activations(generator, conditions, latent, t_tensor):
"""Run one predict_flow call, collecting latent hidden states per block.
Returns a list of [seq, hidden_dim] float32 CPU tensors,
one per block (joint_blocks first, then fused_blocks).
"""
activations = []
def make_hook(is_joint):
def hook(module, input, output):
h = output[0] if is_joint else output
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
return hook
handles = []
for block in generator.joint_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
for block in generator.fused_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
try:
with torch.no_grad():
generator.predict_flow(latent, t_tensor, conditions)
finally:
for h in handles:
h.remove()
return activations # list of n_blocks tensors [seq, hidden]
class SelvaActivationSteeringExtractor:
"""Computes activation steering vectors from a training dataset.
Runs the frozen generator on N clips at random timesteps with both
target style-conditioned and empty-conditioned inputs, then saves the mean
difference per DiT block to a .pt file.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "extract"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("steering_path",)
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
DESCRIPTION = (
"Computes per-block activation steering vectors: mean(target style activations) "
"mean(empty activations) at each DiT block. Load the result with "
"SelVA Activation Steering Loader and connect to the Sampler."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
}),
"output_path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
}),
"n_samples": ("INT", {
"default": 16, "min": 1, "max": 256,
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
}
def extract(self, model, data_dir, output_path, n_samples, seed):
device = get_device()
dtype = model["dtype"]
seq_cfg = model["seq_cfg"]
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
print(f"[Steering] data_dir = {data_dir}", flush=True)
print(f"[Steering] output = {out_path}\n", flush=True)
dataset = _prepare_dataset(model, data_dir, device)
generator = model["generator"]
generator.eval()
torch.manual_seed(seed)
random.seed(seed)
indices = random.choices(range(len(dataset)), k=n_samples)
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
style_sums = [None] * n_blocks
empty_sums = [None] * n_blocks
counts = [0] * n_blocks
pbar = comfy.utils.ProgressBar(n_samples)
for sample_i, clip_idx in enumerate(indices):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
clip_latent_seq_len = x1_cpu.shape[1]
generator.update_seq_lengths(
latent_seq_len=clip_latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
# Random timestep and noise latent for this clip
t_val = torch.rand(1).item()
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
latent = torch.randn(
1, clip_latent_seq_len, generator.latent_dim,
device=device, dtype=dtype,
)
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
if style_sums[i] is None:
style_sums[i] = st.clone()
empty_sums[i] = em.clone()
else:
style_sums[i] += st
empty_sums[i] += em
counts[i] += 1
pbar.update(1)
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
# Steering vector per block: mean(target style) - mean(empty)
steering_vectors = []
for i in range(n_blocks):
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
steering_vectors.append(vec)
norm = vec.norm().item()
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
n_joint = len(generator.joint_blocks)
payload = {
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
"n_blocks": n_blocks,
"n_joint": n_joint,
"n_fused": len(generator.fused_blocks),
"latent_seq_len": seq_cfg.latent_seq_len,
"n_samples": n_samples,
"seed": seed,
"mode": model["mode"],
"variant": model["variant"],
}
torch.save(payload, str(out_path))
print(f"\n[Steering] Saved: {out_path}", flush=True)
soft_empty_cache()
return (str(out_path),)
+62
View File
@@ -0,0 +1,62 @@
"""SelVA Activation Steering Loader.
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaActivationSteeringLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("STEERING_VECTORS",)
RETURN_NAMES = ("steering_vectors",)
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
DESCRIPTION = (
"Loads activation steering vectors from a .pt file produced by "
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
"denoising toward the target activation patterns."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[Steering] File not found: {p}")
payload = torch.load(str(p), map_location="cpu", weights_only=False)
n_blocks = payload["n_blocks"]
n_joint = payload["n_joint"]
n_fused = payload["n_fused"]
n_vecs = len(payload["steering_vectors"])
print(f"[Steering] Loaded: {p}", flush=True)
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
f"latent_seq_len={payload['latent_seq_len']} "
f"n_samples={payload['n_samples']}", flush=True)
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
mean_norm = sum(norms) / len(norms)
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
return (payload,)
+153
View File
@@ -0,0 +1,153 @@
"""SelVA Audio Post-Processing nodes.
Post-generation enhancement applied to standard AUDIO outputs:
SelvaHarmonicExciter — multi-band harmonic exciter (HPF → tanh → mix)
SelvaOutputNormalizer — LUFS normalization + true peak limiting
"""
import numpy as np
import torch
from .utils import SELVA_CATEGORY
class SelvaHarmonicExciter:
"""Multi-band harmonic exciter for post-generation enhancement.
Isolates high-frequency content above a cutoff, applies tanh saturation
to generate 2nd/3rd harmonics, then mixes back with the dry signal.
Restores harmonic richness lost during BigVGAN vocoder reconstruction.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 3000.0, "min": 500.0, "max": 16000.0, "step": 100.0,
"tooltip": "Highpass cutoff frequency in Hz. Only content above this is excited. "
"3000 Hz targets the upper harmonics BigVGAN tends to smear.",
}),
"drive": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.5,
"tooltip": "Saturation drive. Higher = more harmonics generated. "
"2-3 is subtle, 5+ is aggressive.",
}),
"mix": ("FLOAT", {
"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Wet/dry blend. 0.1-0.2 is subtle enhancement, "
"0.5+ is aggressive harmonic addition.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "excite"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Multi-band harmonic exciter. Applies tanh saturation to the high-frequency band "
"to restore harmonics lost during BigVGAN vocoder reconstruction. "
"Uses pedalboard.HighpassFilter for band isolation."
)
def excite(self, audio, cutoff_hz: float, drive: float, mix: float):
from pedalboard import Pedalboard, HighpassFilter
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
wav_np = wav.float().numpy() # [C, T]
# Isolate HF band
board = Pedalboard([HighpassFilter(cutoff_frequency_hz=cutoff_hz)])
hf = board(wav_np, sr) # [C, T]
# Tanh saturation — normalize by drive so output stays in [-1, 1]
excited = np.tanh(hf * drive) / max(drive, 1.0)
# Mix back with dry
mixed = wav_np + mix * excited
# Soft clip to prevent going over
mixed = np.tanh(mixed)
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, T]
print(
f"[HarmonicExciter] cutoff={cutoff_hz}Hz drive={drive} mix={mix:.0%}",
flush=True,
)
return ({"waveform": wav_out, "sample_rate": sr},)
class SelvaOutputNormalizer:
"""Normalize generated audio to a target LUFS level with true peak limiting.
Apply as the final node before saving — brings generated audio to a
consistent loudness target regardless of input video loudness variation.
Uses pyloudnorm (BS.1770-4).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"target_lufs": ("FLOAT", {
"default": -14.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. "
"-14 LUFS for streaming (Spotify/YouTube), "
"-9 to -7 for production masters.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP applied after LUFS gain.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize output audio to a target LUFS level (BS.1770-4) with true peak limiting. "
"Apply as the last node before saving. Uses pyloudnorm."
)
def normalize(self, audio, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
wav_np = wav.permute(1, 0).double().numpy() # [T, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [T] mono
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(wav_np)
if not np.isfinite(loudness):
print("[OutputNormalizer] Could not measure loudness — clip too short or silent. Passing through.", flush=True)
return (audio,)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_out = wav * gain_linear
peak = wav_out.abs().max().item()
if peak > tp_linear:
wav_out = wav_out * (tp_linear / peak)
print(
f"[OutputNormalizer] {loudness:.1f} LUFS → {target_lufs} LUFS "
f"gain={gain_db:+.1f}dB TP={true_peak_dbtp}dBTP",
flush=True,
)
return ({"waveform": wav_out.unsqueeze(0), "sample_rate": sr},)
+293
View File
@@ -0,0 +1,293 @@
"""SelVA Audio Preprocessors — condition training clips for codec compatibility.
Two nodes that reduce the domain mismatch between custom training audio and the
MMAudio VAE's expected spectral distribution, improving LoRA training quality:
SelvaHfSmoother — soft low-pass blend to attenuate extreme HF content
SelvaSpectralMatcher — adaptive per-band EQ toward the codec's training distribution
Root cause they address: MMAudio was trained on natural sounds (speech, foley, env)
with limited engineered HF content. The BigVGANv2 vocoder (frozen, pre-trained) handles
the codec's HF reconstruction poorly for sound design / music training clips, because
those clips land in a latent-space region the vocoder never saw during training.
Recommended order: SpectralMatcher → HfSmoother → feature extraction → LoRA training.
"""
import numpy as np
import torch
import torchaudio.functional as AF
from .utils import SELVA_CATEGORY
# ── Mel filterbank (same algorithm as selva_core/ext/mel_converter.py) ────────
def _mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float, fmax: float) -> torch.Tensor:
"""Returns mel filterbank matrix [n_mels, n_fft//2+1]."""
def hz_to_mel(f):
return 2595.0 * np.log10(1.0 + np.asarray(f) / 700.0)
def mel_to_hz(m):
return 700.0 * (10.0 ** (np.asarray(m) / 2595.0) - 1.0)
n_freqs = n_fft // 2 + 1
fft_freqs = np.linspace(0.0, sr / 2.0, n_freqs)
mel_pts = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
hz_pts = mel_to_hz(mel_pts)
fb = np.zeros((n_mels, n_freqs), dtype=np.float32)
for m in range(1, n_mels + 1):
lo, mid, hi = hz_pts[m - 1], hz_pts[m], hz_pts[m + 1]
up = (fft_freqs - lo) / (mid - lo + 1e-12)
down = (hi - fft_freqs) / (hi - mid + 1e-12)
fb[m - 1] = np.maximum(0.0, np.minimum(up, down))
return torch.from_numpy(fb)
# ── VAE target log-mel means (source: selva_core/ext/autoencoder/vae.py) ──────
# These are the per-band expected log-mel energy means from MMAudio's training data.
# Used as the spectral matching target: clips are EQ'd to match this profile.
_TARGET_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439,
-1.2922, -1.2927, -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912,
-1.4313, -1.4152, -1.4527, -1.4728, -1.4568, -1.5101, -1.5051, -1.5172,
-1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, -1.6081, -1.6331,
-1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377,
-1.8417, -1.8643, -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673,
-1.9824, -2.0042, -2.0215, -2.0436, -2.0766, -2.1064, -2.1418, -2.1855,
-2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, -2.4659, -2.5072,
-2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673,
]
_TARGET_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006,
-2.2357, -2.4597, -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047,
-2.7483, -2.5926, -2.7462, -2.7033, -2.7386, -2.8112, -2.7502, -2.9594,
-2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, -3.1191, -2.9893,
-3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509,
-3.5089, -3.4647, -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747,
-3.7072, -3.7279, -3.7283, -3.7795, -3.8259, -3.8447, -3.8663, -3.9182,
-3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, -4.1488, -4.1874,
-4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053,
-5.4927, -5.5712, -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103,
-6.0955, -6.1673, -6.2362, -6.3120, -6.3926, -6.4797, -6.5565, -6.6511,
-6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, -7.6136, -7.7469,
-7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861,
]
_MEL_CONFIGS = {
"16k": dict(sr=16_000, n_fft=1024, n_mels=80, hop=256, fmin=0, fmax=8_000,
target=_TARGET_MEAN_80D, log10=True),
"44k": dict(sr=44_100, n_fft=2048, n_mels=128, hop=512, fmin=0, fmax=22_050,
target=_TARGET_MEAN_128D, log10=False),
}
# ── Node 1: HF Smoother ───────────────────────────────────────────────────────
class SelvaHfSmoother:
"""Soft high-frequency attenuation for LoRA training clip preprocessing.
Blends a low-pass filtered copy of the audio with the original. Attenuates
the extreme HF content common in engineered sound design that the BigVGANv2
vocoder handles poorly, bringing the clip closer to the spectral region the
MMAudio codec was trained on (natural sounds with limited HF energy).
A blend of 0.7 at 12 kHz is a transparent starting point — audible only on
close comparison. Increase blend or lower cutoff if roundtrip quality is still
poor after spectral matching.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered. 0.7 is a transparent starting point.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Blends a low-pass filtered version of the audio with the original to gently attenuate "
"high-frequency content that the SelVA codec handles poorly. "
"Use before feature extraction to improve LoRA training targets. "
"Run after SelVA Spectral Matcher for best results."
)
def process(self, audio, cutoff_hz: float, blend: float):
waveform = audio["waveform"].float() # [1, C, L]
sr = audio["sample_rate"]
filtered = AF.lowpass_biquad(waveform, sr, cutoff_hz)
out = blend * filtered + (1.0 - blend) * waveform
# Preserve RMS level — LPF removes energy, keep the clip at its original loudness
rms_in = waveform.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = out.pow(2).mean().sqrt().clamp(min=1e-8)
out = out * (rms_in / rms_out)
peak = out.abs().max()
if peak > 1.0:
out = out / peak
print(f"[HF Smoother] cutoff={cutoff_hz:.0f} Hz blend={blend:.2f} "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f} "
f"peak={out.abs().max():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr},)
# ── Node 2: Spectral Matcher ──────────────────────────────────────────────────
class SelvaSpectralMatcher:
"""Adaptive per-band EQ toward the SelVA VAE's expected spectral distribution.
Computes the log-mel energy profile of the clip and compares it to the per-band
means stored in the VAE's normalization buffers (the statistics MMAudio was trained
on). Applies a smooth frequency-domain gain correction so the clip's spectral shape
matches what the codec expects, improving encode→decode roundtrip quality and
therefore LoRA training target quality.
The correction is additive in log space (multiplicative in linear), so it only
changes spectral balance — not the waveform's timing or phase structure.
max_gain_db clamps the correction to prevent extreme boosts on very quiet bands.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution. "
"0.8 is a good starting point.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB. Prevents extreme boosts on "
"very quiet frequency bands. 12 dB is conservative.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Applies a smooth per-band gain correction to bring the audio's spectral profile "
"in line with the MMAudio VAE's expected distribution, derived from the per-band "
"normalization statistics baked into the VAE weights. "
"Use before feature extraction to improve LoRA training target quality. "
"Run before SelVA HF Smoother."
)
def process(self, audio, mode: str, strength: float, max_gain_db: float):
cfg = _MEL_CONFIGS[mode]
waveform = audio["waveform"].float() # [1, C, L]
sr_in = audio["sample_rate"]
sr_tgt = cfg["sr"]
n_fft = cfg["n_fft"]
hop = cfg["hop"]
# ── flatten to mono and resample if needed ────────────────────────────
wav = waveform[0].mean(0) # [L]
if sr_in != sr_tgt:
wav = AF.resample(wav.unsqueeze(0), sr_in, sr_tgt).squeeze(0)
device = wav.device
window = torch.hann_window(n_fft, device=device)
# ── STFT ──────────────────────────────────────────────────────────────
stft = torch.stft(wav, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True, return_complex=True) # [n_freqs, T]
mag = stft.abs() # [n_freqs, T]
# ── current log-mel mean per band ─────────────────────────────────────
fb = _mel_filterbank(sr_tgt, n_fft, cfg["n_mels"],
cfg["fmin"], cfg["fmax"]).to(device) # [n_mels, n_freqs]
mel_mag = torch.matmul(fb, mag).clamp(min=1e-5) # [n_mels, T]
if cfg["log10"]:
mel_log = torch.log10(mel_mag)
else:
mel_log = torch.log(mel_mag)
current_mean = mel_log.mean(dim=-1) # [n_mels]
target_mean = torch.tensor(cfg["target"], device=device) # [n_mels]
# ── per-mel-band gain (log space) ─────────────────────────────────────
mel_gain = (target_mean - current_mean) * strength # [n_mels]
# Clamp to ±max_gain_db
if cfg["log10"]:
max_log = max_gain_db / 20.0 # log10: 20 log10 = dB
else:
max_log = max_gain_db / 8.6859 # ln: 20 * log10(e) ≈ 8.686
mel_gain = mel_gain.clamp(-max_log, max_log)
# ── map mel gains → STFT frequency bins (weighted average) ────────────
fb_sum = fb.sum(0).clamp(min=1e-8) # [n_freqs]
freq_gain = (mel_gain @ fb) / fb_sum # [n_freqs]
if cfg["log10"]:
linear_gain = 10.0 ** freq_gain # [n_freqs]
else:
linear_gain = torch.exp(freq_gain) # [n_freqs]
# ── apply gain in frequency domain and reconstruct ───────────────────
stft_out = stft * linear_gain.unsqueeze(-1) # [n_freqs, T]
wav_out = torch.istft(stft_out, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True,
length=wav.shape[0]) # [L]
# ── resample back to original sr ──────────────────────────────────────
if sr_in != sr_tgt:
wav_out = AF.resample(wav_out.unsqueeze(0), sr_tgt, sr_in).squeeze(0)
# ── preserve original RMS level ───────────────────────────────────────
rms_in = wav.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = wav_out.pow(2).mean().sqrt().clamp(min=1e-8)
wav_out = wav_out * (rms_in / rms_out)
peak = wav_out.abs().max()
if peak > 1.0:
wav_out = wav_out / peak
# ── reshape to match input layout [1, C, L] ───────────────────────────
out = wav_out.unsqueeze(0).unsqueeze(0)
if waveform.shape[1] > 1:
out = out.expand(-1, waveform.shape[1], -1).clone()
gain_db_range = (
20.0 * torch.log10(linear_gain.clamp(min=1e-8))
)
print(f"[Spectral Matcher] mode={mode} strength={strength:.2f} "
f"gain [{gain_db_range.min():.1f}, {gain_db_range.max():.1f}] dB "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr_in},)
+77
View File
@@ -0,0 +1,77 @@
"""SelVA BigVGAN Loader.
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
The model is modified in-place so ComfyUI's model cache is updated — no need to
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from .selva_bigvgan_trainer import inject_gafilters
class SelvaBigvganLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
DESCRIPTION = (
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
"Supports both 16k and 44k models. "
"Connect the output to SelVA Sampler instead of the base model loader."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"path": ("STRING", {
"default": "bigvgan_bj.pt",
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
"Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, model, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
if "generator" not in ckpt:
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
mode = model["mode"]
if mode == "16k":
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
elif mode == "44k":
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
# Remember device before injecting new modules (which default to CPU)
target_device = next(vocoder.parameters()).device
if ckpt.get("has_gafilter", False):
kernel_size = ckpt.get("gafilter_kernel_size", 9)
n_gaf = inject_gafilters(vocoder, kernel_size)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
vocoder.load_state_dict(ckpt["generator"])
vocoder.to(target_device)
vocoder.eval()
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
return (model,)
+625
View File
@@ -0,0 +1,625 @@
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
Each experiment inherits from a shared `base` config and overrides specific keys.
Audio clips are loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image.
JSON format:
{
"name": "bigvgan_sweep",
"description": "optional note",
"data_dir": "/path/to/audio/clips",
"output_root": "/path/to/output",
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
...
]
}
"""
import copy
import csv
import json
import threading
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_bigvgan_trainer import (
_do_train,
_pregenerate_lora_mels,
_load_wav,
)
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
from .selva_lora_scheduler import (
_get_system_info,
_resolve_path,
_draw_comparison_curves,
)
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"train_mode": "snake_alpha_only",
"steps": 2000,
"lr": 1e-4,
"batch_size": 4,
"segment_seconds": 2.0,
"lambda_l2sp": 1e-3,
"use_gafilter": True,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 500,
"seed": 42,
"discriminator_path": "",
"lora_adapter": "",
}
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge param defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _parse_training_log(log_path: Path) -> list:
"""Parse BigVGAN training CSV → list of total_loss values."""
losses = []
if not log_path.exists():
return losses
try:
with open(log_path) as f:
reader = csv.DictReader(f)
for row in reader:
losses.append(float(row["total_loss"]))
except Exception:
pass
return losses
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
"""Build {step: loss} at each save_every boundary.
Uses round-to-nearest to handle log_interval that doesn't divide
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
"""
result = {}
for target in range(save_every, total_steps + 1, save_every):
# loss_history[i] = loss at step (i+1)*log_interval
idx = round(target / log_interval) - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
class SelvaBigvganScheduler:
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
Audio clips are loaded once and reused across all experiments. Each experiment
deep-copies the vocoder and trains independently. Results are written to
`experiment_summary.json` after every completed run so partial results are
preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
"Audio clips are loaded once and reused across all experiments. "
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "bigvgan_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
mode = model["mode"]
dtype = model["dtype"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
if mode == "16k":
original_vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
original_vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[BigVGAN Scheduler] {description}", flush=True)
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load audio clips once
# ------------------------------------------------------------------
# Find minimum segment length across all experiments so we load enough
min_segment_seconds = float("inf")
for exp in spec["experiments"]:
cfg = _merge_config(base_cfg, exp)
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
min_segment_samples = int(min_segment_seconds * sample_rate)
audio_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
audio_files.extend(data_dir.rglob(ext))
if not audio_files:
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
clips = []
for af in audio_files:
try:
wav, sr = _load_wav(af)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0) # [L]
if wav.shape[0] >= min_segment_samples:
clips.append(wav.cpu())
else:
print(f" [BigVGAN Scheduler] Skip {af.name}: "
f"shorter than {min_segment_seconds}s", flush=True)
except Exception as e:
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
if not clips:
raise RuntimeError(
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
)
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
# ------------------------------------------------------------------
# 4. Offload unused components to free VRAM
# ------------------------------------------------------------------
comfy.model_management.unload_all_models()
feature_utils.to("cpu")
if "generator" in model:
model["generator"].to("cpu")
if "video_enc" in model:
model["video_enc"].to("cpu")
soft_empty_cache()
# ------------------------------------------------------------------
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
# ------------------------------------------------------------------
text_clip_cache = {}
any_lora = any(
_merge_config(base_cfg, exp).get("lora_adapter", "")
for exp in spec["experiments"]
)
if any_lora:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_model.to(device)
try:
for npz_path in npz_files:
data = dict(np.load(str(npz_path), allow_pickle=False))
prompt = prompt_map.get(
npz_path.name, data.get("prompt", default_prompt)
)
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
tc = feature_utils.encode_text_clip([prompt])
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
finally:
if clip_model is not None:
clip_model.to("cpu")
soft_empty_cache()
if device.type == "cuda":
torch.cuda.empty_cache()
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
f"CLIP embeddings", flush=True)
# ------------------------------------------------------------------
# 6. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 100),
"start_step": 0,
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[BigVGAN Scheduler] Resuming — skipping "
f"{len(completed_ids)} completed: "
f"{sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[BigVGAN Scheduler] Could not read existing summary "
f"({e}) — starting fresh", flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": len(clips),
"experiments": [],
}
def _write_summary():
summary_path.write_text(
json.dumps(summary, indent=2), encoding="utf-8"
)
_write_summary()
# ------------------------------------------------------------------
# 7. Compute total steps for progress bar
# ------------------------------------------------------------------
total_steps = 0
for exp in spec["experiments"]:
if exp["id"] not in completed_ids:
cfg = _merge_config(base_cfg, exp)
total_steps += int(cfg.get("steps", 2000))
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
# ------------------------------------------------------------------
# 8. Run experiments in a worker thread
# ------------------------------------------------------------------
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
# inside torch.inference_mode(). inference_mode is thread-local — a
# new thread starts with it OFF, so all tensor operations produce
# normal autograd-compatible tensors.
_exc = [None]
def _worker():
try:
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
f"(already completed)", flush=True)
continue
cfg = _merge_config(base_cfg, exp)
# ── Extract experiment parameters ────────────────────
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
exp_steps = int(cfg.get("steps", 2000))
exp_lr = float(cfg.get("lr", 1e-4))
exp_bs = int(cfg.get("batch_size", 4))
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
exp_gafilter = bool(cfg.get("use_gafilter", True))
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
exp_phase = float(cfg.get("lambda_phase", 1.0))
exp_save = int(cfg.get("save_every", 500))
exp_seed = int(cfg.get("seed", 42))
exp_disc = str(cfg.get("discriminator_path", ""))
exp_lora = str(cfg.get("lora_adapter", ""))
segment_samples = int(exp_seg_s * sample_rate)
# Filter clips long enough for this experiment
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
if not exp_clips:
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
f"no clips >= {exp_seg_s}s", flush=True)
summary["experiments"].append({
"id": exp_id, "description": exp_desc,
"config": dict(cfg),
"results": {
"status": "failed",
"error": f"No clips >= {exp_seg_s}s",
"duration_seconds": 0,
},
"checkpoint_path": None,
"output_dir": str(output_root / exp_id),
})
_write_summary()
continue
# ── Resolve discriminator path ───────────────────────
disc_path = None
if exp_disc:
disc_path = Path(exp_disc.strip())
if not disc_path.is_absolute():
disc_path = (
Path(folder_paths.get_output_directory()) / disc_path
)
if not disc_path.exists():
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"discriminator not found: {disc_path}",
flush=True)
disc_path = None
# ── Pre-generate LoRA mels (disk-cached) ─────────────
lora_mel_pairs = None
if exp_lora:
lora_path = Path(exp_lora.strip())
if not lora_path.is_absolute():
lora_path = Path(folder_paths.base_path) / lora_path
if lora_path.exists():
seq_cfg = model["seq_cfg"]
lora_mel_pairs = _pregenerate_lora_mels(
model, data_dir, str(lora_path),
device, dtype, sample_rate,
seq_cfg.duration, seed=exp_seed,
cache_dir=str(output_root),
text_clip_cache=text_clip_cache,
)
if not lora_mel_pairs:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"no LoRA mel pairs generated",
flush=True)
lora_mel_pairs = None
if device.type == "cuda":
torch.cuda.empty_cache()
else:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"LoRA adapter not found: {lora_path}",
flush=True)
# ── Output dir ───────────────────────────────────────
exp_dir = output_root / exp_id
exp_dir.mkdir(parents=True, exist_ok=True)
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
flush=True)
if exp_desc:
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
print(f"[BigVGAN Scheduler] mode={train_mode} "
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"train_mode": train_mode, "steps": exp_steps,
"lr": exp_lr, "batch_size": exp_bs,
"segment_seconds": exp_seg_s,
"lambda_l2sp": exp_l2sp,
"use_gafilter": exp_gafilter,
"gafilter_kernel_size": exp_gaf_ks,
"lambda_phase": exp_phase,
"save_every": exp_save, "seed": exp_seed,
"discriminator_path": exp_disc,
"lora_adapter": exp_lora,
},
"results": {"status": "running"},
"checkpoint_path": None,
"output_dir": str(exp_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
# Ensure mel_converter is on device for this experiment
mel_converter.to(device)
# Fresh vocoder copy — _do_train modifies it in-place
vocoder_copy = copy.deepcopy(original_vocoder)
checkpoint_path = _do_train(
vocoder_copy, mel_converter, exp_clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, exp_steps, exp_lr, exp_bs,
exp_l2sp, exp_gafilter, exp_gaf_ks,
exp_phase, exp_save, exp_seed,
out_path, disc_path, pbar,
lora_mel_pairs,
)
duration = time.monotonic() - t_start
# Parse training CSV for loss history
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
loss_history = _parse_training_log(log_path)
log_interval = max(1, exp_steps // 20)
smoothed = (
_smooth_losses(loss_history)
if loss_history else []
)
final_loss = (
round(smoothed[-1], 6) if smoothed else None
)
min_loss = (
round(min(smoothed), 6) if smoothed else None
)
min_idx = (
smoothed.index(min(smoothed))
if smoothed else None
)
min_loss_step = (
(min_idx + 1) * log_interval
if min_idx is not None else None
)
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std = round(
float(np.std(loss_history[-quarter:])), 6
)
else:
loss_std = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval,
exp_save, exp_steps,
),
"loss_history": [
round(v, 6) for v in loss_history
],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["checkpoint_path"] = checkpoint_path
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
f"failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
finally:
# Clean up vocoder copy to free VRAM
soft_empty_cache()
_write_summary()
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if _exc[0] is not None:
raise _exc[0]
# ------------------------------------------------------------------
# 9. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[BigVGAN Scheduler] Sweep complete. "
f"Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 10. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
+106
View File
@@ -0,0 +1,106 @@
import json
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaDatasetBrowser:
"""Browse a dataset.json file entry by entry using an integer index.
Each entry in the JSON is expected to have:
- "path" : base path (no extension) — directory that holds frame images
- "label" : text description of the clip
Derived outputs:
- video_path : path + ".mp4"
- audio_path : path + ".wav"
- frames_dir : path (the directory itself, for image-sequence loaders)
- label : entry["label"]
- count : total number of entries in the file
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset_json": ("STRING", {
"default": "",
"tooltip": "Absolute or ComfyUI-relative path to a dataset.json file.",
}),
"index": ("INT", {
"default": 0,
"min": 0,
"max": 9999,
"step": 1,
"tooltip": "Zero-based index of the entry to inspect.",
}),
},
}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "INT")
RETURN_NAMES = ("video_path", "audio_wav", "audio_flac", "features_path", "frames_dir", "mask_dir", "label", "max_index")
OUTPUT_TOOLTIPS = (
"path + '.mp4'",
"features/ + name + '.wav'",
"features/ + name + '.flac'",
"features/ + name + '.npz' (pre-extracted SelVA features)",
"path (image-sequence directory)",
"path + '_mask' (mask image-sequence directory)",
"Text label for this clip",
"count - 1 — wire to a primitive INT's max to constrain the index widget",
)
FUNCTION = "browse"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Reads a dataset.json produced by the SelVA dataset preparation pipeline "
"and exposes one entry at a time via an integer index. "
"Outputs the video path, audio path, frames directory, label, and total entry count."
)
# Re-read the file every call so edits are picked up without restarting ComfyUI.
IS_CHANGED = classmethod(lambda cls, **_: float("nan"))
def browse(self, dataset_json: str, index: int):
p = Path(dataset_json.strip())
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Dataset Browser] File not found: {p}")
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list) or len(data) == 0:
raise ValueError(f"[SelVA Dataset Browser] Expected a non-empty JSON array in {p}")
count = len(data)
if index >= count:
raise IndexError(
f"[SelVA Dataset Browser] index {index} is out of range "
f"(dataset has {count} entries, last index is {count - 1})"
)
entry = data[index]
base = entry["path"]
label = entry.get("label", "")
p_base = Path(base)
feat_base = str(p_base.parent / "features" / p_base.name)
print(
f"[SelVA Dataset Browser] {index + 1}/{count} label='{label}' base={base}",
flush=True,
)
return (
base + ".mp4",
feat_base + ".wav",
feat_base + ".flac",
feat_base + ".npz",
base,
base + "_mask",
label,
count - 1,
)
+788
View File
@@ -0,0 +1,788 @@
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetCompressor (optional)
↓ AUDIO_DATASET
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
↓ AUDIO_DATASET
SelvaDatasetHfSmoother (optional — batch HF attenuation)
↓ AUDIO_DATASET
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg
def _load_audio(path: Path):
"""Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues."""
if path.suffix.lower() in _SOUNDFILE_EXTS:
import soundfile as sf
wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C]
wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L]
else:
wav, sr = torchaudio.load(str(path)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
return wav, sr
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = _load_audio(f)
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
class SelvaDatasetCompressor:
"""Apply mild parallel compression to reduce within-clip loudness variance.
Uses pedalboard.Compressor (2:13:1 ratio). Parallel (New York) style:
blends compressed signal with dry so transients are preserved while
the dynamic range is gently tightened. Apply after LUFS normalization.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"threshold_db": ("FLOAT", {
"default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.",
}),
"ratio": ("FLOAT", {
"default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5,
"tooltip": "Compression ratio. 2:13:1 is mild; stay below 4:1 to avoid pumping.",
}),
"attack_ms": ("FLOAT", {
"default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0,
"tooltip": "Attack time in ms. Slower attack preserves transients.",
}),
"release_ms": ("FLOAT", {
"default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0,
"tooltip": "Release time in ms.",
}),
"mix": ("FLOAT", {
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.30.5 is typical.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "compress"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Mild parallel compression to reduce within-clip dynamic range. "
"Blends compressed signal with dry at the given mix ratio. "
"Apply after LUFS normalization."
)
def compress(self, dataset, threshold_db: float, ratio: float,
attack_ms: float, release_ms: float, mix: float):
from pedalboard import Compressor, Pedalboard
board = Pedalboard([Compressor(
threshold_db=threshold_db,
ratio=ratio,
attack_ms=attack_ms,
release_ms=release_ms,
)])
out = []
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pedalboard expects [C, L] float32 numpy
wav_np = wav.float().numpy() # [C, L]
compressed = board(wav_np, sr) # [C, L]
mixed = (1.0 - mix) * wav_np + mix * compressed
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]})
print(
f"[DatasetCompressor] {len(out)} clips compressed "
f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}",
flush=True,
)
return (out,)
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
mono = wav[0].mean(0) # [L]
window = torch.hann_window(n_fft, device=mono.device)
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
if mono.shape[0] < 2048:
return 60.0 # clip too short to frame — assume clean
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
"max_silence_fraction": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
"(< -60 dBFS). Set to 0 to disable silence detection.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
# Silence detection
if max_silence_fraction > 0:
mono = wav[0].mean(0)
if mono.shape[0] >= 2048:
frames = mono.unfold(0, 2048, 512)
rms = frames.pow(2).mean(-1).sqrt()
silent_frac = (rms < 1e-3).float().mean().item()
if silent_frac > max_silence_fraction:
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
class SelvaDatasetSaver:
"""Save all clips in an AUDIO_DATASET to disk as FLAC files.
Optionally copies matching .npz feature files from a source directory,
keeping FLAC/NPZ pairs in sync after the inspector has filtered clips.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"output_dir": ("STRING", {
"default": "",
"tooltip": "Absolute path to output folder. Created if it does not exist.",
}),
},
"optional": {
"npz_source_dir": ("STRING", {
"default": "",
"tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. "
"Missing NPZs are warned but do not abort the save.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("report",)
OUTPUT_NODE = True
FUNCTION = "save"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Save every clip in an AUDIO_DATASET to output_dir as FLAC. "
"If npz_source_dir is provided, copies the matching .npz file for each clip — "
"so rejected clips never get their NPZ copied."
)
def save(self, dataset, output_dir: str, npz_source_dir: str = ""):
import shutil
import soundfile as sf
out = Path(output_dir.strip())
out.mkdir(parents=True, exist_ok=True)
npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None
saved = 0
npz_copied = 0
npz_missing = []
for item in dataset:
name = item["name"]
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# soundfile wants [L] mono or [L, C] multichannel, float32
wav_np = wav.permute(1, 0).float().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
flac_path = out / f"{name}.flac"
sf.write(str(flac_path), wav_np, sr, subtype="PCM_24")
saved += 1
if npz_src is not None:
# Augmented clips store their origin name — use it to find the .npz
lookup = item.get("origin_name", name)
npz_path = npz_src / f"{lookup}.npz"
if npz_path.exists():
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
npz_copied += 1
else:
npz_missing.append(name)
lines = [
f"[DatasetSaver] Saved {saved} clips → {out}",
]
if npz_src is not None:
lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}")
for n in npz_missing:
lines.append(f" MISSING NPZ: {n}")
report = "\n".join(lines)
print(report, flush=True)
return (report,)
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
class SelvaDatasetSpectralMatcher:
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply adaptive spectral matching to every clip in a dataset. "
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
"VAE's expected distribution."
)
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
from .selva_audio_preprocessors import SelvaSpectralMatcher
matcher = SelvaSpectralMatcher()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = matcher.process(audio, mode, strength, max_gain_db)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
f"mode={mode} strength={strength}", flush=True)
return (out,)
class SelvaDatasetHfSmoother:
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
Wraps SelvaHfSmoother so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply soft HF attenuation to every clip in a dataset. "
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
"with the original to tame extreme HF content."
)
def process(self, dataset, cutoff_hz: float, blend: float):
from .selva_audio_preprocessors import SelvaHfSmoother
smoother = SelvaHfSmoother()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = smoother.process(audio, cutoff_hz, blend)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetHfSmoother] {len(out)} clips processed "
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
return (out,)
# ── Dataset augmenter ────────────────────────────────────────────────────────
class SelvaDatasetAugmenter:
"""Create augmented variants of each clip to expand a small dataset.
Supports gain variation (always available) and optionally pitch shift
and time stretch via audiomentations. Install audiomentations for the
full feature set: ``pip install audiomentations``
Each original clip produces ``variants_per_clip`` augmented copies.
Originals are kept by default (toggle ``keep_originals``).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"variants_per_clip": ("INT", {
"default": 2, "min": 1, "max": 20,
"tooltip": "Number of augmented copies per original clip.",
}),
"gain_range_db": ("FLOAT", {
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
}),
"seed": ("INT", {"default": 42}),
},
"optional": {
"pitch_range_semitones": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
}),
"time_stretch_range": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
"tooltip": "Random time stretch ±fraction (0.1 = 90%110% speed). "
"Requires audiomentations. 0 = disabled.",
}),
"keep_originals": ("BOOLEAN", {
"default": True,
"tooltip": "Include the original unaugmented clips in the output.",
}),
},
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "augment"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Create augmented variants of each clip (gain, pitch, time stretch) "
"to expand small training datasets. Gain is always available; pitch and "
"time stretch require audiomentations (pip install audiomentations)."
)
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
seed: int, pitch_range_semitones: float = 0.0,
time_stretch_range: float = 0.0, keep_originals: bool = True):
rng = np.random.RandomState(seed)
# Try audiomentations for pitch/stretch
use_am = False
am_compose = None
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
if needs_am:
try:
import audiomentations as am
transforms = []
if pitch_range_semitones > 0:
transforms.append(am.PitchShift(
min_semitones=-pitch_range_semitones,
max_semitones=pitch_range_semitones,
p=0.5,
))
if time_stretch_range > 0:
transforms.append(am.TimeStretch(
min_rate=1.0 - time_stretch_range,
max_rate=1.0 + time_stretch_range,
leave_length_unchanged=True,
p=0.5,
))
am_compose = am.Compose(transforms)
use_am = True
except ImportError:
print("[DatasetAugmenter] audiomentations not installed — "
"pitch_shift and time_stretch disabled. "
"Install: pip install audiomentations", flush=True)
out = []
if keep_originals:
out.extend(dataset)
for item in dataset:
wav = item["waveform"] # [1, C, L]
sr = item["sample_rate"]
name = item["name"]
for v in range(variants_per_clip):
# Gain variation (always applied)
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
gain_lin = 10.0 ** (gain_db / 20.0)
wav_aug = wav * gain_lin
# Pitch/stretch via audiomentations
if use_am and am_compose is not None:
wav_np = wav_aug[0].numpy() # [C, L] float32
if wav_np.shape[0] == 1:
wav_np = wav_np[0] # [L] mono for audiomentations
wav_np = am_compose(samples=wav_np, sample_rate=sr)
if wav_np.ndim == 1:
wav_np = wav_np[np.newaxis, :] # back to [1, L]
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
# Prevent clipping
peak = wav_aug.abs().max()
if peak > 1.0:
wav_aug = wav_aug / peak
out.append({
"waveform": wav_aug,
"sample_rate": sr,
"name": f"{name}_aug{v:02d}",
"origin_name": name,
})
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
f"gain=±{gain_range_db:.1f}dB"
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
flush=True)
return (out,)
+515
View File
@@ -0,0 +1,515 @@
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against target style reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against target style reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
Start at 0; enable only if mean-only optimization converges
without noise, then increase slowly (0.010.1).
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss — captures spectral envelope
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix loss — captures timbral texture (can add noise if too high)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss computed directly in VAE latent space.
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
ref_mean: [C_lat] mean latent vector of reference clips
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
gram_weight: weight for Gram component — 0 = mean only (recommended start)
Operating in latent space avoids backprop through the VAE decoder, which
is @torch.inference_mode() and produces noisy, unstable gradients.
"""
# Mean latent loss — matches average activation per channel
gen_mean = z.mean(dim=0) # [C_lat]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix — inter-channel covariance, position-invariant
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the target style reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with target style reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.020.05 is recommended; 0.1 (paper default) causes oscillation.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
"tooltip": "Weight of the target style style loss. High values push harder toward "
"target style style but add noise. Start at 0.1 and increase slowly.",
}),
"gram_weight": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
"0.1 adds texture matching but can introduce white noise.",
}),
"anchor_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
"Prevents optimization from pushing x0 out of the flow's "
"expected distribution (which causes white noise). "
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward target style style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match target style reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load reference clips and encode to latent space.
# Style loss is computed in latent space (after net_generator.unnormalize)
# rather than mel space — this avoids backpropagating through the VAE
# decoder (which is @torch.inference_mode() and produces noisy gradients).
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
if not hasattr(feature_utils.tod.vae, "encoder"):
raise RuntimeError(
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
"Reload the model with the encoder enabled."
)
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device, torch.float32) # cuFFT requires float32
ref_latents = []
with torch.no_grad():
for rf in ref_files:
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32)
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_latents:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference latent statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
ref_mean = all_means.mean(0) # [C_lat]
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread) and cast to
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
ref_mean = ref_mean.clone().detach().to(dtype)
ref_gram = ref_gram.clone().detach().to(dtype)
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone()
# Strip inference-mode flags from all model weights and buffers BEFORE any
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
# operations on inference tensors produce inference tensors, so conditions
# computed from tainted weights would also be tainted. clone() outside
# inference_mode produces a normal tensor regardless of the source flag.
def _strip_inference(module):
for mod in module.modules():
for name, buf in list(mod._buffers.items()):
if buf is not None:
mod._buffers[name] = buf.clone()
for name, param in list(mod._parameters.items()):
if param is not None:
mod._parameters[name] = torch.nn.Parameter(
param.data.clone(), requires_grad=False
)
_strip_inference(net_generator)
_strip_inference(feature_utils)
_strip_inference(mel_converter)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Clone all tensors inside conditions/empty_conditions to ensure no inference
# flags survived from intermediate computations inside preprocess_conditions.
def _clone_nested(obj):
if isinstance(obj, torch.Tensor):
return obj.clone()
elif isinstance(obj, dict):
return {k: _clone_nested(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_clone_nested(v) for v in obj)
return obj
conditions = _clone_nested(conditions)
empty_conditions = _clone_nested(empty_conditions)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
x0_init = x0_init.detach() # anchor — kept fixed, no grad
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# Detach from x0 so Phase 1 does not build a computation graph.
with torch.no_grad():
x = x0.detach()
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Straight-through estimator: reconnect x to x0's gradient path by
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
# creates a grad_fn pointing back to x0, so loss.backward() will
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
#
# x may carry an inference tensor flag from Phase 1 (derived from
# conditions which were built outside inference_mode but may have
# propagated the flag). .clone() strips it so the STE addition does
# not try to save an inference tensor for backward.
x = x.clone()
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Style loss in latent space ───────────────────────────────────────
# Unnormalize x back to VAE latent space — fully differentiable, no
# decode needed. ref_mean/ref_gram are computed from encoded reference
# clips in the same space. Avoids backprop through VAE decoder which
# is @torch.inference_mode() and produces noisy gradients.
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
# the ODE into an out-of-distribution region that decodes as white noise.
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
loss = style_loss + anchor_loss
optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
f"x0_std={x0.data.std().item():.3f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
+288
View File
@@ -0,0 +1,288 @@
import os
import hashlib
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
import comfy.utils
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
_CLIP_SIZE = 384
_SYNC_SIZE = 224
_CLIP_FPS = 8
_SYNC_FPS = 25
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
def _sample_frames(video, source_fps, target_fps, duration):
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
T = video.shape[0]
n_out = max(1, int(duration * target_fps))
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
return video[indices]
def _resize_frames(frames, size):
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
return x.clamp(0.0, 1.0) # [N, C, H, W]
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
"""
Apply a ComfyUI MASK to resized frames.
frames: [N, C, H, W] float [0,1]
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
source_fps: original video fps (for accurate temporal sampling)
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
"""
N, C, H, W = frames.shape
M = mask.shape[0]
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
if mask_f.shape[2] != H or mask_f.shape[3] != W:
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
# Temporal sampling — use same index formula as _sample_frames for accuracy
if M == 1:
mask_f = mask_f.expand(N, -1, -1, -1)
else:
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
mask_f = mask_f[indices] # [N, 1, H, W]
mask_f = mask_f.to(frames.device)
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
alpha = 1.0 - mask_strength * (1.0 - mask_f)
return frames * alpha + 0.5 * (1.0 - alpha)
def _resolve_named_path(cache_dir: str, name: str) -> str:
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
# Sanitize: replace path separators so the name stays inside cache_dir
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
i = 1
while True:
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
if not os.path.exists(p):
return p
i += 1
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes()
n = len(raw)
chunk = 512 * 1024 # 512 KB per sample
h.update(raw[:chunk])
h.update(raw[n // 2: n // 2 + chunk])
h.update(raw[max(0, n - chunk):])
if mask is not None:
raw_m = mask.cpu().numpy().tobytes()
nm = len(raw_m)
chunk_m = 256 * 1024
h.update(raw_m[:chunk_m])
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
h.update(raw_m[max(0, nm - chunk_m):])
h.update(str(round(mask_strength, 4)).encode())
h.update(str(mask_clip).encode())
h.update(str(mask_sync).encode())
h.update(prompt.encode())
h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
h.update(variant.encode())
return h.hexdigest()[:32]
class SelvaFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"video": ("IMAGE",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Describes the sounds to generate. Used to focus the visual sync features on motion relevant to the prompt — more specific prompts produce cleaner audio sync. Wire the prompt output directly to the Sampler so you only type it once.",
}),
},
"optional": {
"video_info": ("VHS_VIDEOINFO", {
"tooltip": "VHS_VIDEOINFO from VHS LoadVideo. Automatically sets the correct source fps — always connect this when loading video with VHS nodes.",
}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001,
"tooltip": "Source fps of the input video. Ignored when video_info is connected."}),
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
"cache_dir": ("STRING", {"default": "",
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
"name": ("STRING", {"default": "",
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
"mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}),
"mask_strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
}),
"mask_clip": ("BOOLEAN", {
"default": True,
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
}),
"mask_sync": ("BOOLEAN", {
"default": True,
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
}),
},
}
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
RETURN_NAMES = ("features", "fps", "prompt")
OUTPUT_TOOLTIPS = (
"Extracted feature bundle — connect to Sampler.",
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
)
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", name="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
if video_info is not None:
fps = video_info["loaded_fps"]
T = video.shape[0]
if duration <= 0:
duration = T / fps
duration = min(duration, T / fps) # clamp to actual video length
if not prompt.strip():
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
# Cache
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True)
if name.strip():
# Named mode: always extract and save to an incremented filename
cached_path = _resolve_named_path(cache_dir, name.strip())
else:
# Hash mode: skip extraction if identical inputs were already processed
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path)
return (cached, float(fps), cached.get("prompt", prompt))
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
feature_utils = model["feature_utils"]
net_video_enc = model["video_enc"]
if strategy == "offload_to_cpu":
feature_utils.to(device)
net_video_enc.to(device)
soft_empty_cache()
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3)
try:
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None and mask_clip:
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None and mask_sync:
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
# Normalize [0,1] → [-1,1]
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
pbar.update(1)
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768]
pbar.update(1)
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
finally:
if strategy == "offload_to_cpu":
feature_utils.to(get_offload_device())
net_video_enc.to(get_offload_device())
soft_empty_cache()
np.savez(
cached_path,
clip_features=clip_features.cpu().float().numpy(),
sync_features=sync_features.cpu().float().numpy(),
duration=float(duration),
prompt=np.array(prompt),
variant=np.array(model["variant"]),
)
print(f"[SelVA] Features cached: {cached_path}", flush=True)
return ({
"clip_features": clip_features.cpu(),
"sync_features": sync_features.cpu(),
"duration": float(duration),
"prompt": prompt,
"variant": model["variant"],
}, float(fps), prompt)
def _load_cached(path):
data = np.load(path, allow_pickle=False)
features = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
"duration": float(data["duration"]),
}
if "prompt" in data:
features["prompt"] = str(data["prompt"])
if "variant" in data:
features["variant"] = str(data["variant"])
return features
+421
View File
@@ -0,0 +1,421 @@
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
JSON format:
{
"name": "eval_batch_1",
"data_dir": "/path/to/features",
"output_dir": "/path/to/evals/batch1",
"steps": 25,
"seed": 42,
"adapters": [
{"id": "baseline"},
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
]
}
Empty / missing "path" = baseline (no LoRA applied).
"""
import copy
import json
import sys
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_pil_to_tensor,
_find_audio,
_load_audio,
)
from selva_core.model.lora import apply_lora, load_lora
def _avg_metrics(metrics_list: list) -> dict:
"""Average spectral metrics across multiple clips, ignoring None entries."""
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
"spectral_flatness", "temporal_variance"]
valid = [m for m in metrics_list if m]
if not valid:
return {}
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _safe_stem(adapter_id: str) -> str:
"""Replace characters illegal in filenames."""
for ch in r'/\:*?"<>|':
adapter_id = adapter_id.replace(ch, "_")
return adapter_id
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
METRICS = [
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
("spectral_flatness", "Spectral Flatness"),
("temporal_variance", "Temporal Variance"),
]
COLORS = [
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
]
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
for ax, (key, title) in zip(axes, METRICS):
values = []
colors = []
for i, m in enumerate(metrics_list):
v = m.get(key, 0.0) if m else 0.0
values.append(v)
colors.append(COLORS[i % len(COLORS)])
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
ax.set_title(title, fontsize=9)
ax.set_xlabel(key, fontsize=8)
ax.tick_params(axis="y", labelsize=7)
ax.tick_params(axis="x", labelsize=7)
# Value labels on bars
for bar, val in zip(bars, values):
w = bar.get_width()
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
f"{val:.3f}", va="center", ha="left", fontsize=6)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(output_path), dpi=110)
buf = canvas.buffer_rgba()
w, h = canvas.get_width_height()
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
from PIL import Image
return _pil_to_tensor(Image.fromarray(arr))
class SelvaLoraEvaluator:
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
Generates one audio sample per adapter, computes spectral metrics for each,
and produces a comparison chart. Use this after a sweep to compare candidates
before running the next round of training.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_image")
OUTPUT_TOOLTIPS = (
"Path to eval_summary.json — contains spectral metrics per adapter.",
"Bar chart comparing spectral metrics across all evaluated adapters.",
)
DESCRIPTION = (
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
"from a fixed reference clip, then collects spectral metrics for comparison. "
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"eval_file": ("STRING", {
"default": "eval_batch.json",
"tooltip": (
"Path to the JSON evaluation spec. Relative paths resolve "
"to the ComfyUI output directory. "
"Each adapter entry needs an 'id' and an optional 'path'. "
"Omit 'path' for a no-LoRA baseline."
),
}),
}
}
def run(self, model, eval_file):
# ------------------------------------------------------------------
# 1. Resolve and parse the JSON file
# ------------------------------------------------------------------
eval_path = Path(eval_file.strip())
if not eval_path.is_absolute():
candidate = Path(folder_paths.models_dir) / eval_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / eval_path
eval_path = candidate
if not eval_path.exists():
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
spec = json.loads(eval_path.read_text(encoding="utf-8"))
if "adapters" not in spec or not spec["adapters"]:
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
for i, a in enumerate(spec["adapters"]):
if "id" not in a:
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
if "data_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
if "output_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
name = spec.get("name", eval_path.stem)
data_dir = _resolve_path(spec["data_dir"])
output_dir = _resolve_path(spec["output_dir"])
steps = int(spec.get("steps", 25))
seed = int(spec.get("seed", 42))
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
# ------------------------------------------------------------------
# 2. Prepare dataset (VAE encode once)
# ------------------------------------------------------------------
device = get_device()
dtype = model["dtype"]
dataset = _prepare_dataset(model, data_dir, device)
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------
# 3. Collect reference metrics for all dataset clips
# ------------------------------------------------------------------
import shutil
npz_files = sorted(data_dir.glob("*.npz"))
ref_dir = output_dir / "reference"
ref_dir.mkdir(exist_ok=True)
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
flush=True)
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
continue
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
shutil.copy2(str(audio_path), str(ref_out))
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
ref_clips.append({
"clip": npz_path.stem,
"wav_path": str(ref_out),
"spectral_metrics": metrics,
})
except Exception as e:
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
# Average reference metrics across all clips
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
print(f"[LoRA Evaluator] Reference avg — "
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
# ------------------------------------------------------------------
summary = {
"name": name,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"output_dir": str(output_dir),
"n_clips": len(ref_clips),
"steps": steps,
"seed": seed,
"reference_avg": ref_avg,
"reference_clips": ref_clips,
"adapters": [],
}
summary_path = output_dir / "eval_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Per-adapter evaluation loop (all clips)
# ------------------------------------------------------------------
n_clips = len(dataset)
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
for adapter_spec in spec["adapters"]:
adapter_id = adapter_spec["id"]
adapter_path = (adapter_spec.get("path") or "").strip()
safe_id = _safe_stem(adapter_id)
clip_dir = output_dir / safe_id
clip_dir.mkdir(exist_ok=True)
record = {
"id": adapter_id,
"path": adapter_path or None,
"meta": None,
"clips": [],
"avg_metrics": None,
"status": "running",
}
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
try:
with torch.inference_mode(False):
generator = copy.deepcopy(model["generator"])
if adapter_path:
pt_path = Path(adapter_path)
if not pt_path.is_absolute():
pt_path = Path(folder_paths.base_path) / pt_path
if not pt_path.exists():
raise FileNotFoundError(f"Adapter not found: {pt_path}")
ckpt = torch.load(str(pt_path), map_location="cpu",
weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
meta = ckpt.get("meta", {})
else:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
use_rslora = meta.get("use_rslora", False)
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
# Always use standard init for loading — PiSSA checkpoints
# include linear.weight (residual) in state_dict
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target), dropout=dropout,
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"apply_lora matched 0 layers (target={target})"
)
load_lora(generator, state_dict)
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
f"(rank={rank}, {n} layers)", flush=True)
else:
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
generator = generator.to(device, dtype)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
clip_metrics_list = []
for clip_idx in range(n_clips):
clip_stem = npz_files[clip_idx].stem
wav, sr = _eval_sample(
generator, feature_utils_orig, dataset,
seq_cfg, device, dtype,
num_steps=steps, seed=seed, clip_idx=clip_idx,
)
if wav is None:
pbar.update(1)
continue
wav_path = clip_dir / f"{clip_stem}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
metrics = _spectral_metrics(wav, sr)
clip_metrics_list.append(metrics)
record["clips"].append({
"clip": clip_stem,
"wav_path": str(wav_path),
"spectral_metrics": metrics,
})
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
pbar.update(1)
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
record["status"] = "completed"
avg = record["avg_metrics"]
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
except Exception as e:
record["status"] = "failed"
record["error"] = str(e)
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
traceback.print_exc()
pbar.update(n_clips - len(record["clips"]))
finally:
try:
del generator
except NameError:
pass
soft_empty_cache()
summary["adapters"].append(record)
_write_summary()
# ------------------------------------------------------------------
# 5. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 6. Comparison chart
# ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed:
ids = ["reference"] + [r["id"] for r in completed]
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
chart_path = output_dir / "metric_comparison.png"
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
else:
from PIL import Image
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
return (str(summary_path), comparison)
+109
View File
@@ -0,0 +1,109 @@
import copy
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from selva_core.model.lora import apply_lora, load_lora
class SelvaLoraLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"adapter_path": ("STRING", {
"default": "",
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
}),
"strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to all LoRA contributions. "
"1.0 = full adapter strength. "
"0.0 = effectively disables the adapter. "
"Values above 1.0 exaggerate the effect.",
}),
},
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
"The base model is not modified — a shallow copy of the model bundle is returned."
)
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
if not adapter_path.strip():
raise ValueError("[SelVA LoRA] adapter_path is empty.")
# Resolve path: allow absolute or relative to ComfyUI base
from pathlib import Path
p = Path(adapter_path)
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
# Support both raw state_dict and {state_dict, meta} formats
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
meta = checkpoint.get("meta", {})
else:
state_dict = checkpoint
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
init_mode = meta.get("init_mode", "standard")
use_rslora = meta.get("use_rslora", False)
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
f"init={init_mode} rslora={use_rslora} strength={strength}",
flush=True)
# Shallow-copy the model bundle so the original generator is not mutated
patched = {**model}
generator = copy.deepcopy(model["generator"])
# For PiSSA, use standard init (the base weights will be overwritten
# by load_state_dict since the checkpoint includes linear.weight)
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target),
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"[SelVA LoRA] No layers matched target={target}. "
"Check that the adapter was trained with the same target suffixes."
)
load_lora(generator, state_dict)
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
norms = [p.norm().item() for name, p in generator.named_parameters()
if "lora_A" in name]
if norms:
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
else:
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
# Apply strength scaling: multiply all lora_B params by strength
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
if strength != 1.0:
with torch.no_grad():
for name, param in generator.named_parameters():
if "lora_B" in name:
param.mul_(strength)
generator.to(model["generator"].parameters().__next__().device)
patched["generator"] = generator
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
return (patched,)
+539
View File
@@ -0,0 +1,539 @@
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "tier1_sweep",
"description": "optional human note",
"data_dir": "dataset/dog_bark",
"output_root": "lora_output/tier1_sweep",
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
...
]
}
"""
import copy
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
SelvaLoraTrainer,
SkipExperiment,
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
def _get_system_info() -> dict:
"""Collect GPU / torch version info for the summary header."""
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"alpha": 0.0,
"target": "attn.qkv",
"batch_size": 4,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"resume_path": "",
"seed": 42,
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": True,
"latent_mixup_alpha": 0.0,
"latent_noise_sigma": 0.0,
}
# Palette for comparison chart: one color per experiment (cycles if > 8)
_PALETTE = [
(66, 133, 244), # blue
(234, 67, 53), # red
(52, 168, 83), # green
(251, 188, 5), # yellow
(155, 89, 182), # purple
(26, 188, 156), # teal
(230, 126, 34), # orange
(149, 165, 166), # grey
]
def _resolve_path(raw: str) -> Path:
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge base defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
# Don't carry id/description into the training params
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
start_step: int, total_steps: int) -> dict:
"""Build a dict of {step: loss} at each save_every boundary.
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
start + (i+1)*log_interval].
"""
result = {}
targets = range(save_every, total_steps + 1, save_every)
for target in targets:
# index of the loss entry nearest to this step
idx = (target - start_step) // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
) -> Image.Image:
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
# Collect all smoothed series
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"log_interval": ed.get("log_interval", 50),
"start_step": ed.get("start_step", 0),
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
# Draw each curve
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
# Legend (right side)
lx = W - pr + 10
ly = pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
class SelvaLoraScheduler:
"""Runs a sweep of LoRA training experiments defined in a JSON file.
The dataset (VAE encoding + .npz loading) is performed once and shared
across all experiments. Each experiment deep-copies the generator and trains
independently. Results are written to `experiment_summary.json` after every
completed run so partial results are preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
"The dataset is encoded once and reused across all experiments. "
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is. "
"See LORA_TRAINING.md for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
# Try relative to ComfyUI models dir first, then output dir
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[LoRA Scheduler] {description}", flush=True)
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load + encode dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = [] # collected for comparison image
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 50),
"start_step": 0,
})
# Restore the original summary, clear completed_at so it gets set again
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaLoraTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
variant = model["variant"]
mode = model["mode"]
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
# Required training params
steps = int(cfg.get("steps", 2000))
rank = int(cfg.get("rank", 16))
lr = float(cfg.get("lr", 1e-4))
alpha = float(cfg.get("alpha", 0.0))
target = str(cfg.get("target", "attn.qkv"))
batch_size = int(cfg.get("batch_size", 4))
warmup = int(cfg.get("warmup_steps", 100))
grad_accum = int(cfg.get("grad_accum", 1))
save_every = int(cfg.get("save_every", 500))
resume_path = str(cfg.get("resume_path", ""))
seed = int(cfg.get("seed", 42))
ts_mode = str(cfg.get("timestep_mode", "uniform"))
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
curr_switch = float(cfg.get("curriculum_switch", 0.6))
dropout = float(cfg.get("lora_dropout", 0.0))
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
lr_schedule = str(cfg.get("lr_schedule", "constant"))
init_mode = str(cfg.get("init_mode", "pissa"))
use_rslora = bool(cfg.get("use_rslora", True))
alpha_val = alpha if alpha > 0.0 else float(2 * rank)
target_suffixes = tuple(target.strip().split())
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
"batch_size": batch_size, "warmup_steps": warmup,
"grad_accum": grad_accum, "save_every": save_every,
"seed": seed, "target": list(target_suffixes),
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
"curriculum_switch": curr_switch,
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
"lr_schedule": lr_schedule,
"init_mode": init_mode, "use_rslora": use_rslora,
},
"results": {"status": "running"},
"adapter_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup,
grad_accum, save_every, resume_path, seed,
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
lr_schedule, init_mode, use_rslora,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
grad_norm_history = r.get("grad_norm_history", [])
spectral_metrics = r.get("spectral_metrics", {})
run_start_step = r.get("start_step", 0)
smoothed = _smooth_losses(loss_history) if loss_history else []
# Scalar summary metrics
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (
run_start_step + (min_idx + 1) * log_interval
if min_idx is not None else None
)
# Stability: std-dev of raw loss over last 25% of steps
if loss_history:
quarter = max(1, len(loss_history) // 4)
last_q = loss_history[-quarter:]
loss_std_last_quarter = round(float(np.std(last_q)), 6)
else:
loss_std_last_quarter = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, run_start_step, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"grad_norm_history": grad_norm_history,
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["adapter_path"] = r["adapter_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except SkipExperiment as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
partial = getattr(e, "partial", {})
lh = partial.get("loss_history", [])
smoothed = _smooth_losses(lh) if lh else []
exp_record["results"] = {
"status": "skipped",
"stopped_at_step": partial.get("stopped_at_step"),
"final_loss": round(smoothed[-1], 6) if smoothed else None,
"loss_history": [round(v, 6) for v in lh],
"grad_norm_history": partial.get("grad_norm_history", []),
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
except Exception as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
# Continue to next experiment rather than aborting the whole sweep
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
+171
View File
@@ -0,0 +1,171 @@
import os
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
# Variant → (generator filename, mode, has_bigvgan)
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
}
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
_HF_REPO = "jnwnlee/SelVA"
# filename → (hf_repo_path, expected_md5 or None to skip check)
# Note: 44k generators are named 44khz in the HF repo; md5=None since the
# original download_utils had the wrong filenames so those md5s are unverified.
_WEIGHTS = {
"video_enc_sup_5.pth": ("weights/video_enc_sup_5.pth", "ff09a6dc36148536ee4db97eba081d05"),
"generator_small_16k_sup_5.pth": ("weights/generator_small_16k_sup_5.pth", "1cb0f0deec52de37f67b1fd9965337d0"),
"generator_small_44k_sup_5.pth": ("weights/generator_small_44khz_sup_5.pth", None),
"generator_medium_44k_sup_5.pth":("weights/generator_medium_44khz_sup_5.pth", None),
"generator_large_44k_sup_5.pth": ("weights/generator_large_44khz_sup_5.pth", None),
"v1-16.pth": ("ext_weights/v1-16.pth", "69f56803f59a549a1a507c93859fd4d7"),
"v1-44.pth": ("ext_weights/v1-44.pth", "fab020275fa44c6589820ce025191600"),
"best_netG.pt": ("ext_weights/best_netG.pt", "eeaf372a38a9c31c362120aba2dde292"),
"synchformer_state_dict.pth": ("ext_weights/synchformer_state_dict.pth", "5b2f5594b0730f70e41e549b7c94390c"),
}
def _md5(path):
import hashlib
h = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8 * 1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def _ensure(filename, subdir=None):
"""Return path to weight file. Re-downloads if missing or MD5 mismatch."""
import shutil
from huggingface_hub import hf_hub_download
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
dest_path = dest_dir / filename
entry = _WEIGHTS.get(filename)
if entry is None:
raise ValueError(f"[SelVA] Unknown weight file: {filename}")
repo_path, expected_md5 = entry
if dest_path.exists():
if expected_md5 is None:
return str(dest_path)
actual = _md5(dest_path)
if actual == expected_md5:
return str(dest_path)
print(f"[SelVA] {filename}: MD5 mismatch ({actual}{expected_md5}), re-downloading...", flush=True)
dest_path.unlink()
print(f"[SelVA] Downloading {filename} from {_HF_REPO}...", flush=True)
dest_dir.mkdir(parents=True, exist_ok=True)
cached = hf_hub_download(repo_id=_HF_REPO, filename=repo_path)
shutil.copy2(cached, dest_path)
print(f"[SelVA] Saved to {dest_path}", flush=True)
return str(dest_path)
def _synchformer_path():
"""Return synchformer path, reusing models/prismaudio/ if already present."""
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
if prismaudio_path.exists():
return str(prismaudio_path)
return _ensure("synchformer_state_dict.pth")
class SelvaModelLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"variant": (list(_VARIANTS.keys()), {
"tooltip": "Model size and output sample rate. small_16k is fastest (16 kHz). 44k variants output 44.1 kHz. larger = better quality, more VRAM.",
}),
"precision": (["bf16", "fp16", "fp32"], {
"tooltip": "Compute dtype. bf16 is recommended on Ampere+ GPUs. fp16 for older NVIDIA hardware. fp32 if you see NaN outputs.",
}),
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"], {
"tooltip": "auto picks keep_in_vram if ≥16 GB VRAM is free, otherwise offload_to_cpu. offload_to_cpu moves weights to RAM between nodes, saving VRAM at the cost of speed.",
}),
}
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
FUNCTION = "load_model"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
def load_model(self, variant, precision, offload_strategy):
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.networks_video_enc import get_my_textsynch
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
strategy = determine_offload_strategy(offload_strategy)
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
video_enc_path = _ensure("video_enc_sup_5.pth")
gen_path = _ensure(gen_filename)
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _ensure(vae_name, subdir="ext")
synch_path = _synchformer_path()
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
net_video_enc.load_weights(
torch.load(video_enc_path, map_location="cpu", weights_only=False)
)
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_path, map_location="cpu", weights_only=False)
)
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_path,
synchformer_ckpt=synch_path,
enable_conditions=True,
mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=True,
).to(device, dtype).eval()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
net_video_enc.to(get_offload_device())
feature_utils.to(get_offload_device())
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
return ({
"generator": net_generator,
"video_enc": net_video_enc,
"feature_utils": feature_utils,
"variant": variant,
"mode": mode,
"strategy": strategy,
"dtype": dtype,
"seq_cfg": seq_cfg,
},)
+279
View File
@@ -0,0 +1,279 @@
import torch
import comfy.utils
import comfy.model_management
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_textual_inversion_trainer import _inject_tokens
class SelvaSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description for CLIP text conditioning. Leave empty to reuse the prompt from the Feature Extractor (wire its prompt output here). Changing this without re-extracting features shifts CLIP conditioning but not sync features.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
"tooltip": "Sounds to suppress, e.g. 'speech, music, wind noise'. Steered away from via CFG. Leave empty for unconditional guidance baseline.",
}),
"duration": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
"tooltip": "Output audio length in seconds. 0 = match the video duration stored in features.",
}),
"steps": ("INT", {"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the flow matching ODE. 25 is the SelVA default. Diminishing returns above 50; below 10 may sound rough."}),
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1,
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 37."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"steering_vectors": ("STEERING_VECTORS", {
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
"Nudges each DiT block's hidden state toward the extracted pattern.",
}),
"steering_strength": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to each steering vector before adding to block output. "
"Start around 0.10.3; higher values risk destabilizing the ODE.",
}),
"normalize": ("BOOLEAN", {
"default": True,
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.",
}),
"textual_inversion": ("TEXTUAL_INVERSION", {
"tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. "
"Injects style tokens into CLIP conditioning without modifying model weights.",
}),
"ti_strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Blends between original CLIP conditioning (0.0) and full TI injection (1.0). "
"Reduce toward 0.30.5 if TI produces buzz artifacts.",
}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
FUNCTION = "generate"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
import dataclasses
from selva_core.model.flow_matching import FlowMatching
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
# Validate that features were extracted with the same model variant
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
)
# Resolve prompt: use override if given, otherwise fall back to features prompt
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
if prompt:
print(f"[SelVA] Using prompt from features: '{prompt[:60]}'", flush=True)
else:
print("[SelVA] Warning: no prompt in features or sampler — CLIP text conditioning will be empty.", flush=True)
# Resolve duration
if duration <= 0:
if "duration" not in features:
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
duration = features["duration"]
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
# Derive sequence config for this duration from the model's mode template
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
try:
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
# Update model rotary position embeddings for actual feature shapes and duration.
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
with torch.no_grad():
# Encode text conditioning
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
# Encode negative prompt (or use empty conditions)
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
# Inject textual inversion tokens into CLIP conditioning
if textual_inversion is not None:
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
K = emb.shape[0]
inject_mode = textual_inversion.get("inject_mode", "suffix")
ti_text = _inject_tokens(text_clip, emb, K, inject_mode)
text_clip = torch.lerp(text_clip, ti_text, ti_strength)
if neg_text_clip is not None:
ti_neg = _inject_tokens(neg_text_clip, emb, K, inject_mode)
neg_text_clip = torch.lerp(neg_text_clip, ti_neg, ti_strength)
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode} strength={ti_strength}",
flush=True)
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Initial noise (MPS doesn't support torch.Generator on device)
gen_device = "cpu" if device.type == "mps" else device
rng = torch.Generator(device=gen_device).manual_seed(seed)
x0 = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=gen_device, dtype=dtype, generator=rng,
).to(device)
# Activation steering: apply only during the conditional predict_flow pass
# so steering gets amplified by cfg_strength rather than canceling out.
steering_handles = []
_orig_predict_flow = None
if steering_vectors is not None and steering_strength > 0.0:
vecs = steering_vectors["steering_vectors"]
n_joint = steering_vectors["n_joint"]
# Patch predict_flow to flag which pass is conditional.
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
# identity check tells us which is which.
_is_cond_pass = [False]
_orig_predict_flow = net_generator.predict_flow
def _tracked_predict_flow(latent, t, cond):
_is_cond_pass[0] = (cond is conditions)
return _orig_predict_flow(latent, t, cond)
net_generator.predict_flow = _tracked_predict_flow
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
vec = vec_cpu.to(dev, dt) # [seq, hidden]
def hook(module, input, output):
if not _is_cond_pass[0]:
return # skip unconditional pass; steering amplified by cfg_strength
# Interpolate steering vec to match actual output seq length
# (handles generation at different duration than extraction)
if is_joint:
out_seq = output[0].shape[1]
else:
out_seq = output.shape[1]
v = vec
if v.shape[0] != out_seq:
v = torch.nn.functional.interpolate(
v.T.unsqueeze(0), # [1, hidden, seq_orig]
size=out_seq,
mode="linear",
align_corners=False,
).squeeze(0).T # [seq_new, hidden]
if is_joint:
latent_out = output[0] + strength * v
return (latent_out,) + output[1:]
else:
return output + strength * v
return hook
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
for i, block in enumerate(blocks):
is_joint = i < n_joint
if i < len(vecs):
h = block.register_forward_hook(
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
)
steering_handles.append(h)
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
f"strength={steering_strength} (conditional pass only)", flush=True)
# Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
pbar = comfy.utils.ProgressBar(steps)
def ode_wrapper_tracked(t, x):
comfy.model_management.throw_exception_if_processing_interrupted()
pbar.update(1)
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
try:
x1 = fm.to_data(ode_wrapper_tracked, x0)
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
finally:
if _orig_predict_flow is not None:
net_generator.predict_flow = _orig_predict_flow
for h in steering_handles:
h.remove()
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
# Decode: latent → mel → audio
try:
with torch.no_grad():
x1_unnorm = net_generator.unnormalize(x1)
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
audio = feature_utils.vocode(spec) # mel → waveform
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
finally:
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
# Ensure [1, 1, samples] and normalize to [-1,1]
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
# If RMS normalization pushes peaks into clipping, scale back to
# preserve dynamics rather than hard-clipping (no saturation)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
+50
View File
@@ -0,0 +1,50 @@
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaSkipExperiment:
"""Writes skip_current.flag into a sweep output_root.
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
the current experiment and move to the next one. The trainer picks up
the flag within 50 steps (~a few seconds).
"""
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"output_root": ("STRING", {
"default": "",
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("flag_path",)
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
FUNCTION = "skip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
"and move to the next one. Queue this node while the scheduler is running. "
"Partial scalars collected so far are saved in the summary."
)
def skip(self, output_root: str):
p = Path(output_root.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
flag = p / "skip_current.flag"
flag.touch()
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
return (str(flag),)
+70
View File
@@ -0,0 +1,70 @@
"""SelVA Textual Inversion Loader.
Loads a .pt file produced by SelvaTextualInversionTrainer and returns a
TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaTextualInversionLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. "
"Relative paths resolve to the ComfyUI output directory.",
}),
},
}
RETURN_TYPES = ("TEXTUAL_INVERSION",)
RETURN_NAMES = ("textual_inversion",)
OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. "
"Connect the output to the SelVA Sampler's optional textual_inversion input to guide "
"generation toward the training data style without degrading audio quality."
)
def load(self, path: str) -> tuple:
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[TI Loader] File not found: {p}")
data = torch.load(str(p), map_location="cpu", weights_only=False)
embeddings = data["embeddings"] # [K, 1024]
n_tokens = int(data.get("n_tokens", embeddings.shape[0]))
print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} "
f"shape={tuple(embeddings.shape)}", flush=True)
if data.get("init_text"):
print(f"[TI Loader] init_text='{data['init_text']}'", flush=True)
if data.get("step"):
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
f"lr={data.get('lr', '?')}", flush=True)
inject_mode = data.get("inject_mode", "suffix")
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
bundle = {
"embeddings": embeddings, # [K, 1024] float32 on CPU
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"path": str(p),
"init_text": data.get("init_text", ""),
}
return (bundle,)
+450
View File
@@ -0,0 +1,450 @@
"""SelVA Textual Inversion Trainer.
Learns K token embedding vectors in CLIP space that guide the base model
to generate audio in the style of the training clips — without modifying
any model weights.
Key difference from LoRA:
- ALL generator parameters are frozen (requires_grad=False)
- Only K×1024 token embeddings receive gradients
- Latents stay on the decoder's natural manifold → no quality degradation
- The learned tokens shift WHICH latents are generated, not HOW
Usage:
1. Train on your .npz audio features
2. Load result with SelVA Textual Inversion Loader
3. Connect to SelVA Sampler optional input
"""
import copy
import random
import traceback
from pathlib import Path
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.flow_matching import FlowMatching
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_smooth_losses,
_draw_loss_curve,
_pil_to_tensor,
)
# ---------------------------------------------------------------------------
# Eval helper with token injection
# ---------------------------------------------------------------------------
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
n_tokens: int, inject_mode: str) -> torch.Tensor:
"""Build a text_clip tensor with learned tokens injected.
inject_mode:
"suffix" — replace last n_tokens positions (EOS/padding zone)
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
Works for both training (tokens is a Parameter) and eval (tokens is detached).
"""
if inject_mode == "prefix":
bos = text_clip[:, :1, :].detach() # [B, 1, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
else: # suffix (default)
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
return torch.cat([front, toks], dim=1) # [B, 77, D]
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, num_steps=25, seed=42, clip_idx=0):
"""Inference pass with learned tokens injected into text conditioning."""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype).clone()
emb = learned_tokens.detach().to(device, dtype)
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_input,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
tod = feature_utils_orig.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
tod.to(tod_orig_dev)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
target_rms = 10 ** (-27.0 / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = (audio * (target_rms / rms))
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
return audio.squeeze(0), seq_cfg.sampling_rate
except Exception as e:
print(f"[TI Trainer] Eval sample failed: {e}", flush=True)
traceback.print_exc()
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTextualInversionTrainer:
"""Learns K CLIP token embeddings that steer SelVA toward a target audio style.
Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor
receives gradients, keeping generated latents on the decoder's natural manifold
and preserving base model audio quality while shifting generation style.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "train"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("embeddings_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
"Smoothed training loss curve.",
)
DESCRIPTION = (
"Trains K learnable CLIP token embeddings against your audio dataset "
"with all model weights frozen. The tokens are then injected into the "
"sampler to guide generation toward the training data style without "
"degrading audio quality."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).",
}),
"output_path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.",
}),
"n_tokens": ("INT", {
"default": 4, "min": 1, "max": 16,
"tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.",
}),
"steps": ("INT", {
"default": 3000, "min": 100, "max": 50000,
"tooltip": "Training steps. 3000 is a reasonable starting point.",
}),
"lr": ("FLOAT", {
"default": 2e-4, "min": 1e-5, "max": 1e-1, "step": 1e-5,
"tooltip": "Learning rate. 2e-4 matches the LoRA working regime. Higher LR (1e-3) causes token norm to drift without plateauing on small datasets.",
}),
"batch_size": ("INT", {
"default": 4, "min": 1, "max": 64,
"tooltip": "Clips sampled per training step. Smaller batch (48) gives more diverse gradients and helps token norm saturate rather than drift.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
"save_every": ("INT", {
"default": 1000, "min": 100, "max": 10000,
"tooltip": "Save a checkpoint and generate an eval sample every N steps.",
}),
},
"optional": {
"inject_mode": (["suffix", "prefix"], {
"default": "suffix",
"tooltip": (
"Where to inject the learned tokens in the 77-token CLIP sequence. "
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
),
}),
"init_text": ("STRING", {
"default": "",
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
}),
"warmup_steps": ("INT", {
"default": 100, "min": 0, "max": 1000,
"tooltip": "Linear LR warmup steps.",
}),
},
}
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
batch_size, seed, save_every,
inject_mode="suffix", init_text="", warmup_steps=100):
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
# --- Resolve paths ---
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True)
print(f"[TI Trainer] data_dir = {data_dir}", flush=True)
print(f"[TI Trainer] output = {out_path}\n", flush=True)
# --- Load dataset (reuse LoRA trainer helper) ---
dataset = _prepare_dataset(model, data_dir, device)
# Training must run outside inference_mode so autograd works
with torch.inference_mode(False), torch.enable_grad():
r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode,
)
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
return (r["embeddings_path"], _pil_to_tensor(curve_img))
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
):
torch.manual_seed(seed)
# --- Generator (frozen) ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
generator.requires_grad_(False)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Init learned tokens ---
# Call encode_text_clip outside the grad context (it has @inference_mode),
# grab values only (no grad needed), then wrap as nn.Parameter.
if init_text.strip():
with torch.no_grad():
init_embed = feature_utils_orig.encode_text_clip([init_text.strip()])
# Positions 1:1+n_tokens — after BOS, before EOS — have actual content
init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float()
if init_vals.shape[0] < n_tokens:
# Prompt was very short; pad remaining with small noise
pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02
init_vals = torch.cat([init_vals, pad], dim=0)
learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype))
print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1{n_tokens})", flush=True)
else:
learned_tokens = torch.nn.Parameter(
torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02
)
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
# --- Measure CLIP token norm from the dataset (content positions 120) ---
# Learned tokens must stay within this range or the model treats them as
# out-of-distribution and produces buzz artifacts instead of style shift.
with torch.no_grad():
sample_norms = []
for item in dataset[:min(len(dataset), 20)]:
tc = item[3].squeeze(0) # [77, 1024]
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
clip_norm_ref = torch.cat(sample_norms).mean().item()
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
f"limit={clip_norm_limit:.4f}", flush=True)
# --- Optimizer + scheduler ---
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Checkpoint dir ---
ckpt_dir = out_path.parent / out_path.stem
ckpt_dir.mkdir(parents=True, exist_ok=True)
# --- Baseline sample (once, before any training) ---
print(f"[TI Trainer] Generating baseline sample...", flush=True)
baseline_wav, baseline_sr = _eval_sample(
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
)
if baseline_wav is not None:
baseline_path = ckpt_dir / "baseline.wav"
try:
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
except RuntimeError:
import soundfile as sf
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
try:
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
except Exception:
pass
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
pbar = comfy.utils.ProgressBar(steps)
loss_history = []
running_loss = 0.0
print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True)
for step in range(1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
x1 = generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
loss = fm.loss(v_pred, x0, x1).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Clamp token norm to CLIP manifold — prevents out-of-distribution
# embeddings that cause buzz artifacts instead of style shift.
with torch.no_grad():
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
scale = (clip_norm_limit / norms).clamp(max=1.0)
learned_tokens.data.mul_(scale)
running_loss += loss.item()
pbar.update(1)
if step % log_interval == 0:
avg = running_loss / log_interval
loss_history.append(round(avg, 6))
running_loss = 0.0
lr_now = scheduler.get_last_lr()[0]
norm = learned_tokens.norm(dim=-1).mean().item()
print(f"[TI Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e} "
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
if step % save_every == 0 or step == steps:
# Save checkpoint
ckpt = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": step,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
torch.save(ckpt, str(ckpt_path))
# Eval sample
wav, sr = _eval_sample_ti(
generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, seed=seed,
)
if wav is not None:
wav_path = ckpt_dir / f"step_{step:05d}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
try:
metrics = _spectral_metrics(wav, sr)
_save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
print(f"[TI Trainer] step {step} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.4f} "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
except Exception as e:
print(f"[TI Trainer] Spectral/spectrogram failed: {e}", flush=True)
print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True)
# --- Final save ---
final = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": steps,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
torch.save(final, str(out_path))
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
soft_empty_cache()
return {
"embeddings_path": str(out_path),
"loss_history": loss_history,
}
+479
View File
@@ -0,0 +1,479 @@
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "ti_sweep_1",
"description": "optional human note",
"data_dir": "dataset/bj_sounds",
"output_root": "ti_output/sweep_1",
"base": {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000
},
"experiments": [
{"id": "baseline", "description": "default 4 tokens"},
{"id": "n8_tokens", "n_tokens": 8},
{"id": "lr_5e4", "lr": 5e-4},
{"id": "warm_init", "init_text": "industrial sound design"},
{"id": "n4_more_steps", "steps": 5000}
]
}
"""
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
# ---------------------------------------------------------------------------
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
# ---------------------------------------------------------------------------
def _get_system_info() -> dict:
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
_PARAM_DEFAULTS = {
"n_tokens": 4,
"lr": 2e-4,
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000,
"init_text": "",
"inject_mode": "suffix",
}
_PALETTE = [
(66, 133, 244),
(234, 67, 53),
(52, 168, 83),
(251, 188, 5),
(155, 89, 182),
(26, 188, 156),
(230, 126, 34),
(149, 165, 166),
]
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
result = {}
for target in range(save_every, total_steps + 1, save_every):
idx = target // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
from PIL import Image, ImageDraw
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
lx, ly = W - pr + 10, pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTiScheduler:
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
The dataset is loaded once and reused. Each experiment calls
SelvaTextualInversionTrainer._train_inner() with its own config.
Results are written to experiment_summary.json after every completed run.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — compare runs across sweeps.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
"The dataset is encoded once and reused. Results (loss, config, embeddings "
"paths) are collected in experiment_summary.json after each run."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "ti_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"output directory. See node description for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate JSON
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[TI Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[TI Scheduler] {description}", flush=True)
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
all_curve_data.append({
"id": rec["id"],
"loss_history": rec["results"].get("loss_history", []),
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
comparison_img_path = output_root / "loss_comparison.png"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
def _save_comparison():
try:
img = _draw_comparison_curves(all_curve_data)
img.save(str(comparison_img_path))
except Exception as e:
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaTextualInversionTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
n_tokens = int(cfg["n_tokens"])
lr = float(cfg["lr"])
steps = int(cfg["steps"])
batch_size = int(cfg["batch_size"])
warmup = int(cfg["warmup_steps"])
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
inject_mode = str(cfg["inject_mode"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / "embeddings.pt"
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed} "
f"inject_mode={inject_mode}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"n_tokens": n_tokens,
"lr": lr,
"steps": steps,
"batch_size": batch_size,
"warmup_steps": warmup,
"seed": seed,
"save_every": save_every,
"init_text": init_text,
"inject_mode": inject_mode,
},
"results": {"status": "running"},
"embeddings_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text, inject_mode,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
loss_std_last_quarter = None
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["embeddings_path"] = r["embeddings_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
_write_summary()
_save_comparison()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image (final update, then return to ComfyUI)
# ------------------------------------------------------------------
_save_comparison()
comparison_img = _draw_comparison_curves(all_curve_data)
return (str(summary_path), _pil_to_tensor(comparison_img))
+157
View File
@@ -0,0 +1,157 @@
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
Useful for diagnosing codec reconstruction quality: if the output sounds
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
not the diffusion model or LoRA.
"""
import torch
import torch.nn.functional as F
import torchaudio
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
class SelvaVaeRoundtrip:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"audio": ("AUDIO",),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio_reconstructed",)
OUTPUT_TOOLTIPS = (
"Audio after VAE encode → decode roundtrip. "
"Compare to the input to hear codec reconstruction quality.",
)
FUNCTION = "roundtrip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
"Use this to isolate codec reconstruction quality from generation quality. "
"If the output sounds degraded compared to the input, the VAE/DAC is the "
"bottleneck — not the model or LoRA."
)
def roundtrip(self, model, audio):
from selva_core.model.utils.features_utils import FeaturesUtils
mode = model["mode"]
seq_cfg = model["seq_cfg"]
dtype = model["dtype"]
device = get_device()
generator = model["generator"]
feature_utils = model["feature_utils"]
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
# Load encoder only — decoder/vocoder come from model["feature_utils"]
# to mirror exactly what the sampler uses.
# AutoEncoderModule requires vocoder_ckpt_path even when only encoding,
# so pass the BigVGAN path (weights won't actually be used for decode here).
bigvgan_path = _SELVA_DIR / "ext" / "best_netG.pt"
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
vae_enc = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
bigvgan_vocoder_ckpt=str(bigvgan_path) if bigvgan_path.exists() else None,
).to(device).eval()
try:
# Prepare input audio
waveform = audio["waveform"] # [1, C, L]
sr_in = audio["sample_rate"]
wav = waveform[0].mean(0) # mono [L]
if sr_in != seq_cfg.sampling_rate:
wav = torchaudio.functional.resample(
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
).squeeze(0)
print(f"[VAE Roundtrip] Resampled {sr_in}{seq_cfg.sampling_rate} Hz",
flush=True)
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
if wav.shape[0] > target_len:
wav = wav[:target_len]
elif wav.shape[0] < target_len:
wav = F.pad(wav, (0, target_len - wav.shape[0]))
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
with torch.no_grad():
# Encode: audio → raw latent [1, latent_dim, T]
dist = vae_enc.encode_audio(wav_b)
latent = dist.mode().clone()
# Trim/pad to exact model sequence length (same as _prepare_dataset)
tgt = seq_cfg.latent_seq_len
if latent.shape[2] < tgt:
latent = F.pad(latent, (0, tgt - latent.shape[2]))
elif latent.shape[2] > tgt:
latent = latent[:, :, :tgt]
# To [B, T, latent_dim] — layout the generator uses
latent_t = latent.transpose(1, 2).to(dtype)
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
flush=True)
# Normalize → unnormalize mirrors the training/inference pipeline:
# training normalizes encoded latents; sampler unnormalizes before decode.
# This ensures the latent is in the same space the decoder expects.
latent_norm = generator.normalize(latent_t.clone())
latent_unnorm = generator.unnormalize(latent_norm)
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
flush=True)
# Decode using model's feature_utils — same path as the sampler
tod = feature_utils.tod
tod_orig_device = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils.decode(latent_unnorm)
out = feature_utils.vocode(spec)
finally:
tod.to(tod_orig_device)
out = out.float().cpu()
if out.dim() == 1:
out = out.unsqueeze(0).unsqueeze(0)
elif out.dim() == 2:
out = out.unsqueeze(1)
elif out.dim() == 3 and out.shape[1] != 1:
out = out.mean(dim=1, keepdim=True)
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
target_rms = 10 ** (-27.0 / 20.0)
out = out * (target_rms / rms)
out = out.clamp(-1.0, 1.0)
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
flush=True)
finally:
del vae_enc
soft_empty_cache()
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
-160
View File
@@ -1,160 +0,0 @@
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
)
from .sampler import _substitute_empty_features
class PrismAudioTextOnly:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Encode text with T5-Gemma
text_features = _encode_text_t5(text_prompt, device, dtype)
# Build metadata: tuple of one dict per sample
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
# These will be substituted with learned empty embeddings after conditioning
sample_meta = {
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
"text_features": text_features.to(device, dtype=dtype),
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
"video_exist": torch.tensor(False),
}
metadata = (sample_meta,)
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
conditioning = diffusion.conditioner(metadata, device)
# Substitute empty features for video/sync
_substitute_empty_features(diffusion, conditioning, device, dtype)
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
fakes_f = fakes.float()
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes_f)
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp
audio = audio.float()
pre_norm_std = audio.std().item()
pre_norm_peak = audio.abs().max().item()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
# T5-Gemma encoder singleton
_t5_model = None
_t5_tokenizer = None
def _encode_text_t5(text, device, dtype):
"""Encode text using T5-Gemma.
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
FeaturesUtils.encode_t5_text() implementation.
No truncation applied (matching reference behavior).
"""
global _t5_model, _t5_tokenizer
if _t5_model is None:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id = "google/t5gemma-l-l-ul2-it"
token = resolve_hf_token()
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
_t5_model.eval()
_t5_model.to(device, dtype=dtype)
tokens = _t5_tokenizer(
text,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
outputs = _t5_model(**tokens)
# Move T5 off GPU after encoding to save VRAM
_t5_model.to("cpu")
soft_empty_cache()
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
+3 -46
View File
@@ -1,21 +1,7 @@
import os
import torch import torch
import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
PRISMAUDIO_CATEGORY = "PrismAudio" SELVA_CATEGORY = "SelVA"
SAMPLE_RATE = 44100
DOWNSAMPLING_RATIO = 2048
IO_CHANNELS = 64
def get_prismaudio_model_dir():
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
os.makedirs(model_dir, exist_ok=True)
return model_dir
def register_model_folder():
model_dir = get_prismaudio_model_dir()
folder_paths.add_model_folder_path("prismaudio", model_dir)
def get_device(): def get_device():
return mm.get_torch_device() return mm.get_torch_device()
@@ -23,42 +9,13 @@ def get_device():
def get_offload_device(): def get_offload_device():
return mm.unet_offload_device() return mm.unet_offload_device()
def get_free_memory(device=None):
if device is None:
device = get_device()
return mm.get_free_memory(device)
def soft_empty_cache(): def soft_empty_cache():
mm.soft_empty_cache() mm.soft_empty_cache()
def determine_precision(preference, device):
if preference != "auto":
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
if device.type == "cpu":
return torch.float32
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def determine_offload_strategy(preference): def determine_offload_strategy(preference):
if preference != "auto": if preference != "auto":
return preference return preference
free_mem = get_free_memory() free_mem = mm.get_free_memory(get_device())
gb = free_mem / (1024 ** 3) if free_mem / (1024 ** 3) >= 16:
if gb >= 24:
return "keep_in_vram" return "keep_in_vram"
else:
return "offload_to_cpu" return "offload_to_cpu"
def try_import_flash_attn():
try:
import flash_attn
return flash_attn
except ImportError:
return None
def resolve_hf_token():
env_token = os.environ.get("HF_TOKEN")
if env_token:
return env_token
return None
-5
View File
@@ -1,5 +0,0 @@
"""
PrismAudio core inference modules.
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
Only inference-critical code — no training, no JAX/TF dependencies.
"""
-141
View File
@@ -1,141 +0,0 @@
{
"model_type": "diffusion_cond",
"sample_size": 397312,
"sample_rate": 44100,
"audio_channels": 2,
"model": {
"pretransform": {
"type": "autoencoder",
"iterate_batch": true,
"config": {
"encoder": {
"type": "oobleck",
"config": {
"in_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 128,
"use_snake": true
}
},
"decoder": {
"type": "oobleck",
"config": {
"out_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 64,
"use_snake": true,
"final_tanh": false
}
},
"bottleneck": {
"type": "vae"
},
"latent_dim": 64,
"downsampling_ratio": 2048,
"io_channels": 2
}
},
"conditioning": {
"configs": [
{
"id": "video_features",
"type": "cond_mlp",
"config": {
"dim": 1024,
"output_dim": 1024
}
},
{
"id": "text_features",
"type": "cond_mlp",
"config": {
"dim": 1024,
"output_dim": 1024
}
},
{
"id": "sync_features",
"type": "sync_mlp",
"config": {
"dim": 768,
"output_dim": 1024
}
}
],
"cond_dim": 768
},
"diffusion": {
"cross_attention_cond_ids": ["video_features","text_features"],
"add_cond_ids": ["video_features"],
"sync_cond_ids": ["sync_features"],
"type": "dit",
"diffusion_objective": "rectified_flow",
"config": {
"io_channels": 64,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"cond_token_dim": 1024,
"add_token_dim": 1024,
"sync_token_dim": 1024,
"project_cond_tokens": false,
"transformer_type": "continuous_transformer",
"attn_kwargs":{
"qk_norm": "rns"
},
"use_gated": true,
"use_sync_gated": true
}
},
"io_channels": 64
},
"training": {
"use_ema": true,
"log_loss_info": false,
"cfg_dropout_prob": 0.1,
"pre_encoded": true,
"timestep_sampler": "trunc_logit_normal",
"optimizer_configs": {
"diffusion": {
"optimizer": {
"type": "AdamW",
"config": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"weight_decay": 1e-3
}
},
"scheduler": {
"type": "InverseLR",
"config": {
"inv_gamma": 100000,
"power": 0.5,
"warmup": 0.99
}
}
}
},
"demo": {
"demo_every": 5000,
"demo_steps": 24,
"num_demos": 10,
"demo_cond": [
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
],
"demo_cfg_scales": [5]
}
}
}
-413
View File
@@ -1,413 +0,0 @@
"""
Model factory functions for PrismAudio inference.
Extracted from:
- PrismAudio/models/factory.py
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
Only inference-critical factory functions are retained.
"""
import json
import typing as tp
from typing import Dict, Any
import numpy as np
def create_model_from_config(model_config):
model_type = model_config.get('model_type', None)
assert model_type is not None, 'model_type must be specified in model config'
if model_type == 'autoencoder':
return create_autoencoder_from_config(model_config)
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
return create_diffusion_cond_from_config(model_config)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
def create_pretransform_from_config(pretransform_config, sample_rate):
pretransform_type = pretransform_config.get('type', None)
assert pretransform_type is not None, 'type must be specified in pretransform config'
if pretransform_type == 'autoencoder':
from prismaudio_core.models.pretransforms import AutoencoderPretransform
# Create fake top-level config to pass sample rate to autoencoder constructor
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
autoencoder = create_autoencoder_from_config(autoencoder_config)
scale = pretransform_config.get("scale", 1.0)
model_half = pretransform_config.get("model_half", False)
iterate_batch = pretransform_config.get("iterate_batch", False)
chunked = pretransform_config.get("chunked", False)
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
elif pretransform_type == 'wavelet':
raise NotImplementedError("wavelet pretransform type is not supported")
elif pretransform_type == 'pqmf':
from prismaudio_core.models.pretransforms import PQMFPretransform
pqmf_config = pretransform_config["config"]
pretransform = PQMFPretransform(**pqmf_config)
elif pretransform_type == 'dac_pretrained':
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
pretrained_dac_config = pretransform_config["config"]
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
elif pretransform_type == "audiocraft_pretrained":
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
audiocraft_config = pretransform_config["config"]
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
else:
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
enable_grad = pretransform_config.get('enable_grad', False)
pretransform.enable_grad = enable_grad
pretransform.eval().requires_grad_(pretransform.enable_grad)
return pretransform
def create_bottleneck_from_config(bottleneck_config):
bottleneck_type = bottleneck_config.get('type', None)
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
if bottleneck_type == 'tanh':
from prismaudio_core.models.bottleneck import TanhBottleneck
bottleneck = TanhBottleneck()
elif bottleneck_type == 'vae':
from prismaudio_core.models.bottleneck import VAEBottleneck
bottleneck = VAEBottleneck()
elif bottleneck_type == 'rvq':
from prismaudio_core.models.bottleneck import RVQBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQBottleneck(**quantizer_params)
elif bottleneck_type == "dac_rvq":
from prismaudio_core.models.bottleneck import DACRVQBottleneck
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'rvq_vae':
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQVAEBottleneck(**quantizer_params)
elif bottleneck_type == 'dac_rvq_vae':
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'l2_norm':
from prismaudio_core.models.bottleneck import L2Bottleneck
bottleneck = L2Bottleneck()
elif bottleneck_type == "wasserstein":
from prismaudio_core.models.bottleneck import WassersteinBottleneck
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
elif bottleneck_type == "fsq":
from prismaudio_core.models.bottleneck import FSQBottleneck
bottleneck = FSQBottleneck(**bottleneck_config["config"])
else:
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
requires_grad = bottleneck_config.get('requires_grad', True)
if not requires_grad:
for param in bottleneck.parameters():
param.requires_grad = False
return bottleneck
def create_autoencoder_from_config(config: Dict[str, Any]):
"""Create an AudioAutoencoder from a config dictionary.
Originally in PrismAudio/models/autoencoders.py.
"""
from prismaudio_core.models.autoencoders import (
AudioAutoencoder,
create_encoder_from_config,
create_decoder_from_config,
)
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
"""Create a MultiConditioner from a conditioning config dictionary.
Originally in PrismAudio/models/conditioners.py.
"""
from prismaudio_core.models.conditioners import (
MultiConditioner,
T5Conditioner,
CLAPTextConditioner,
CLIPTextConditioner,
MetaCLIPTextConditioner,
CLAPAudioConditioner,
Cond_MLP,
Global_MLP,
Sync_MLP,
Cond_MLP_1,
Cond_ConvMLP,
Cond_MLP_Global,
Cond_MLP_Global_1,
Cond_MLP_Global_2,
Video_Global,
Video_Sync,
Text_Linear,
CLIPConditioner,
IntConditioner,
NumberConditioner,
PhonemeConditioner,
TokenizerLUTConditioner,
PretransformConditioner,
mm_unchang,
)
from prismaudio_core.models.utils import load_ckpt_state_dict
conditioners = {}
cond_dim = config["cond_dim"]
default_keys = config.get("default_keys", {})
for conditioner_info in config["configs"]:
id = conditioner_info["id"]
conditioner_type = conditioner_info["type"]
conditioner_config = {"output_dim": cond_dim}
conditioner_config.update(conditioner_info["config"])
if conditioner_type == "t5":
conditioners[id] = T5Conditioner(**conditioner_config)
elif conditioner_type == "clap_text":
conditioners[id] = CLAPTextConditioner(**conditioner_config)
elif conditioner_type == "clip_text":
conditioners[id] = CLIPTextConditioner(**conditioner_config)
elif conditioner_type == "metaclip_text":
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
elif conditioner_type == "clap_audio":
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
elif conditioner_type == "cond_mlp":
conditioners[id] = Cond_MLP(**conditioner_config)
elif conditioner_type == "global_mlp":
conditioners[id] = Global_MLP(**conditioner_config)
elif conditioner_type == "sync_mlp":
conditioners[id] = Sync_MLP(**conditioner_config)
elif conditioner_type == "cond_mlp_1":
conditioners[id] = Cond_MLP_1(**conditioner_config)
elif conditioner_type == "cond_convmlp":
conditioners[id] = Cond_ConvMLP(**conditioner_config)
elif conditioner_type == "cond_mlp_global":
conditioners[id] = Cond_MLP_Global(**conditioner_config)
elif conditioner_type == "cond_mlp_global_1":
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
elif conditioner_type == "cond_mlp_global_2":
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
elif conditioner_type == "video_global":
conditioners[id] = Video_Global(**conditioner_config)
elif conditioner_type == "video_sync":
conditioners[id] = Video_Sync(**conditioner_config)
elif conditioner_type == "text_linear":
conditioners[id] = Text_Linear(**conditioner_config)
elif conditioner_type == "video_clip":
conditioners[id] = CLIPConditioner(**conditioner_config)
elif conditioner_type == "int":
conditioners[id] = IntConditioner(**conditioner_config)
elif conditioner_type == "number":
conditioners[id] = NumberConditioner(**conditioner_config)
elif conditioner_type == "phoneme":
conditioners[id] = PhonemeConditioner(**conditioner_config)
elif conditioner_type == "lut":
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
elif conditioner_type == "pretransform":
sample_rate = conditioner_config.pop("sample_rate", None)
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
elif conditioner_type == "mm_unchang":
conditioners[id] = mm_unchang(**conditioner_config)
else:
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
return MultiConditioner(conditioners, default_keys=default_keys)
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
Originally in PrismAudio/models/diffusion.py.
"""
from prismaudio_core.models.diffusion import (
ConditionedDiffusionModelWrapper,
MMConditionedDiffusionModelWrapper,
UNetCFG1DWrapper,
UNet1DCondWrapper,
DiTWrapper,
)
model_config = config["model"]
model_type = config["model_type"]
diffusion_config = model_config.get('diffusion', None)
assert diffusion_config is not None, "Must specify diffusion config"
diffusion_model_type = diffusion_config.get('type', None)
assert diffusion_model_type is not None, "Must specify diffusion model type"
diffusion_model_config = diffusion_config.get('config', None)
assert diffusion_model_config is not None, "Must specify diffusion model config"
if diffusion_model_type == 'adp_cfg_1d':
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
elif diffusion_model_type == 'adp_1d':
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit':
raise NotImplementedError("mmdit diffusion model type is not supported")
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
sample_rate = config.get('sample_rate', None)
assert sample_rate is not None, "Must specify sample_rate in config"
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
conditioning_config = model_config.get('conditioning', None)
conditioner = None
if conditioning_config is not None:
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
add_cond_ids = diffusion_config.get('add_cond_ids', [])
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
zero_init = diffusion_config.get('zero_init', False)
pretransform = model_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
min_input_length *= np.prod(diffusion_model_config["factors"])
elif diffusion_model_type == "dit":
min_input_length *= diffusion_model.model.patch_size
# Get the proper wrapper class
extra_kwargs = {}
if model_type == "mm_diffusion_cond":
wrapper_fn = MMConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior":
raise NotImplementedError("diffusion_prior model type is not supported")
return wrapper_fn(
diffusion_model,
conditioner,
min_input_length=min_input_length,
sample_rate=sample_rate,
cross_attn_cond_ids=cross_attention_ids,
global_cond_ids=global_cond_ids,
input_concat_ids=input_concat_ids,
prepend_cond_ids=prepend_cond_ids,
add_cond_ids=add_cond_ids,
sync_cond_ids=sync_cond_ids,
pretransform=pretransform,
io_channels=io_channels,
zero_init=zero_init,
**extra_kwargs
)
-4
View File
@@ -1,4 +0,0 @@
from .sampling import sample_discrete_euler
from .utils import set_audio_channels, prepare_audio
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
-29
View File
@@ -1,29 +0,0 @@
import torch
@torch.no_grad()
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
"""Discrete Euler sampler for rectified flow, with optional callback.
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
Original uses tqdm internally.
Args:
model: The diffusion model (DiTWrapper)
x: Initial noise tensor [B, C, T]
steps: Number of sampling steps
sigma_max: Maximum sigma (default 1.0 for rectified flow)
callback: Optional callable({"i": step, "x": current_x}) for progress
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
sync_cond, cfg_scale, batch_cfg, etc.
"""
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
dt = t_next - t_curr
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
x = x + dt * model(x, t_curr_tensor, **extra_args)
if callback is not None:
callback({"i": i, "x": x})
return x
-62
View File
@@ -1,62 +0,0 @@
import torch
import torch.nn.functional as F
from torchaudio import transforms as T
def set_audio_channels(audio, target_channels):
"""Convert audio tensor to target number of channels.
Args:
audio: Audio tensor of shape [B, C, T]
target_channels: Desired number of channels (1 for mono, 2 for stereo)
Returns:
Audio tensor with the target number of channels.
"""
if target_channels == 1:
# Convert to mono
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
# Convert to stereo
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Resample, pad/trim, and convert channels of an audio tensor.
Args:
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
in_sr: Input sample rate
target_sr: Target sample rate
target_length: Target length in samples (padded or cropped)
target_channels: Target number of channels
device: Torch device to place the audio on
Returns:
Audio tensor of shape [B, target_channels, target_length] on device.
"""
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
# Pad or crop to target_length
if audio.shape[-1] < target_length:
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[:, :, :target_length]
audio = set_audio_channels(audio, target_channels)
return audio
-9
View File
@@ -1,9 +0,0 @@
"""
PrismAudio model modules for inference.
Re-exports create_model_from_config from the factory module.
"""
from prismaudio_core.factory import create_model_from_config
__all__ = ["create_model_from_config"]
File diff suppressed because it is too large Load Diff
-821
View File
@@ -1,821 +0,0 @@
import torch
import math
import numpy as np
from torch import nn
from torch.nn import functional as F
from torchaudio import transforms as T
from alias_free_torch import Activation1d
from dac.nn.layers import WNConv1d, WNConvTranspose1d
from typing import Literal, Dict, Any
from .blocks import SnakeBeta
from .bottleneck import Bottleneck, DiscreteBottleneck
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
from .pretransforms import Pretransform
from .utils import checkpoint
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
import torchaudio.transforms as T
import torch
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
if audio.shape[0] > target_channels:
audio = audio[:target_channels]
elif audio.shape[0] < target_channels:
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
if audio.shape[-1] < target_length:
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[..., :target_length]
return audio.unsqueeze(0)
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config(pretransform, sample_rate)
def _lazy_create_bottleneck_from_config(bottleneck):
from prismaudio_core.factory import create_bottleneck_from_config
return create_bottleneck_from_config(bottleneck)
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class DACEncoderWrapper(nn.Module):
def __init__(self, in_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Encoder as DACEncoder
latent_dim = kwargs.pop("latent_dim", None)
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
self.latent_dim = latent_dim
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
if in_channels != 1:
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
def forward(self, x):
x = self.encoder(x)
x = self.proj_out(x)
return x
class DACDecoderWrapper(nn.Module):
def __init__(self, latent_dim, out_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Decoder as DACDecoder
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
self.latent_dim = latent_dim
def forward(self, x):
return self.decoder(x)
class AudioAutoencoder(nn.Module):
def __init__(
self,
encoder,
decoder,
latent_dim,
downsampling_ratio,
sample_rate,
io_channels=2,
bottleneck: Bottleneck = None,
pretransform: Pretransform = None,
in_channels = None,
out_channels = None,
soft_clip = False
):
super().__init__()
self.downsampling_ratio = downsampling_ratio
self.sample_rate = sample_rate
self.latent_dim = latent_dim
self.io_channels = io_channels
self.in_channels = io_channels
self.out_channels = io_channels
self.min_length = self.downsampling_ratio
if in_channels is not None:
self.in_channels = in_channels
if out_channels is not None:
self.out_channels = out_channels
self.bottleneck = bottleneck
self.encoder = encoder
self.decoder = decoder
self.pretransform = pretransform
self.soft_clip = soft_clip
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
info = {}
if self.pretransform is not None and not skip_pretransform:
if self.pretransform.enable_grad:
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
else:
with torch.no_grad():
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
if self.encoder is not None:
if iterate_batch:
latents = []
for i in range(audio.shape[0]):
latents.append(self.encoder(audio[i:i+1]))
latents = torch.cat(latents, dim=0)
else:
latents = self.encoder(audio)
else:
latents = audio
if self.bottleneck is not None:
# TODO: Add iterate batch logic, needs to merge the info dicts
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
info.update(bottleneck_info)
if return_info:
return latents, info
return latents
def decode(self, latents, iterate_batch=False, **kwargs):
if self.bottleneck is not None:
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.bottleneck.decode(latents[i:i+1]))
latents = torch.cat(decoded, dim=0)
else:
latents = self.bottleneck.decode(latents)
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.decoder(latents[i:i+1]))
decoded = torch.cat(decoded, dim=0)
else:
decoded = self.decoder(latents, **kwargs)
if self.pretransform is not None:
if self.pretransform.enable_grad:
if iterate_batch:
decodeds = []
for i in range(decoded.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
if iterate_batch:
decodeds = []
for i in range(latents.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
if self.soft_clip:
decoded = torch.tanh(decoded)
return decoded
def decode_tokens(self, tokens, **kwargs):
'''
Decode discrete tokens to audio
Only works with discrete autoencoders
'''
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
return self.decode(latents, **kwargs)
def preprocess_audio_for_encoder(self, audio, in_sr):
'''
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
If the model is mono, stereo audio will be converted to mono.
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
Audio will be resampled to the model's sample rate.
The output will have batch size 1 and be shape (1 x Channels x Length)
'''
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
'''
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
The audio in that list can be of different lengths and channels.
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
All audio will be resampled to the model's sample rate.
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
If the model is mono, all audio will be converted to mono.
The output will be a tensor of shape (Batch x Channels x Length)
'''
batch_size = len(audio_list)
if isinstance(in_sr_list, int):
in_sr_list = [in_sr_list]*batch_size
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
new_audio = []
max_length = 0
# resample & find the max length
for i in range(batch_size):
audio = audio_list[i]
in_sr = in_sr_list[i]
if len(audio.shape) == 3 and audio.shape[0] == 1:
# batchsize 1 was given by accident. Just squeeze it.
audio = audio.squeeze(0)
elif len(audio.shape) == 1:
# Mono signal, channel dimension is missing, unsqueeze it in
audio = audio.unsqueeze(0)
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
# Resample audio
if in_sr != self.sample_rate:
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
audio = resample_tf(audio)
new_audio.append(audio)
if audio.shape[-1] > max_length:
max_length = audio.shape[-1]
# Pad every audio to the same length, multiple of model's downsampling ratio
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
for i in range(batch_size):
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
# convert to tensor
return torch.stack(new_audio)
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
Overlap and chunk_size params are both measured in number of latents (not audio samples)
# and therefore you likely could use the same values with decode_audio.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Encode the entire audio in parallel
return self.encode(audio, **kwargs)
else:
# CHUNKED ENCODING
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
samples_per_latent = self.downsampling_ratio
total_size = audio.shape[2] # in samples
batch_size = audio.shape[0]
chunk_size *= samples_per_latent # converting metric in latents to samples
overlap *= samples_per_latent # converting metric in latents to samples
hop_size = chunk_size - overlap
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = audio[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = audio[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# Note: y_size might be a different value from the latent length used in diffusion training
# because we can encode audio of varying lengths
# However, the audio should've been padded to a multiple of samples_per_latent by now.
y_size = total_size // samples_per_latent
# Create an empty latent, we will populate it with chunks as we encode them
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# encode the chunk
y_chunk = self.encode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size // samples_per_latent
t_end = t_start + chunk_size // samples_per_latent
# remove the edges of the overlaps
ol = overlap//samples_per_latent//2
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Decode latents to audio.
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Decode the entire latent in parallel
return self.decode(latents, **kwargs)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = latents[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = self.downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# decode the chunk
y_chunk = self.decode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap//2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
class DiffusionAutoencoder(AudioAutoencoder):
def __init__(
self,
diffusion: ConditionedDiffusionModel,
diffusion_downsampling_ratio,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.diffusion = diffusion
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
if self.encoder is not None:
# Shrink the initial encoder parameters to avoid saturated latents
with torch.no_grad():
for param in self.encoder.parameters():
param *= 0.5
def decode(self, latents, steps=100):
upsampled_length = latents.shape[2] * self.downsampling_ratio
if self.bottleneck is not None:
latents = self.bottleneck.decode(latents)
if self.decoder is not None:
latents = self.decoder(latents)
# Upsample latents to match diffusion length
if latents.shape[2] != upsampled_length:
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
from prismaudio_core.inference.sampling import sample
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
if self.pretransform is not None:
if self.pretransform.enable_grad:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
decoded = self.pretransform.decode(decoded)
return decoded
# AE factories
def create_encoder_from_config(encoder_config: Dict[str, Any]):
encoder_type = encoder_config.get("type", None)
assert encoder_type is not None, "Encoder type must be specified"
if encoder_type == "oobleck":
encoder = OobleckEncoder(
**encoder_config["config"]
)
elif encoder_type == "seanet":
from encodec.modules import SEANetEncoder
seanet_encoder_config = encoder_config["config"]
#SEANet encoder expects strides in reverse order
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
encoder = SEANetEncoder(
**seanet_encoder_config
)
elif encoder_type == "dac":
dac_config = encoder_config["config"]
encoder = DACEncoderWrapper(**dac_config)
elif encoder_type == "local_attn":
from .local_attention import TransformerEncoder1D
local_attn_config = encoder_config["config"]
encoder = TransformerEncoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown encoder type {encoder_type}")
requires_grad = encoder_config.get("requires_grad", True)
if not requires_grad:
for param in encoder.parameters():
param.requires_grad = False
return encoder
def create_decoder_from_config(decoder_config: Dict[str, Any]):
decoder_type = decoder_config.get("type", None)
assert decoder_type is not None, "Decoder type must be specified"
if decoder_type == "oobleck":
decoder = OobleckDecoder(
**decoder_config["config"]
)
elif decoder_type == "seanet":
from encodec.modules import SEANetDecoder
decoder = SEANetDecoder(
**decoder_config["config"]
)
elif decoder_type == "dac":
dac_config = decoder_config["config"]
decoder = DACDecoderWrapper(**dac_config)
elif decoder_type == "local_attn":
from .local_attention import TransformerDecoder1D
local_attn_config = decoder_config["config"]
decoder = TransformerDecoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown decoder type {decoder_type}")
requires_grad = decoder_config.get("requires_grad", True)
if not requires_grad:
for param in decoder.parameters():
param.requires_grad = False
return decoder
def create_autoencoder_from_config(config: Dict[str, Any]):
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_diffAE_from_config(config: Dict[str, Any]):
diffae_config = config["model"]
if "encoder" in diffae_config:
encoder = create_encoder_from_config(diffae_config["encoder"])
else:
encoder = None
if "decoder" in diffae_config:
decoder = create_decoder_from_config(diffae_config["decoder"])
else:
decoder = None
diffusion_model_type = diffae_config["diffusion"]["type"]
if diffusion_model_type == "DAU1d":
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "adp_1d":
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "dit":
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
latent_dim = diffae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = diffae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
bottleneck = diffae_config.get("bottleneck", None)
pretransform = diffae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
diffusion_downsampling_ratio = None
if diffusion_model_type == "DAU1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
elif diffusion_model_type == "adp_1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
elif diffusion_model_type == "dit":
diffusion_downsampling_ratio = 1
return DiffusionAutoencoder(
encoder=encoder,
decoder=decoder,
diffusion=diffusion,
io_channels=io_channels,
sample_rate=sample_rate,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
bottleneck=bottleneck,
pretransform=pretransform
)
-331
View File
@@ -1,331 +0,0 @@
from functools import reduce
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.backends.cuda import sdp_kernel
from packaging import version
from dac.nn.layers import Snake1d
class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(ResidualBlock):
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
super().__init__([
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_mid),
Snake1d(c_mid) if use_snake else nn.GELU(),
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
], skip)
class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm = nn.GroupNorm(1, c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
if self.use_flash:
with sdp_kernel(*self.sdp_kernel_config):
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
else:
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)
def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
class Downsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv1d(x, weight, stride=2)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
class Upsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
def Downsample1d_2(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor,
padding=factor * (kernel_multiplier // 2),
)
def Upsample1d_2(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
)
else:
return nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor,
padding=factor // 2 + factor % 2,
output_padding=factor % 2,
)
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
def normalize(x, eps=1e-4):
dim = list(range(1, x.ndim))
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
alpha = np.sqrt(n.numel() / x.numel())
return x / torch.add(eps, n, alpha=alpha)
class ForcedWNConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
def forward(self, x):
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(self.weight))
fan_in = self.weight[0].numel()
w = normalize(self.weight) / math.sqrt(fan_in)
return F.conv1d(x, w, padding='same')
# Kernels
use_compile = True
def compile(function, *args, **kwargs):
if not use_compile:
return function
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
@compile
def linear_geglu(x, weight, bias=None):
x = x @ weight.mT
if bias is not None:
x = x + bias
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
@compile
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
# Layers
class LinearGEGLU(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features * 2, bias=bias)
self.out_features = out_features
def forward(self, x):
return linear_geglu(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, shape, fix_scale = False, eps=1e-6):
super().__init__()
self.eps = eps
if fix_scale:
self.register_buffer("scale", torch.ones(shape))
else:
self.scale = nn.Parameter(torch.ones(shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# try:
# snake_beta = torch.compile(snake_beta)
# except RuntimeError:
# pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
-355
View File
@@ -1,355 +0,0 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from vector_quantize_pytorch import ResidualVQ, FSQ
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
class Bottleneck(nn.Module):
def __init__(self, is_discrete: bool = False):
super().__init__()
self.is_discrete = is_discrete
def encode(self, x, return_info=False, **kwargs):
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
class DiscreteBottleneck(Bottleneck):
def __init__(self, num_quantizers, codebook_size, tokens_id):
super().__init__(is_discrete=True)
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.tokens_id = tokens_id
def decode_tokens(self, codes, **kwargs):
raise NotImplementedError
class TanhBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
self.tanh = nn.Tanh()
def encode(self, x, return_info=False):
info = {}
x = torch.tanh(x)
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def compute_mean_kernel(x, y):
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
return torch.exp(-kernel_input).mean()
def compute_mmd(latents):
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
noise = torch.randn_like(latents_reshaped)
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
noise_kernel = compute_mean_kernel(noise, noise)
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
return mmd.mean()
class WassersteinBottleneck(Bottleneck):
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
super().__init__(is_discrete=False)
self.noise_augment_dim = noise_augment_dim
self.bypass_mmd = bypass_mmd
def encode(self, x, return_info=False):
info = {}
if self.training and return_info:
if self.bypass_mmd:
mmd = torch.tensor(0.0)
else:
mmd = compute_mmd(x)
info["mmd"] = mmd
if return_info:
return x, info
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
class L2Bottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False):
info = {}
x = F.normalize(x, dim=1)
if return_info:
return x, info
else:
return x
def decode(self, x):
return F.normalize(x, dim=1)
class RVQBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False, **kwargs):
info = {}
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class RVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False):
info = {}
x, kl = vae_sample(*x.chunk(2, dim=1))
info["kl"] = kl
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class DACRVQBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
self.noise_augment_dim = noise_augment_dim
def encode(self, x, return_info=False, **kwargs):
info = {}
info["pre_quantizer"] = x
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class DACRVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
def encode(self, x, return_info=False, n_quantizers: int = None):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["pre_quantizer"] = x
info["kl"] = kl
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class FSQBottleneck(DiscreteBottleneck):
def __init__(self, noise_augment_dim=0, **kwargs):
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
self.noise_augment_dim = noise_augment_dim
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
def encode(self, x, return_info=False):
info = {}
orig_dtype = x.dtype
x = x.float()
x = rearrange(x, "b c n -> b n c")
x, indices = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
x = x.to(orig_dtype)
# Reorder indices to match the expected format
indices = rearrange(indices, "b n q -> b q n")
info["quantizer_indices"] = indices
if return_info:
return x, info
else:
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, tokens, **kwargs):
latents = self.quantizer.indices_to_codes(tokens)
return self.decode(latents, **kwargs)
File diff suppressed because it is too large Load Diff
-884
View File
@@ -1,884 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
import numpy as np
import typing as tp
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
from .conditioners import MultiConditioner
from .dit import DiffusionTransformer
from .pretransforms import Pretransform
from .adp import UNetCFG1d, UNet1d
# Lazy imports for factory functions to avoid circular imports
def _get_create_pretransform_from_config():
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config
def _get_create_multi_conditioner_from_conditioning_config():
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
return create_multi_conditioner_from_conditioning_config
class DiffusionModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, t, **kwargs):
raise NotImplementedError()
class DiffusionModelWrapper(nn.Module):
def __init__(
self,
model: DiffusionModel,
io_channels,
sample_size,
sample_rate,
min_input_length,
pretransform: tp.Optional[Pretransform] = None,
):
super().__init__()
self.io_channels = io_channels
self.sample_size = sample_size
self.sample_rate = sample_rate
self.min_input_length = min_input_length
self.model = model
if pretransform is not None:
self.pretransform = pretransform
else:
self.pretransform = None
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class ConditionedDiffusionModel(nn.Module):
def __init__(self,
*args,
supports_cross_attention: bool = False,
supports_input_concat: bool = False,
supports_global_cond: bool = False,
supports_prepend_cond: bool = False,
**kwargs):
super().__init__(*args, **kwargs)
self.supports_cross_attention = supports_cross_attention
self.supports_input_concat = supports_input_concat
self.supports_global_cond = supports_global_cond
self.supports_prepend_cond = supports_prepend_cond
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
cross_attn_cond: torch.Tensor = None,
cross_attn_mask: torch.Tensor = None,
input_concat_cond: torch.Tensor = None,
global_embed: torch.Tensor = None,
prepend_cond: torch.Tensor = None,
prepend_cond_mask: torch.Tensor = None,
cfg_scale: float = 1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
**kwargs):
raise NotImplementedError()
class ConditionedDiffusionModelWrapper(nn.Module):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model: ConditionedDiffusionModel,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
zero_init: bool = False,
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
sync_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.sync_cond_ids = sync_cond_ids
self.min_input_length = min_input_length
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if zero_init is True:
self.conditioner.apply(_basic_init)
self.model.model.initialize_weights()
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
sync_input = None
if len(self.cross_attn_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
cross_attention_input = []
cross_attention_masks = []
for key in self.cross_attn_cond_ids:
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
# Add sequence dimension if it's not there
if len(cross_attn_in.shape) == 2:
cross_attn_in = cross_attn_in.unsqueeze(1)
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
if len(self.add_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
add_input = []
for key in self.add_cond_ids:
add_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(add_in.shape) == 2:
add_in = add_in.unsqueeze(1)
# add_in = add_in.transpose(1,2)
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
# add_in = add_in.transpose(1,2)
add_input.append(add_in)
add_input = torch.cat(add_input, dim=2)
if len(self.sync_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
sync_input = []
for key in self.sync_cond_ids:
sync_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(sync_in.shape) == 2:
sync_in = sync_in.unsqueeze(1)
sync_input.append(sync_in)
sync_input = torch.cat(sync_input, dim=2)
if len(self.global_cond_ids) > 0:
# Concatenate all global conditioning inputs over the channel dimension
# Assumes that the global conditioning inputs are of shape (batch, channels)
global_conds = []
for key in self.global_cond_ids:
global_cond_input = conditioning_tensors[key][0]
if len(global_cond_input.shape) == 2:
global_cond_input = global_cond_input.unsqueeze(1)
global_conds.append(global_cond_input)
# # Concatenate over the channel dimension
# if global_conds[0].shape[-1] == 768:
# global_cond = torch.cat(global_conds, dim=-1)
# else:
# global_cond = sum(global_conds)
global_cond = sum(global_conds)
# global_cond = torch.cat(global_conds, dim=-1)
if len(global_cond.shape) == 3:
global_cond = global_cond.squeeze(1)
if len(self.input_concat_ids) > 0:
# Concatenate all input concat conditioning inputs over the channel dimension
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
if len(self.prepend_cond_ids) > 0:
# Concatenate all prepend conditioning inputs over the sequence dimension
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
prepend_conds = []
prepend_cond_masks = []
for key in self.prepend_cond_ids:
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
if len(prepend_cond_input.shape) == 2:
prepend_cond_input = prepend_cond_input.unsqueeze(1)
prepend_conds.append(prepend_cond_input)
prepend_cond_masks.append(prepend_cond_mask)
prepend_cond = torch.cat(prepend_conds, dim=1)
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
if negative:
return {
"negative_cross_attn_cond": cross_attention_input,
"negative_cross_attn_mask": cross_attention_masks,
"negative_global_cond": global_cond,
"negative_input_concat_cond": input_concat_cond
}
else:
return {
"cross_attn_cond": cross_attention_input,
"cross_attn_mask": cross_attention_masks,
"global_cond": global_cond,
"input_concat_cond": input_concat_cond,
"prepend_cond": prepend_cond,
"prepend_cond_mask": prepend_cond_mask,
"add_cond": add_input,
"sync_cond": sync_input
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class UNetCFG1DWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
self.model = UNetCFG1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
input_concat_cond=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
**kwargs):
channels_list = None
if input_concat_cond is not None:
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
embedding=cross_attn_cond,
embedding_mask=cross_attn_mask,
features=global_cond,
channels_list=channels_list,
embedding_scale=cfg_scale,
embedding_mask_proba=cfg_dropout_prob,
batch_cfg=batch_cfg,
rescale_cfg=rescale_cfg,
negative_embedding=negative_cross_attn_cond,
negative_embedding_mask=negative_cross_attn_mask,
**kwargs)
return outputs
class UNet1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
self.model = UNet1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
global_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
**kwargs):
channels_list = None
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
features=global_cond,
channels_list=channels_list,
**kwargs)
return outputs
class UNet1DUncondWrapper(DiffusionModel):
def __init__(
self,
in_channels,
*args,
**kwargs
):
super().__init__()
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
self.io_channels = in_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class DAU1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
self.model = DiffusionAttnUnet1D(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
**kwargs):
return self.model(x, t, cond = input_concat_cond)
class DiffusionAttnUnet1D(nn.Module):
def __init__(
self,
io_channels = 2,
depth=14,
n_attn_layers = 6,
channels = [128, 128, 256, 256] + [512] * 10,
cond_dim = 0,
cond_noise_aug = False,
kernel_size = 5,
learned_resample = False,
strides = [2] * 13,
conv_bias = True,
use_snake = False
):
super().__init__()
self.cond_noise_aug = cond_noise_aug
self.io_channels = io_channels
if self.cond_noise_aug:
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
self.timestep_embed = FourierFeatures(1, 16)
attn_layer = depth - n_attn_layers
strides = [1] + strides
block = nn.Identity()
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
for i in range(depth, 0, -1):
c = channels[i - 1]
stride = strides[i-1]
if stride > 2 and not learned_resample:
raise ValueError("Must have stride 2 without learned resampling")
if i > 1:
c_prev = channels[i - 2]
add_attn = i >= attn_layer and n_attn_layers > 0
block = SkipBlock(
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
conv_block(c_prev, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
block,
conv_block(c * 2 if i != depth else c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c_prev),
SelfAttention1d(c_prev, c_prev //
32) if add_attn else nn.Identity(),
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
)
else:
cond_embed_dim = 16 if not self.cond_noise_aug else 32
block = nn.Sequential(
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
conv_block(c, c, c),
conv_block(c, c, c),
block,
conv_block(c * 2, c, c),
conv_block(c, c, c),
conv_block(c, c, io_channels, is_last=True),
)
self.net = block
with torch.no_grad():
for param in self.net.parameters():
param *= 0.5
def forward(self, x, t, cond=None, cond_aug_scale=None):
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
inputs = [x, timestep_embed]
if cond is not None:
if cond.shape[2] != x.shape[2]:
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
if self.cond_noise_aug:
# Get a random number between 0 and 1, uniformly sampled
if cond_aug_scale is None:
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
else:
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
# Add noise to the conditioning signal
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
# Get embedding for noise cond level, reusing timestamp_embed
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
inputs.append(aug_level_embed)
inputs.append(cond)
outputs = self.net(torch.cat(inputs, dim=1))
return outputs
class DiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = DiffusionTransformer(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
negative_input_concat_cond=None,
global_cond=None,
negative_global_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_mask,
negative_cross_attn_cond=negative_cross_attn_cond,
negative_cross_attn_mask=negative_cross_attn_mask,
input_concat_cond=input_concat_cond,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
global_embed=global_cond,
**kwargs)
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
mm_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.min_input_length = min_input_length
self.mm_cond_ids = mm_cond_ids
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
inpaint_masked_input = None
t5_features = None
metaclip_global_text_features = None
clip_f = conditioning_tensors["metaclip_features"]
sync_f = conditioning_tensors["sync_features"]
text_f = conditioning_tensors["metaclip_text_features"]
if 'inpaint_masked_input' in conditioning_tensors.keys():
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
if 't5_features' in conditioning_tensors.keys():
t5_features = conditioning_tensors["t5_features"]
if 'metaclip_global_text_features' in conditioning_tensors.keys():
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
return {
"clip_f": clip_f,
"sync_f": sync_f,
"text_f": text_f,
"inpaint_masked_input": inpaint_masked_input,
"t5_features": t5_features,
"metaclip_global_text_features": metaclip_global_text_features
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class DiTUncondWrapper(DiffusionModel):
def __init__(
self,
io_channels,
*args,
**kwargs
):
super().__init__()
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
self.io_channels = io_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
model_type = diffusion_uncond_config.get('type', None)
diffusion_config = diffusion_uncond_config.get('config', {})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**diffusion_config
)
elif model_type == "adp_uncond_1d":
model = UNet1DUncondWrapper(
**diffusion_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**diffusion_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
diffusion_config = diffusion_uncond_config.get('diffusion', {})
model_type = diffusion_config.get('type', None)
model_config = diffusion_config.get("config",{})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**model_config
)
elif model_type == "adp_uncond_1d":
io_channels = model_config.get("io_channels", 64)
model = UNet1DUncondWrapper(
io_channels = io_channels,
**model_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**model_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
model_config = config["model"]
model_type = config["model_type"]
diffusion_config = model_config.get('diffusion', None)
assert diffusion_config is not None, "Must specify diffusion config"
diffusion_model_type = diffusion_config.get('type', None)
assert diffusion_model_type is not None, "Must specify diffusion model type"
diffusion_model_config = diffusion_config.get('config', None)
assert diffusion_model_config is not None, "Must specify diffusion model config"
if diffusion_model_type == 'adp_cfg_1d':
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
elif diffusion_model_type == 'adp_1d':
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
else:
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
sample_rate = config.get('sample_rate', None)
assert sample_rate is not None, "Must specify sample_rate in config"
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
conditioning_config = model_config.get('conditioning', None)
conditioner = None
if conditioning_config is not None:
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
add_cond_ids = diffusion_config.get('add_cond_ids', [])
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
zero_init = diffusion_config.get('zero_init', False)
pretransform = model_config.get("pretransform", None)
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
min_input_length *= np.prod(diffusion_model_config["factors"])
elif diffusion_model_type == "dit":
min_input_length *= diffusion_model.model.patch_size
# Get the proper wrapper class
extra_kwargs = {}
if model_type == "mm_diffusion_cond":
wrapper_fn = MMConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return wrapper_fn(
diffusion_model,
conditioner,
min_input_length=min_input_length,
sample_rate=sample_rate,
cross_attn_cond_ids=cross_attention_ids,
global_cond_ids=global_cond_ids,
input_concat_ids=input_concat_ids,
prepend_cond_ids=prepend_cond_ids,
add_cond_ids=add_cond_ids,
sync_cond_ids=sync_cond_ids,
pretransform=pretransform,
io_channels=io_channels,
zero_init=zero_init,
**extra_kwargs
)
-539
View File
@@ -1,539 +0,0 @@
import typing as tp
import math
import torch
# from beartype.typing import Tuple
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from .utils import mask_from_frac_lengths, resample
class DiffusionTransformer(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
cond_ctx_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
add_token_dim=0,
sync_token_dim=0,
use_mlp=False,
use_zero_init=False,
**kwargs):
super().__init__()
self.cond_token_dim = cond_token_dim
# Timestep embeddings
timestep_features_dim = 256
# Timestep embeddings
self.timestep_cond_type = timestep_cond_type
self.timestep_features = FourierFeatures(1, timestep_features_dim)
if timestep_cond_type == "global":
timestep_embed_dim = embed_dim
elif timestep_cond_type == "input_concat":
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
input_concat_dim += timestep_embed_dim
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
self.use_mlp = use_mlp
if cond_token_dim > 0:
# Conditioning tokens
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
self.to_cond_embed = nn.Sequential(
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
)
else:
cond_embed_dim = 0
if global_cond_dim > 0:
# Global conditioning
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
self.to_global_embed = nn.Sequential(
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
)
if add_token_dim > 0:
# Conditioning tokens
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
self.to_add_embed = nn.Sequential(
nn.Linear(add_token_dim, add_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
)
else:
add_embed_dim = 0
if sync_token_dim > 0:
# Conditioning tokens
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
self.to_sync_embed = nn.Sequential(
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
)
else:
sync_embed_dim = 0
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
self.input_concat_dim = input_concat_dim
dim_in = io_channels + self.input_concat_dim
self.patch_size = patch_size
# Transformer
self.transformer_type = transformer_type
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.global_cond_type = global_cond_type
if self.transformer_type == "continuous_transformer":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend = cond_token_dim > 0,
cond_token_dim = cond_embed_dim,
global_cond_dim=global_dim,
**kwargs
)
else:
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
nn.init.zeros_(self.preprocess_conv.weight)
nn.init.zeros_(self.postprocess_conv.weight)
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
# if isinstance(module, nn.Conv1d):
# if module.bias is not None:
# nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
# Zero-out output layers:
if self.global_cond_type == "adaLN":
for block in self.transformer.layers:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.empty_clip_feat, 0)
nn.init.constant_(self.empty_sync_feat, 0)
def _forward(
self,
x,
t,
mask=None,
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
add_masks=None,
sync_cond=None,
return_info=False,
**kwargs):
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# reshape from (b, n, c) to (b, c, n)
if input_concat_cond.shape[1] != x.shape[1]:
input_concat_cond = input_concat_cond.transpose(1,2)
# Interpolate input_concat_cond to the same length as x
# if input_concat_cond.shape[1] != x.shape[2]:
# input_concat_cond = input_concat_cond.transpose(1,2)
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
# input_concat_cond = input_concat_cond.transpose(1,2)
# if len(global_embed.shape) == 2:
# global_embed = global_embed.unsqueeze(1)
# global_embed = global_embed + input_concat_cond
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if self.timestep_cond_type == "global":
if global_embed is not None:
if len(global_embed.shape) == 3:
timestep_embed = timestep_embed.unsqueeze(1)
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
elif self.timestep_cond_type == "input_concat":
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend" and global_embed is not None:
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
if len(global_embed.shape) == 2:
prepend_inputs = global_embed.unsqueeze(1)
else:
prepend_inputs = global_embed
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
if len(global_embed.shape) == 2:
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
else:
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
b, seq_len, c = x.shape
# 计算需要填充的数量
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
if pad_amount > 0:
# 在时间维度上进行填充
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if add_cond is not None:
# Interpolate add_cond to the same length as x
# if self.use_mlp:
add_cond = self.to_add_embed(add_cond)
if add_cond.shape[1] != x.shape[1]:
add_cond = add_cond.transpose(1,2)
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
add_cond = add_cond.transpose(1,2)
# add_cond = resample(add_cond, x)
if sync_cond is not None:
sync_cond = self.to_sync_embed(sync_cond)
if self.transformer_type == "continuous_transformer":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
# 移除之前添加的填充
if pad_amount > 0:
output = output[:, :, :seq_len]
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
sync_cond=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
bsz, a, b = x.shape
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
t = t.to(model_dtype)
if cross_attn_cond is not None:
cross_attn_cond = cross_attn_cond.to(model_dtype)
if negative_cross_attn_cond is not None:
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
if input_concat_cond is not None:
input_concat_cond = input_concat_cond.to(model_dtype)
if global_embed is not None:
global_embed = global_embed.to(model_dtype)
if negative_global_embed is not None:
negative_global_embed = negative_global_embed.to(model_dtype)
if prepend_cond is not None:
prepend_cond = prepend_cond.to(model_dtype)
if add_cond is not None:
add_cond = add_cond.to(model_dtype)
if sync_cond is not None:
sync_cond = sync_cond.to(model_dtype)
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if add_cond is not None:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
add_cond = torch.where(dropout_mask, null_embed, add_cond)
if sync_cond is not None:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None and global_embed.shape[0] == bsz:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
elif global_embed is not None:
batch_global_cond = global_embed
else:
batch_global_cond = None
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
elif input_concat_cond is not None:
batch_input_concat_cond = input_concat_cond
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
elif cross_attn_cond is not None:
batch_cond = cross_attn_cond
else:
batch_cond = None
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
elif prepend_cond is not None:
batch_prepend_cond = prepend_cond
else:
batch_prepend_cond = None
batch_add_cond = None
# Handle CFG for cross-attention conditioning
if add_cond is not None and add_cond.shape[0] == bsz:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
elif add_cond is not None:
batch_add_cond = add_cond
else:
batch_add_cond = None
batch_sync_cond = None
# Handle CFG for cross-attention conditioning
if sync_cond is not None and sync_cond.shape[0] == bsz:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
elif sync_cond is not None:
batch_sync_cond = sync_cond
else:
batch_sync_cond = None
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask = batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed = batch_global_cond,
prepend_cond = batch_prepend_cond,
prepend_cond_mask = batch_prepend_cond_mask,
add_cond = batch_add_cond,
sync_cond = batch_sync_cond,
return_info = return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
add_cond=add_cond,
sync_cond=sync_cond,
mask=mask,
return_info=return_info,
**kwargs
)
-275
View File
@@ -1,275 +0,0 @@
import torch
from einops import rearrange
from torch import nn
from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
from .utils import checkpoint
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_in = None,
dim_out = None,
causal = False,
local_attn_window_size = 64,
heads = 8,
ff_mult = 2,
cond_dim = 0,
cross_attn_cond_dim = 0,
**kwargs
):
super().__init__()
dim_head = dim//heads
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
self.local_attn_window_size = local_attn_window_size
self.cond_dim = cond_dim
self.cross_attn_cond_dim = cross_attn_cond_dim
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
for _ in range(depth):
self.layers.append(nn.ModuleList([
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
Attention(
dim=dim,
dim_heads=dim_head,
causal=causal,
zero_init_output=True,
natten_kernel_size=local_attn_window_size,
),
Attention(
dim=dim,
dim_heads=dim_head,
dim_context = cross_attn_cond_dim,
zero_init_output=True
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
]))
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
x = checkpoint(self.project_in, x)
if prepend_cond is not None:
x = torch.cat([prepend_cond, x], dim=1)
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
residual = x
if cond is not None:
x = checkpoint(attn_norm, x, cond)
else:
x = checkpoint(attn_norm, x)
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
if cross_attn_cond is not None:
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
residual = x
if cond is not None:
x = checkpoint(ff_norm, x, cond)
else:
x = checkpoint(ff_norm, x)
x = checkpoint(ff, x) + residual
return checkpoint(self.project_out, x)
class TransformerDownsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim = 768,
depth = 3,
heads = 12,
downsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.downsample_ratio = downsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size=local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
def forward(self, x):
x = checkpoint(self.project_in, x)
# Compute
x = self.transformer(x)
# Trade sequence length for channels
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
# Project back to embed dim
x = checkpoint(self.project_down, x)
return x
class TransformerUpsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
depth = 3,
heads = 12,
upsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.upsample_ratio = upsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size = local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
def forward(self, x):
# Project to embed dim
x = checkpoint(self.project_in, x)
# Project to increase channel dim
x = checkpoint(self.project_up, x)
# Trade channels for sequence length
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
# Compute
x = self.transformer(x)
return x
class TransformerEncoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [96, 192, 384, 768],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerDownsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
downsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
class TransformerDecoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [768, 384, 192, 96],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerUpsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
upsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
@@ -1 +0,0 @@
# mmmodules package
@@ -1 +0,0 @@
# mmmodules.model package
-393
View File
@@ -1,393 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from scipy.optimize import fmin
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
class PQMF(nn.Module):
"""
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
Uses polyphase representation which is computationally more efficient for real-time.
Parameters:
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
"""
def __init__(self, attenuation, num_bands):
super(PQMF, self).__init__()
# Ensure num_bands is a power of 2
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
assert is_power_of_2, "'num_bands' must be a power of 2."
# Create the prototype filter
prototype_filter = design_prototype_filter(attenuation, num_bands)
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
# Register filters and settings
self.register_buffer("filter_bank", padded_filter_bank)
self.register_buffer("prototype", prototype_filter)
self.num_bands = num_bands
def forward(self, signal):
"""Decompose the signal into multiple frequency bands."""
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
signal = prepare_signal_dimensions(signal)
# The signal length must be a multiple of num_bands. Pad it with zeros.
signal = pad_signal(signal, self.num_bands)
# run it
signal = polyphase_analysis(signal, self.filter_bank)
return apply_alias_cancellation(signal)
def inverse(self, bands):
"""Reconstruct the original signal from the frequency bands."""
bands = apply_alias_cancellation(bands)
return polyphase_synthesis(bands, self.filter_bank)
def prepare_signal_dimensions(signal):
"""
Rearrange signal into Batch x Channels x Length.
Parameters
----------
signal : torch.Tensor or numpy.ndarray
The input signal.
Returns
-------
torch.Tensor
Preprocessed signal tensor.
"""
# Convert numpy to torch tensor
if isinstance(signal, np.ndarray):
signal = torch.from_numpy(signal)
# Ensure tensor
if not isinstance(signal, torch.Tensor):
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
# Modify dimension of signal to Batch x Channels x Length
if signal.dim() == 1:
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
signal = signal.unsqueeze(0).unsqueeze(0)
elif signal.dim() == 2:
# This is a multi-channel signal (e.g. stereo)
# Rearrange so that larger dimension (Length) is last
if signal.shape[0] > signal.shape[1]:
signal = signal.T
# Unsqueeze to 1 x Channels x Length
signal = signal.unsqueeze(0)
return signal
def pad_signal(signal, num_bands):
"""
Pads the signal to make its length divisible by the given number of bands.
Parameters
----------
signal : torch.Tensor
The input signal tensor, where the last dimension represents the signal length.
num_bands : int
The number of bands by which the signal length should be divisible.
Returns
-------
torch.Tensor
The padded signal tensor. If the original signal length was already divisible
by num_bands, returns the original signal unchanged.
"""
remainder = signal.shape[-1] % num_bands
if remainder > 0:
padding_size = num_bands - remainder
signal = nn.functional.pad(signal, (0, padding_size))
return signal
def generate_modulated_filter_bank(prototype_filter, num_bands):
"""
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
Parameters
----------
prototype_filter : torch.Tensor
The prototype filter used as the basis for modulation.
num_bands : int
The number of desired subbands or filters.
Returns
-------
torch.Tensor
A bank of cosine modulated filters.
"""
# Initialize indices for modulation.
subband_indices = torch.arange(num_bands).reshape(-1, 1)
# Calculate the length of the prototype filter.
filter_length = prototype_filter.shape[-1]
# Generate symmetric time indices centered around zero.
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
# Calculate phase offsets to ensure orthogonality between subbands.
phase_offsets = (-1)**subband_indices * np.pi / 4
# Compute the cosine modulation function.
modulation = torch.cos(
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
)
# Apply modulation to the prototype filter.
modulated_filters = 2 * prototype_filter * modulation
return modulated_filters
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
"""
Design a lowpass filter using the Kaiser window.
Parameters
----------
angular_cutoff : float
The angular frequency cutoff of the filter.
attenuation : float
The desired stopband attenuation in decibels (dB).
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The designed lowpass filter coefficients.
"""
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
# Ensure the estimated length is odd.
estimated_length = 2 * (estimated_length // 2) + 1
if filter_length is None:
filter_length = estimated_length
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
"""
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
Parameters
----------
angular_cutoff : float
Angular frequency cutoff of the filter.
attenuation : float
Desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter.
Returns
-------
float
The computed objective (loss) value for the given filter specs.
"""
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
def design_prototype_filter(attenuation, num_bands, filter_length=None):
"""
Design the optimal prototype filter for a multiband system given the desired specs.
Parameters
----------
attenuation : float
The desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The optimal prototype filter coefficients.
"""
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
1 / num_bands, disp=0)[0]
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
return torch.tensor(prototype_filter, dtype=torch.float32)
def pad_to_nearest_power_of_two(x):
"""
Pads the input tensor 'x' on both sides such that its last dimension
becomes the nearest larger power of two.
Parameters:
-----------
x : torch.Tensor
The input tensor to be padded.
Returns:
--------
torch.Tensor
The padded tensor.
"""
current_length = x.shape[-1]
target_length = 2**math.ceil(math.log2(current_length))
total_padding = target_length - current_length
left_padding = total_padding // 2
right_padding = total_padding - left_padding
return nn.functional.pad(x, (left_padding, right_padding))
def apply_alias_cancellation(x):
"""
Applies alias cancellation by inverting the sign of every
second element of every second row, starting from the second
row's first element in a tensor.
This operation helps ensure that the aliasing introduced in
each band during the decomposition will be counteracted during
the reconstruction.
Parameters:
-----------
x : torch.Tensor
The input tensor.
Returns:
--------
torch.Tensor
Tensor with specific elements' sign inverted for alias cancellation.
"""
# Create a mask of the same shape as 'x', initialized with all ones
mask = torch.ones_like(x)
# Update specific elements in the mask to -1 to perform inversion
mask[..., 1::2, ::2] = -1
# Apply the mask to the input tensor 'x'
return x * mask
def ensure_odd_length(tensor):
"""
Pads the last dimension of a tensor to ensure its size is odd.
Parameters:
-----------
tensor : torch.Tensor
Input tensor whose last dimension might need padding.
Returns:
--------
torch.Tensor
The original tensor if its last dimension was already odd,
or the padded tensor with an odd-sized last dimension.
"""
last_dim_size = tensor.shape[-1]
if last_dim_size % 2 == 0:
tensor = nn.functional.pad(tensor, (0, 1))
return tensor
def polyphase_analysis(signal, filter_bank):
"""
Applies the polyphase method to efficiently analyze the signal using a filter bank.
Parameters:
-----------
signal : torch.Tensor
Input signal tensor with shape (Batch x Channels x Length).
filter_bank : torch.Tensor
Filter bank tensor with shape (Bands x Length).
Returns:
--------
torch.Tensor
Signal split into sub-bands. (Batch x Channels x Bands x Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange signal for polyphase processing.
# Also combine Batch x Channel into one dimension for now.
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
# Rearrange the filter bank for matching signal shape
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
# Apply convolution with appropriate padding to maintain spatial dimensions
padding = filter_bank.shape[-1] // 2
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
# Truncate the last dimension post-convolution to adjust the output shape
filtered_signal = filtered_signal[..., :-1]
# Rearrange the first dimension back into Batch x Channels
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
return filtered_signal
def polyphase_synthesis(signal, filter_bank):
"""
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
Parameters
----------
signal : torch.Tensor
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
filter_bank : torch.Tensor
Analysis filter bank (shape: Bands x Length).
should_rearrange : bool, optional
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
Returns
-------
torch.Tensor
Reconstructed signal (shape: Batch x Channels X Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange the filter bank
filter_bank = filter_bank.flip(-1)
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
# Combine Batch x Channels into one dimension for now.
signal = rearrange(signal, "b c n t -> (b c) n t")
# Apply convolution with appropriate padding
padding_amount = filter_bank.shape[-1] // 2 + 1
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
# Scale the result
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
# Reorganize the output and truncate
reconstructed_signal = reconstructed_signal.flip(1)
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
return reconstructed_signal
-239
View File
@@ -1,239 +0,0 @@
import torch
from einops import rearrange
from torch import nn
class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
self.is_discrete = is_discrete
self.io_channels = io_channels
self.encoded_channels = None
self.downsampling_ratio = None
self.enable_grad = enable_grad
def encode(self, x):
raise NotImplementedError
def decode(self, z):
raise NotImplementedError
def tokenize(self, x):
raise NotImplementedError
def decode_tokens(self, tokens):
raise NotImplementedError
class AutoencoderPretransform(Pretransform):
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
self.model = model
self.model.requires_grad_(False).eval()
self.scale=scale
self.downsampling_ratio = model.downsampling_ratio
self.io_channels = model.io_channels
self.sample_rate = model.sample_rate
self.model_half = model_half
self.iterate_batch = iterate_batch
self.encoded_channels = model.latent_dim
self.chunked = chunked
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
if self.model_half:
self.model.half()
def encode(self, x, **kwargs):
if self.model_half:
x = x.half()
self.model.to(torch.float16)
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
encoded = encoded.float()
return encoded / self.scale
def decode(self, z, **kwargs):
z = z * self.scale
if self.model_half:
z = z.half()
self.model.to(torch.float16)
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
decoded = decoded.float()
return decoded
def tokenize(self, x, **kwargs):
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
_, info = self.model.encode(x, return_info = True, **kwargs)
return info[self.model.bottleneck.tokens_id]
def decode_tokens(self, tokens, **kwargs):
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
return self.model.decode_tokens(tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.model.load_state_dict(state_dict, strict=strict)
class PQMFPretransform(Pretransform):
def __init__(self, attenuation=100, num_bands=16):
# TODO: Fix PQMF to take in in-channels
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
from .pqmf import PQMF
self.pqmf = PQMF(attenuation, num_bands)
def encode(self, x):
# x is (Batch x Channels x Time)
x = self.pqmf.forward(x)
# pqmf.forward returns (Batch x Channels x Bands x Time)
# but Pretransform needs Batch x Channels x Time
# so concatenate channels and bands into one axis
return rearrange(x, "b c n t -> b (c n) t")
def decode(self, x):
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
# returns (Batch x Channels x Time)
return self.pqmf.inverse(x)
class PretrainedDACPretransform(Pretransform):
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
import dac
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
self.model = dac.DAC.load(model_path)
self.quantize_on_decode = quantize_on_decode
if model_type == "44khz":
self.downsampling_ratio = 512
else:
self.downsampling_ratio = 320
self.io_channels = 1
self.scale = scale
self.chunked = chunked
self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.n_codebooks
self.codebook_size = self.model.codebook_size
def encode(self, x):
latents = self.model.encoder(x)
if self.quantize_on_decode:
output = latents
else:
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
output = z
if self.scale != 1.0:
output = output / self.scale
return output
def decode(self, z):
if self.scale != 1.0:
z = z * self.scale
if self.quantize_on_decode:
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
return self.model.decode(z)
def tokenize(self, x):
return self.model.encode(x)[1]
def decode_tokens(self, tokens):
latents = self.model.quantizer.from_codes(tokens)
return self.model.decode(latents)
class AudiocraftCompressionPretransform(Pretransform):
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
try:
from audiocraft.models import CompressionModel
except ImportError:
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
self.model = CompressionModel.get_pretrained(model_type)
self.quantize_on_decode = quantize_on_decode
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
self.sample_rate = self.model.sample_rate
self.io_channels = self.model.channels
self.scale = scale
#self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.num_codebooks
self.codebook_size = self.model.cardinality
self.model.to(torch.float16).eval().requires_grad_(False)
def encode(self, x):
assert False, "Audiocraft compression models do not support continuous encoding"
# latents = self.model.encoder(x)
# if self.quantize_on_decode:
# output = latents
# else:
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
# output = z
# if self.scale != 1.0:
# output = output / self.scale
# return output
def decode(self, z):
assert False, "Audiocraft compression models do not support continuous decoding"
# if self.scale != 1.0:
# z = z * self.scale
# if self.quantize_on_decode:
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
# return self.model.decode(z)
def tokenize(self, x):
with torch.cuda.amp.autocast(enabled=False):
return self.model.encode(x.to(torch.float16))[0]
def decode_tokens(self, tokens):
with torch.cuda.amp.autocast(enabled=False):
return self.model.decode(tokens)
-989
View File
@@ -1,989 +0,0 @@
from functools import reduce, partial
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from typing import Callable, Literal
try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
flash_attn_kvpacked_func = None
flash_attn_func = None
from .utils import compile, checkpoint
try:
import natten
except ImportError:
natten = None
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device = device)
return self.forward(t)
@autocast(enabled = False)
def forward(self, t):
device = self.inv_freq.device
t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim = -1)
# norms
class DynamicTanh(nn.Module):
def __init__(self, dim, init_alpha=10.0):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
x = F.tanh(self.alpha * x)
return self.gamma * x + self.beta
class RunningInstanceNorm(nn.Module):
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
super().__init__()
self.register_buffer("running_mean", torch.zeros(1,1,dim))
self.register_buffer("running_std", torch.ones(1,1,dim))
self.saturate = saturate
self.eps = eps
self.momentum = momentum
self.dim = dim
self.trainable_gain = trainable_gain
if self.trainable_gain:
self.gain = nn.Parameter(torch.ones(1))
def _update_stats(self, x):
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
def forward(self, x):
if self.training:
self._update_stats(x)
x = (x - self.running_mean) / self.running_std
if self.saturate:
x = torch.asinh(x)
if self.trainable_gain:
x = x * self.gain
return x
class LayerNorm(nn.Module):
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
if fix_scale:
self.register_buffer("gamma", torch.ones(dim))
else:
self.gamma = nn.Parameter(torch.ones(dim))
if bias:
self.beta = nn.Parameter(torch.zeros(dim))
else:
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.force_fp32 = force_fp32
def forward(self, x):
if not self.force_fp32:
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
else:
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
return output.to(x.dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_val = 1e-5):
super().__init__()
self.scale = nn.Parameter(torch.full([dim], init_val))
def forward(self, x):
return x * self.scale
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
use_conv = False,
conv_kernel_size = 3,
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim = -1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
no_bias = False,
glu = True,
use_conv = False,
conv_kernel_size = 3,
zero_init_output = True,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
# init last linear layer to 0
if zero_init_output:
nn.init.zeros_(linear_out.weight)
if not no_bias:
nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
differential = False,
feat_scale = False
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.differential = differential
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
if differential:
self.to_q = nn.Linear(dim, dim * 2, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
else:
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
else:
if differential:
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim, bias=False)
if zero_init_output:
nn.init.zeros_(self.to_out.weight)
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
elif self.qk_norm == 'rns':
self.q_norm = nn.RMSNorm(dim_heads)
self.k_norm = nn.RMSNorm(dim_heads)
elif self.qk_norm == 'dyt':
self.q_norm = DynamicTanh(dim_heads)
self.k_norm = DynamicTanh(dim_heads)
self.sdp_kwargs = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
self.feat_scale = feat_scale
if self.feat_scale:
self.lambda_dc = nn.Parameter(torch.zeros(dim))
self.lambda_hf = nn.Parameter(torch.zeros(dim))
self.causal = causal
@compile
def apply_qk_layernorm(self, q, k):
q_type = q.dtype
k_type = k.dtype
q = self.q_norm(q).to(q_type)
k = self.k_norm(k).to(k_type)
return q, k
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
if self.num_heads != self.kv_heads:
# Repeat interleave kv_heads to match q_heads for grouped query attention
heads_per_kv_head = self.num_heads // self.kv_heads
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
flash_attn_available = HAS_FLASH_ATTN
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
flex_attention_block_mask = None
flex_attention_score_mod = None
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
raise NotImplementedError(
"FlexAttention is not available in this build. "
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
)
elif flash_attn_available:
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
else:
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
return out
#@compile
def forward(
self,
x,
context = None,
rotary_pos_emb = None,
causal = None,
flex_attention_block_mask = None,
flex_attention_score_mod = None,
flash_attn_sliding_window = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
if self.differential:
q, q_diff = self.to_q(x).chunk(2, dim=-1)
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
q = torch.stack([q, q_diff], dim = 1)
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
k = torch.stack([k, k_diff], dim = 1)
else:
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
if self.differential:
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
q = torch.stack([q, q_diff], dim = 1)
k = torch.stack([k, k_diff], dim = 1)
else:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm != "none":
q, k = self.apply_qk_layernorm(q, k)
if rotary_pos_emb is not None:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
if q.shape[-2] >= k.shape[-2]:
ratio = q.shape[-2] / k.shape[-2]
q_freqs, k_freqs = freqs, ratio * freqs
else:
ratio = k.shape[-2] / q.shape[-2]
q_freqs, k_freqs = ratio * freqs, freqs
q = apply_rotary_pos_emb(q, q_freqs)
k = apply_rotary_pos_emb(k, k_freqs)
q = q.to(v.dtype)
k = k.to(v.dtype)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if self.differential:
q, q_diff = q.unbind(dim = 1)
k, k_diff = k.unbind(dim = 1)
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out = out - out_diff
else:
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
# merge heads
out = rearrange(out, ' b h n d -> b n (h d)')
# Communicate between heads
# with autocast(enabled = False):
# out_dtype = out.dtype
# out = out.to(torch.float32)
# out = self.to_out(out).to(out_dtype)
out = self.to_out(out)
if self.feat_scale:
out_dc = out.mean(dim=-2, keepdim=True)
out_hf = out - out_dc
# Selectively modulate DC and high frequency components
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs = {},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
#@compile
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
cross_attend = False,
dim_context = None,
global_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
add_rope = False,
layer_scale = False,
use_sync_block_film = False,
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {}
):
super().__init__()
self.dim = dim
self.dim_heads = min(dim_heads,dim)
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
if layer_scale and zero_init_branch_outputs:
zero_init_branch_outputs = False
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.add_rope = add_rope
self.self_attn = Attention(
dim,
dim_heads = self.dim_heads,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.cross_attend = cross_attend
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.cross_attn = Attention(
dim,
dim_heads = self.dim_heads,
dim_context=dim_context,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.layer_ix = layer_ix
self.conformer = None
if conformer:
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
if use_sync_block_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
@compile
def forward(
self,
x,
context = None,
global_cond=None,
rotary_pos_emb = None,
self_attention_block_mask = None,
self_attention_score_mod = None,
cross_attention_block_mask = None,
cross_attention_score_mod = None,
self_attention_flash_sliding_window = None,
cross_attention_flash_sliding_window = None,
sync_cond = None,
prepend_length=0
):
if rotary_pos_emb is None and self.add_rope:
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
if len(global_cond.shape) == 2:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
else:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
x = x * torch.sigmoid(1 - gate_self)
x = self.self_attn_scale(x)
x = x + residual
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = self.ff_scale(x)
x = x + residual
else:
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
prepend_part = x[:, :prepend_length, :]
audio_part = x[:, prepend_length:, :]
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
modulated_audio_part = audio_part * (1 + scale) + shift
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in = None,
dim_out = None,
dim_heads = 64,
cross_attend=False,
cond_token_dim=None,
pre_cross_attn_ix=-1,
final_cross_attn_ix=-1,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
num_memory_tokens=0,
sliding_window=None,
use_mlp=False,
use_add_norm=False,
use_gated=False,
use_final_layer=False,
use_zeros=False,
use_conv=False,
use_fusion_mlp=False,
use_film=False,
use_sync_film=False,
use_sync_gated=False,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
if use_mlp:
self.project_in = nn.Sequential(
nn.Linear(dim_in, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim, bias=False)
)
else:
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
self.video_temporal_conv = None
self.audio_temporal_conv = None
self.fusion_mlp = None
if use_conv:
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
if use_fusion_mlp:
self.fusion_mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
self.adaLN_modulation = None
if global_cond_dim is not None:
if use_final_layer:
self.norm_final = LayerNorm(dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(
dim, 2 * dim, bias=True
),
)
if use_zeros:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.project_out.weight, 0)
self.global_cond_embedder = nn.Sequential(
nn.Linear(global_cond_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6)
)
if use_zeros:
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
self.final_cross_attn_ix = final_cross_attn_ix
self.use_gated = use_gated
self.use_film = use_film
self.use_add_norm = use_add_norm
if self.use_add_norm:
self.add_norm = nn.LayerNorm(dim)
if use_gated:
self.gate = nn.Parameter(torch.ones(1, 1, dim))
if use_film:
self.film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.film_generator = None
if use_sync_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.sync_film_generator = None
if use_sync_gated:
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
else:
self.sync_gate = None
self.sliding_window = sliding_window
for i in range(depth):
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
# print(f"Layer {i} cross attends: {should_cross_attend}")
self.layers.append(
TransformerBlock(
dim,
dim_heads = dim_heads,
cross_attend = should_cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
**kwargs
)
)
def forward(
self,
x,
mask = None,
prepend_embeds = None,
prepend_mask = None,
add_cond = None,
sync_cond = None,
global_cond = None,
return_info = False,
use_checkpointing = True,
exit_layer_ix = None,
video_dropout_prob = 0.0,
**kwargs
):
batch, seq, device = *x.shape[:2], x.device
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
prepend_length = 0
info = {
"hidden_states": [],
}
x = self.project_in(x)
if add_cond is not None:
if self.use_gated:
gate = torch.sigmoid(self.gate)
x = x + gate * add_cond
elif self.use_film:
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
else:
x = x + add_cond
if self.use_add_norm:
x = self.add_norm(x)
if self.fusion_mlp is not None:
x = self.fusion_mlp(x)
if sync_cond is not None:
# Resample sync_cond to match audio sequence length if needed
if sync_cond.shape[1] != x.shape[1]:
sync_cond = torch.nn.functional.interpolate(
sync_cond.transpose(1, 2), size=x.shape[1],
mode='linear', align_corners=False,
).transpose(1, 2)
if self.sync_film_generator is not None:
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
elif self.sync_gate is not None:
gate_value = torch.sigmoid(self.sync_gate)
x = x + gate_value * sync_cond
# else:
# x = x + sync_cond
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim = -2)
if self.num_memory_tokens > 0:
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
if global_cond is not None and self.global_cond_embedder is not None:
global_cond_embed = self.global_cond_embedder(global_cond)
else:
global_cond_embed = global_cond
# Iterate over the transformer layers
for layer_ix, layer in enumerate(self.layers):
if use_checkpointing:
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
if return_info:
info["hidden_states"].append(x)
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
x = x[:, self.num_memory_tokens:, :]
if return_info:
return x, info
return x
x = x[:, self.num_memory_tokens:, :]
if global_cond is not None and self.adaLN_modulation is not None:
if len(global_cond.shape) == 2:
global_cond = global_cond.unsqueeze(1)
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.project_out(x)
if return_info:
return x, info
return x
-180
View File
@@ -1,180 +0,0 @@
import torch
from safetensors.torch import load_file
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
from torch.nn.utils import remove_weight_norm
def load_ckpt_state_dict(ckpt_path, prefix=None):
if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# 过滤特定前缀的state_dict
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
return filtered_state_dict
def remove_weight_norm_from_model(model):
for module in model.modules():
if hasattr(module, "weight"):
remove_weight_norm(module)
return model
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
# License can be found in LICENSES/LICENSE_META.txt
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in top-k.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in top-p.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def next_power_of_two(n):
return 2 ** (n - 1).bit_length()
def next_multiple_of_64(n):
return ((n + 63) // 64) * 64
# mask construction helpers
def mask_from_start_end_indices(
seq_len: int,
start: Tensor,
end: Tensor
):
assert start.shape == end.shape
device = start.device
seq = torch.arange(seq_len, device = device, dtype = torch.long)
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
seq = seq.expand(*start.shape, seq_len)
mask = seq >= start[..., None].long()
mask &= seq < end[..., None].long()
return mask
def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths.device
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def _build_spline(video_feat, video_t, target_t):
# 三次样条插值核心实现
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
spline = NaturalCubicSpline(coeffs)
return spline.evaluate(target_t).permute(0,2,1)
def resample(video_feat, audio_latent):
"""
9s
video_feat: [B, 72, D]
audio_latent: [B, D', 194] or int
"""
B, Tv, D = video_feat.shape
if isinstance(audio_latent, torch.Tensor):
# audio_latent is a tensor
if audio_latent.shape[1] != 64:
Ta = audio_latent.shape[1]
else:
Ta = audio_latent.shape[2]
elif isinstance(audio_latent, int):
# audio_latent is an int
Ta = audio_latent
else:
raise TypeError("audio_latent must be either a tensor or an int")
# 构建时间戳 (关键改进点)
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
# 三维化处理 (Batch, Feature, Time)
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
# 三次样条插值
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
import os
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
def compile(function, *args, **kwargs):
if enable_torch_compile:
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
return function
-6
View File
@@ -1,11 +1,5 @@
einops>=0.7.0 einops>=0.7.0
einops-exts
safetensors
huggingface_hub huggingface_hub
transformers>=4.52.3 transformers>=4.52.3
k-diffusion>=0.1.1
alias-free-torch
descript-audio-codec
vector-quantize-pytorch
scipy scipy
tqdm tqdm
-21
View File
@@ -1,21 +0,0 @@
name: prismaudio-extract
channels:
- conda-forge
- defaults
dependencies:
- python=3.10
- pip
- ffmpeg<7
- pip:
- torch>=2.6.0
- torchaudio>=2.6.0
- torchvision>=0.21.0
- tensorflow-cpu==2.15.0
- jax
- jaxlib
- transformers>=4.52.3
- decord
- einops>=0.7.0
- numpy
- mediapy
- git+https://github.com/google-deepmind/videoprism.git
-168
View File
@@ -1,168 +0,0 @@
#!/usr/bin/env python3
"""
Standalone PrismAudio feature extraction script.
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
Usage:
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
"""
import argparse
import os
import sys
import time
import numpy as np
import torch
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
if _PLUGIN_DIR not in sys.path:
sys.path.insert(0, _PLUGIN_DIR)
def _step(n, total, label):
"""Print step header and return start time."""
print(f"[extract] Step {n}/{total}{label}...", flush=True)
return time.perf_counter()
def _done(t0, extra=""):
elapsed = time.perf_counter() - t0
suffix = f" {extra}" if extra else ""
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
def main():
t_total = time.perf_counter()
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
parser.add_argument("--video", required=True, help="Path to input video")
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
parser.add_argument("--output", required=True, help="Output .npz path")
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
parser.add_argument("--clip_fps", type=float, default=4.0)
parser.add_argument("--clip_size", type=int, default=288)
parser.add_argument("--sync_fps", type=float, default=25.0)
parser.add_argument("--sync_size", type=int, default=224)
args = parser.parse_args()
print(f"[extract] Python : {sys.executable}", flush=True)
print(f"[extract] Video : {args.video}", flush=True)
print(f"[extract] Output : {args.output}", flush=True)
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
if not os.path.exists(args.video):
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
sys.exit(1)
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------------
t0 = _step(1, 6, "importing dependencies")
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
import torchvision.transforms as T
_done(t0)
# ------------------------------------------------------------------
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
feat_utils = FeaturesUtils(
vae_config_path=args.vae_config,
synchformer_ckpt=args.synchformer_ckpt,
device=device,
)
_done(t0)
# ------------------------------------------------------------------
t0 = _step(3, 6, "reading and preprocessing video")
if args.video.endswith(".npy"):
all_frames = np.load(args.video) # [T, H, W, C] uint8
fps = args.source_fps
total_frames = all_frames.shape[0]
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = all_frames[clip_indices]
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = all_frames[sync_indices]
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
else:
from decord import VideoReader, cpu
vr = VideoReader(args.video, ctx=cpu(0))
fps = vr.get_avg_fps()
total_frames = len(vr)
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = vr.get_batch(clip_indices).asnumpy()
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = vr.get_batch(sync_indices).asnumpy()
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
clip_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.clip_size),
T.CenterCrop(args.clip_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
sync_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.sync_size),
T.CenterCrop(args.sync_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
_done(t0)
# ------------------------------------------------------------------
t0 = _step(4, 6, "encoding text with T5-Gemma")
text_features = feat_utils.encode_t5_text([args.cot_text])
_done(t0, f"shape={tuple(text_features.shape)}")
# ------------------------------------------------------------------
t0 = _step(5, 6, "encoding video with VideoPrism")
global_video_features, video_features, global_text_features = \
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
# ------------------------------------------------------------------
t0 = _step(6, 6, "encoding video with Synchformer")
sync_features = feat_utils.encode_video_with_sync(sync_input)
_done(t0, f"shape={tuple(sync_features.shape)}")
# ------------------------------------------------------------------
t0 = time.perf_counter()
print(f"[extract] Saving features to {args.output} ...", flush=True)
np.savez(
args.output,
video_features=video_features.cpu().float().numpy(),
global_video_features=global_video_features.cpu().float().numpy(),
text_features=text_features.cpu().float().numpy(),
global_text_features=global_text_features.cpu().float().numpy(),
sync_features=sync_features.cpu().float().numpy(),
caption_cot=args.cot_text,
duration=duration,
)
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
if __name__ == "__main__":
main()
-44
View File
@@ -1,44 +0,0 @@
#!/usr/bin/env bash
# Install the PrismAudio feature-extraction environment using pip venv.
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
#
# Usage:
# bash scripts/install_extract_env.sh [/path/to/venv]
#
# Default venv path: /opt/prismaudio-extract
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
# <venv>/bin/python (Linux/Mac)
# <venv>\Scripts\python.exe (Windows)
set -euo pipefail
VENV_DIR="${1:-/opt/prismaudio-extract}"
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
python3 -m venv "${VENV_DIR}"
PIP="${VENV_DIR}/bin/pip"
echo "[PrismAudio] Upgrading pip..."
"${PIP}" install --upgrade pip
echo "[PrismAudio] Installing PyTorch stack..."
"${PIP}" install torch torchaudio torchvision
echo "[PrismAudio] Installing feature-extraction dependencies..."
"${PIP}" install \
"tensorflow-cpu>=2.16.0" \
"jax[cpu]" \
"jaxlib" \
"transformers" \
"decord" \
"einops" \
"numpy" \
"mediapy"
echo "[PrismAudio] Installing VideoPrism..."
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
echo ""
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
echo " ${VENV_DIR}/bin/python"
+3
View File
@@ -0,0 +1,3 @@
# Vendored from https://github.com/jnwnlee/selva
# Pinned commit: d7d40a992aab58e7cf246055681a657e5d8b4a4d
# Imports rewritten from selva.* → selva_core.*
+190
View File
@@ -0,0 +1,190 @@
import logging
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Optional
import av
import numpy as np
import torch
from av import AudioFrame
log = logging.getLogger()
@dataclass
class VideoInfo:
duration_sec: float
fps: Fraction
clip_frames: torch.Tensor
sync_frames: torch.Tensor
all_frames: Optional[list[np.ndarray]]
@property
def height(self):
return self.all_frames[0].shape[0]
@property
def width(self):
return self.all_frames[0].shape[1]
@classmethod
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
fps: Fraction) -> 'VideoInfo':
num_frames = int(duration_sec * fps)
all_frames = [image_info.original_frame] * num_frames
return cls(duration_sec=duration_sec,
fps=fps,
clip_frames=image_info.clip_frames,
sync_frames=image_info.sync_frames,
all_frames=all_frames)
@dataclass
class ImageInfo:
clip_frames: torch.Tensor
sync_frames: torch.Tensor
original_frame: Optional[np.ndarray]
@property
def height(self):
return self.original_frame.shape[0]
@property
def width(self):
return self.original_frame.shape[1]
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
output_frames = [[] for _ in list_of_fps]
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
all_frames = []
# container = av.open(video_path)
with av.open(video_path) as container:
stream = container.streams.video[0]
fps = stream.guessed_rate
stream.thread_type = 'AUTO'
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < start_sec:
continue
if frame_time > end_sec:
break
frame_np = None
if need_all_frames:
frame_np = frame.to_ndarray(format='rgb24')
all_frames.append(frame_np)
for i, _ in enumerate(list_of_fps):
this_time = frame_time
while this_time >= next_frame_time_for_each_fps[i]:
if frame_np is None:
frame_np = frame.to_ndarray(format='rgb24')
output_frames[i].append(frame_np)
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
output_frames = [np.stack(frames) for frames in output_frames]
return output_frames, all_frames, fps
def normalize_video_chunk(video_chunk: torch.Tensor,
expected_length: int,
*,
n_tolerance_frame: int = 1,
desc: str = "") \
-> torch.Tensor:
# video_chunk: [T, H, W, C]
if video_chunk.shape[0] < expected_length:
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
# copy the last frame to make it the right length
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
assert video_chunk.shape[0] == expected_length
else:
raise RuntimeError(
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
)
video_chunk = video_chunk[:expected_length]
if video_chunk.shape[0] != expected_length:
raise RuntimeError(f'Video wrong length {desc}, '
f'expected {expected_length}, '
f'got {video_chunk.shape[0]}')
return video_chunk
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
sampling_rate: int):
container = av.open(output_path, 'w')
output_video_stream = container.add_stream('h264', video_info.fps)
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
output_video_stream.width = video_info.width
output_video_stream.height = video_info.height
output_video_stream.pix_fmt = 'yuv420p'
output_audio_stream = container.add_stream('aac', sampling_rate)
# encode video
for image in video_info.all_frames:
image = av.VideoFrame.from_ndarray(image)
packet = output_video_stream.encode(image)
container.mux(packet)
for packet in output_video_stream.encode():
container.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
container.mux(packet)
for packet in output_audio_stream.encode():
container.mux(packet)
container.close()
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
"""
NOTE: I don't think we can get the exact video duration right without re-encoding
so we are not using this but keeping it here for reference
"""
video = av.open(video_path)
output = av.open(output_path, 'w')
input_video_stream = video.streams.video[0]
output_video_stream = output.add_stream(template=input_video_stream)
output_audio_stream = output.add_stream('aac', sampling_rate)
duration_sec = audio.shape[-1] / sampling_rate
for packet in video.demux(input_video_stream):
# We need to skip the "flushing" packets that `demux` generates.
if packet.dts is None:
continue
# We need to assign the packet to the new stream.
packet.stream = output_video_stream
output.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
output.mux(packet)
for packet in output_audio_stream.encode():
output.mux(packet)
video.close()
output.close()
output.close()
+227
View File
@@ -0,0 +1,227 @@
import logging
import random
from typing import Optional
import numpy as np
import torch
from omegaconf import DictConfig, open_dict
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from selva_core.data.vgg_sound import VGGSound
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
from selva_core.data.eval.audiocaps import AudioCapsData
from selva_core.data.mm_dataset import MultiModalDataset
from selva_core.data.mixup import DataMixupCollate
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
# Re-seed randomness every time we start a worker
def worker_init_fn(worker_id: int):
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
np.random.seed(worker_seed)
random.seed(worker_seed)
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
) -> Dataset:
dataset = VGGSound(root=data_cfg.root,
tsv_path=data_cfg.subset_name,
sample_rate=16_000,
duration_sec=8.0,
normalize_audio=normalize_audio,
mmap_dir=data_cfg.memmap_dir,
tsv_tsynch_path=data_cfg.tsv_tsynch,
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
data_dim=cfg.data_dim
)
return dataset
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
raise NotImplementedError('Audio data loading is not implemented yet')
def setup_training_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DistributedSampler, DataLoader]:
if cfg.mini_train:
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
dataset = MultiModalDataset([vgg], [])
if cfg.example_train:
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
dataset = MultiModalDataset([video], [])
else:
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
# load the largest one first
# you can add more video/audio data upon demand, such as
# clotho = load_audio_data(cfg, cfg.data.Clotho)
dataset = MultiModalDataset([vgg], [])
batch_size = cfg.batch_size
num_workers = cfg.num_workers
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_params = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_params)
else:
collate_fn = None
sampler, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, sampler, loader
def setup_test_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DistributedSampler, DataLoader]:
if cfg.example_train:
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
elif cfg.dataset.startswith('vggsound'):
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
else:
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
batch_size = cfg.batch_size
num_workers = cfg.get('num_workers_val', cfg.num_workers)
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_config = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_config)
else:
collate_fn = None
sampler, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, sampler, loader
def setup_val_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DataLoader, DataLoader]:
if cfg.example_train:
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
else:
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
val_batch_size = cfg.batch_size
val_eval_batch_size = cfg.eval_batch_size
num_workers = cfg.get('num_workers_val', cfg.num_workers)
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_config = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_config)
else:
collate_fn = None
_, val_loader = construct_loader(dataset,
val_batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
_, eval_loader = construct_loader(dataset,
val_eval_batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, val_loader, eval_loader
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
if dataset_name.startswith('audiocaps_full'):
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
cfg.eval_data.audiocaps_full.csv_path)
elif dataset_name.startswith('audiocaps'):
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
cfg.eval_data.audiocaps.csv_path)
elif dataset_name.startswith('vggsound'):
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
cfg.eval_data.vggsound.csv_path,
duration_sec=cfg.duration_s)
elif dataset_name.startswith('infer_video'):
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
cfg.eval_data.infer_video.jsonl_path,
duration_sec=cfg.duration_s)
cfg.batch_size = 1
elif dataset_name.startswith('example_video'):
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
cfg.eval_data.Example_video.csv_path,
duration_sec=cfg.duration_s)
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
cfg.eval_data[dataset_name].csv_path,
duration_sec=cfg.duration_s)
else:
raise ValueError(f'Invalid dataset name: {dataset_name}')
batch_size = cfg.batch_size
num_workers = cfg.num_workers
pin_memory = cfg.pin_memory
_, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
error_avoidance=True)
return dataset, loader
def error_avoidance_collate(batch):
# Filter our None values
batch = [item for item in batch if item is not None]
if len(batch) == 0:
return None
return default_collate(batch)
def construct_loader(dataset: Dataset,
batch_size: int,
num_workers: int,
*,
shuffle: bool = True,
drop_last: bool = True,
pin_memory: bool = False,
error_avoidance: bool = False,
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
train_loader = DataLoader(dataset,
batch_size,
sampler=train_sampler,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
drop_last=drop_last,
persistent_workers=num_workers > 0,
pin_memory=pin_memory,
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
return train_sampler, train_loader
+39
View File
@@ -0,0 +1,39 @@
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Union
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class AudioCapsData(Dataset):
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
df = pd.read_csv(csv_path).to_dict(orient='records')
audio_files = sorted(os.listdir(audio_path))
audio_files = set(
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
self.data = []
for row in df:
self.data.append({
'name': row['name'],
'caption': row['caption'],
})
self.audio_path = Path(audio_path)
self.csv_path = Path(csv_path)
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[idx]
def __len__(self):
return len(self.data)
+237
View File
@@ -0,0 +1,237 @@
import json
import logging
import os
from pathlib import Path
from typing import Union
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VideoDataset(Dataset):
def __init__(
self,
video_root: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
self.video_root = Path(video_root)
self.duration_sec = duration_sec
self.clip_video_required = clip_video_required
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if self.clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
# to be implemented by subclasses
self.captions = {}
self.negative_captions = {}
self.videos = sorted(list(self.captions.keys()))
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
caption = self.captions[video_id]
negative_caption = self.negative_captions.get(video_id, None)
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.clip_video_required:
clip_chunk = data_chunk[1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
data = {
'name': video_id,
'caption': caption,
'sync_video': sync_chunk,
}
if self.clip_video_required:
data['clip_video'] = clip_chunk
if negative_caption is not None:
data['negative_caption'] = negative_caption
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.captions)
class VGGSound(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
csv_path: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.csv_path = Path(csv_path)
videos = sorted(os.listdir(self.video_root))
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.captions = {}
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
'split']).to_dict(orient='records')
videos_no_found = []
for row in df:
if row['split'] == 'test':
start_sec = int(row['sec'])
video_id = str(row['id'])
# this is how our videos are named
video_name = f'{video_id}_{start_sec:06d}'
if video_name + '.mp4' not in videos:
videos_no_found.append(video_name)
continue
self.captions[video_name] = row['caption']
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
log.info(f'{len(self.captions)} useable videos found')
if videos_no_found:
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
log.info(
'A small amount is expected, as not all videos are still available on YouTube')
self.videos = sorted(list(self.captions.keys()))
class InferenceVideoData(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
jsonl_root: Union[str, Path],
*,
duration_sec: float = 10.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.jsonl_root = Path(jsonl_root)
videos = sorted(os.listdir(self.video_root))
videos = [v[:-4] for v in videos] # remove extensions
self.captions = {}
for v in videos:
with open(self.jsonl_root / (v + '.jsonl')) as f:
data = json.load(f)
self.captions[v] = data['audio_prompt']
self.negative_captions[v] = data.get('negative_audio_prompt', None)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.videos = videos
class VGGMonoAudioBench(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
csv_path: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.csv_path = Path(csv_path)
videos = sorted(os.listdir(self.video_root))
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.captions = {}
self.negative_captions = {}
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
).to_dict(orient='records')
videos_no_found = []
for row in df:
video_name = str(Path(row['file_name']).stem)
if video_name + '.mp4' not in videos:
videos_no_found.append(video_name)
continue
self.captions[video_name] = row['label']
self.negative_captions[video_name] = row['paired_label']
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
log.info(f'{len(self.captions)} useable videos found')
if videos_no_found:
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
self.videos = sorted(list(self.captions.keys()))
+194
View File
@@ -0,0 +1,194 @@
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
audio_required: bool = True,
sample_rate: int = 16_000,
duration_sec: float = 8.0,
audio_samples: Optional[int] = None,
normalize_audio: bool = False,
clip_video_required: bool = True,
):
self.root = Path(root)
self.audio_required = audio_required
if audio_required:
self.normalize_audio = normalize_audio
if audio_samples is None:
self.audio_samples = int(sample_rate * duration_sec)
else:
self.audio_samples = audio_samples
effective_duration = audio_samples / sample_rate
# make sure the duration is close enough, within 15ms
assert abs(effective_duration - duration_sec) < 0.015, \
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
self.clip_video_required = clip_video_required
videos = sorted(os.listdir(self.root))
videos = set([Path(v).stem for v in videos]) # remove extensions
self.labels = {}
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
label = record['label']
if id in videos:
self.labels[id] = label
self.videos.append(id)
else:
missing_videos.append(id)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
self.sample_rate = sample_rate
self.duration_sec = duration_sec
if audio_required:
self.expected_audio_length = self.audio_samples
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
if clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if clip_video_required:
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
if audio_required:
self.resampler = {}
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
label = self.labels[video_id]
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.audio_required:
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.audio_required:
audio_chunk = data_chunk[1]
if self.clip_video_required:
clip_chunk = data_chunk[2 if self.audio_required else 1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
# process audio
if self.audio_required:
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
audio_chunk = audio_chunk.transpose(0, 1)
audio_chunk = audio_chunk.mean(dim=0) # mono
if self.normalize_audio:
abs_max = audio_chunk.abs().max()
audio_chunk = audio_chunk * (0.95 / abs_max)
if abs_max <= 1e-6:
raise RuntimeError(f'Audio is silent {video_id}')
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.expected_audio_length:
raise RuntimeError(f'Audio too short {video_id}')
audio_chunk = audio_chunk[:self.expected_audio_length]
data = {
'id': video_id,
'caption': label,
'sync_video': sync_chunk,
}
if self.audio_required:
data['audio'] = audio_chunk
if self.clip_video_required:
data['clip_video'] = clip_chunk
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)
+129
View File
@@ -0,0 +1,129 @@
import logging
import os
from pathlib import Path
from typing import Union
import open_clip
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class WavTextClipsDataset(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
captions_tsv: Union[str, Path],
clips_tsv: Union[str, Path],
sample_rate: int,
num_samples: int,
normalize_audio: bool = False,
reject_silent: bool = False,
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
):
self.root = Path(root)
self.sample_rate = sample_rate
self.num_samples = num_samples
self.normalize_audio = normalize_audio
self.reject_silent = reject_silent
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
audios = sorted(os.listdir(self.root))
audios = set([
Path(audio).stem for audio in audios
if audio.endswith('.wav') or audio.endswith('.flac')
])
self.captions = {}
# read the caption tsv
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
caption = record['caption']
self.captions[id] = caption
# read the clip tsv
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
'id': str,
'name': str
}).to_dict('records')
self.clips = []
for record in df_list:
record['id'] = record['id']
record['name'] = record['name']
id = record['id']
name = record['name']
record['caption'] = self.captions[name]
self.clips.append(record)
log.info(f'Found {len(self.clips)} audio files in {self.root}')
self.resampler = {}
def __getitem__(self, idx: int) -> torch.Tensor:
try:
clip = self.clips[idx]
audio_name = clip['name']
audio_id = clip['id']
caption = clip['caption']
start_sample = clip['start_sample']
end_sample = clip['end_sample']
audio_path = self.root / f'{audio_name}.flac'
if not audio_path.exists():
audio_path = self.root / f'{audio_name}.wav'
assert audio_path.exists()
audio_chunk, sample_rate = torchaudio.load(audio_path)
audio_chunk = audio_chunk.mean(dim=0) # mono
abs_max = audio_chunk.abs().max()
if self.normalize_audio:
audio_chunk = audio_chunk / abs_max * 0.95
if self.reject_silent and abs_max < 1e-6:
log.warning(f'Rejecting silent audio')
return None
audio_chunk = audio_chunk[start_sample:end_sample]
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.num_samples:
raise ValueError('Audio is too short')
audio_chunk = audio_chunk[:self.num_samples]
tokens = self.tokenizer([caption])[0]
output = {
'waveform': audio_chunk,
'id': audio_id,
'caption': caption,
'tokens': tokens,
}
return output
except Exception as e:
log.error(f'Error reading {audio_path}: {e}')
return None
def __len__(self):
return len(self.clips)
+338
View File
@@ -0,0 +1,338 @@
""" Embedding Mixup
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
"""
from typing import Literal, Tuple, Union, List, Optional
from functools import partial
import gc
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import v2
from einops import rearrange
from omegaconf import DictConfig
from selva_core.data.vgg_sound import _SYNC_SIZE
class MixupBase:
""" Base class for mixup on either data or feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both'],
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
self.modality = modality
self.mixup_lambda:float = mixup_lambda
self.mixup_alpha:float = mixup_alpha
self.mix_prob:float = prob
self.mode:str = mode
self.eps:float = eps
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
if generator.device.type == 'cuda':
self.generator_cuda = generator
generator_seed = generator.initial_seed()
self.generator = torch.Generator(device='cpu')
self.generator.manual_seed(generator_seed)
else:
self.generator = generator
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
if not self.mixup_alpha >= 0.:
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
def _params_per_elem(self, batch_size:int) -> np.ndarray:
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
elif self.mixup_alpha > 0.: # sampled lambda
# Use torch's beta distribution with generator
lam_mix = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
else:
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
lam_mix[lam_mix < self.eps] = self.eps
# Use torch's random with generator for the random comparison
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
return lam
def _params_per_batch(self) -> float:
lam:float = 1.
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam = self.mixup_lambda
elif self.mixup_alpha > 0.: # sampled lambda
lam = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample().item()
else:
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
if lam < self.eps: lam = self.eps
lam = float(lam)
return lam
class DataMixupCollate(MixupBase):
""" Mixup video in data domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_video'
self.source_audio_key = 'audio'
self.target_video_key = 'sync_video_mixed'
self.target_audio_key = 'audio_mixed'
if not mode == 'batch':
raise ValueError(f"Mode {mode} is not supported for data domain.")
self.sync_transform = v2.Compose([
v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
# only batch mode supported
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
# Randomly choose between horizontal and vertical resizing using
orig_size = int(lam * _SYNC_SIZE)
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
if is_horizontal:
# Horizontal resize
resize_shape_orig = (_SYNC_SIZE, orig_size)
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
else:
# Vertical resize
resize_shape_orig = (orig_size, _SYNC_SIZE)
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
sync_resize_orig = v2.Compose([
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
])
sync_resize_pair = v2.Compose([
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
])
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
# (B, T, C, H, W)
# pass through resize, transform and concat
batch_videos_orig = sync_resize_orig(batch_videos_orig)
batch_videos_pair = sync_resize_pair(batch_videos_pair)
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
batch_videos_concat = self.sync_transform(batch_videos_concat)
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch_videos_concat[i]
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
gc.collect()
return lam
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
normalize:bool = True) -> float:
# assume source_key audios are normalized
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
if normalize:
source_abs_max = batch[i][source_key].abs().max()
target_abs_max = batch[i][target_key].abs().max()
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
return lam
def __call__(self, batch:list, _=None) -> torch.tensor:
batch_size:int = len(batch)
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
half = 'half' in self.mode
if half:
batch_size //= 2
if self.modality == 'video' or self.modality == 'both':
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
if self.modality == 'audio' or self.modality == 'both':
# raise NotImplementedError('Audio mixup is not implemented yet.')
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
return default_collate(batch)
class FeatureMixup(MixupBase):
""" Mixup video in feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_f_vid_orig'
self.source_audio_key = 'sync_f_aud_orig'
self.target_video_key = 'sync_f_vid_mixed'
self.target_audio_key = 'sync_f_aud_mixed'
def _mix_elem_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
half:bool=False) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
num_elem:int = batch_size // 2 if half else batch_size
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
indices = torch.arange(num_elem)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
if half:
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
return lam_batch.unsqueeze(1)
def _mix_pair_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
indices = torch.arange(batch_size // 2)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][active_mix_indices] = (
batch[source_key][active_mix_indices] * active_lambdas +
batch[source_key][active_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
return lam_batch.unsqueeze(1)
def _mix_batch_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
lam:float = self._params_per_batch()
for target_key, source_key in zip(target_keys, source_keys):
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
flipped_source = torch.flip(batch[source_key], dims=[0])
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
return lam
def __call__(self, batch:dict, _=None) -> None:
batch_size:int = len(batch['id'])
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
half = 'half' in self.mode
if half:
batch_size //= 2
# Mixup
if self.mode == 'elem' or self.mode == 'half':
collate_fn = partial(self._mix_elem_collate, half=half)
elif self.mode == 'pair':
collate_fn = self._mix_pair_collate
else:
collate_fn = self._mix_batch_collate
if self.modality == 'both':
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
elif self.modality == 'video':
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
elif self.modality == 'audio':
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
# return batch
+45
View File
@@ -0,0 +1,45 @@
import bisect
import torch
from torch.utils.data.dataset import Dataset
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
class MultiModalDataset(Dataset):
datasets: list[Dataset]
cumulative_sizes: list[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
super().__init__()
self.video_datasets = list(video_datasets)
self.audio_datasets = list(audio_datasets)
self.datasets = self.video_datasets + self.audio_datasets
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
return self.video_datasets[0].compute_latent_stats()
+148
View File
@@ -0,0 +1,148 @@
import logging
import os
import random
import tempfile
from pathlib import Path
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
from tensordict import MemoryMappedTensor
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from selva_core.utils.dist_utils import local_rank, world_size
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
shm_path = Path('/dev/shm')
log = logging.getLogger()
def reseed(seed):
random.seed(seed)
torch.manual_seed(seed)
def local_scatter_torch(obj: Optional[Any]):
if world_size == 1:
# Just one worker. Do nothing.
return obj
array = [obj] * world_size
target_array = [None]
if local_rank == 0:
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
else:
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
return target_array[0]
class ShardDataset(Dataset):
def __init__(self, root):
self.root = root
self.shards = sorted(os.listdir(root))
def __len__(self):
return len(self.shards)
def __getitem__(self, idx):
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
def get_tmp_dir(in_memory: bool) -> Path:
return shm_path if in_memory else scratch_path
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
in_memory: bool) -> MemoryMappedTensor:
if local_rank == 0:
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
log.info(f'Loading shards from {data_path} into {f.name}...')
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
data = share_tensor_to_all(data)
torch.distributed.barrier()
f.close() # why does the context manager not close the file for me?
else:
log.info('Waiting for the data to be shared with me...')
data = share_tensor_to_all(None)
torch.distributed.barrier()
return data
def load_shards(
data_path: Union[str, Path],
ids: list[int],
*,
tmp_file_path: str,
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
id_set = set(ids)
shards = sorted(os.listdir(data_path))
log.info(f'Found {len(shards)} shards in {data_path}.')
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
log.info(f'Rank {local_rank} created file {tmp_file_path}')
first_item = next(iter(first_shard.values()))
log.info(f'First item shape: {first_item.shape}')
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
dtype=torch.float32,
filename=tmp_file_path,
existsok=True)
total_count = 0
used_index = set()
id_indexing = {i: idx for idx, i in enumerate(ids)}
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
for data in tqdm(loader, desc='Loading shards'):
for i, v in data.items():
if i not in id_set:
continue
# tensor_index = ids.index(i)
tensor_index = id_indexing[i]
if tensor_index in used_index:
raise ValueError(f'Duplicate id {i} found in {data_path}.')
used_index.add(tensor_index)
mm_tensor[tensor_index] = v
total_count += 1
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
log.info(f'Loaded {total_count} tensors from {data_path}.')
return mm_tensor
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
"""
x: the tensor to be shared; None if local_rank != 0
return: the shared tensor
"""
# there is no need to share your stuff with anyone if you are alone; must be in memory
if world_size == 1:
return x
if local_rank == 0:
assert x is not None, 'x must not be None if local_rank == 0'
else:
assert x is None, 'x must be None if local_rank != 0'
if local_rank == 0:
filename = x.filename
meta_information = (filename, x.shape, x.dtype)
else:
meta_information = None
filename, data_shape, data_type = local_scatter_torch(meta_information)
if local_rank == 0:
data = x
else:
data = MemoryMappedTensor.from_filename(filename=filename,
dtype=data_type,
shape=data_shape)
return data
+299
View File
@@ -0,0 +1,299 @@
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from tensordict import TensorDict
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
for_generator: bool = True,
audio_required: bool = False,
sample_rate: int = 16_000,
duration_sec: float = 8.0,
audio_samples: Optional[int] = None,
normalize_audio: bool = False,
clip_video_required: bool = False,
mmap_dir: Union[str, Path] = None,
tsv_tsynch_path: Union[str, Path] = None,
mmap_tsync_dir: Union[str, Path] = None,
data_dim: dict[str, int] = None,
):
self.root = Path(root)
self.audio_required = audio_required
if audio_required:
self.normalize_audio = normalize_audio
if audio_samples is None:
self.audio_samples = int(sample_rate * duration_sec)
else:
self.audio_samples = audio_samples
effective_duration = audio_samples / sample_rate
# make sure the duration is close enough, within 15ms
assert abs(effective_duration - duration_sec) < 0.015, \
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
self.clip_video_required = clip_video_required
self.for_generator = for_generator
videos = sorted(os.listdir(self.root))
videos = set([Path(v).stem for v in videos]) # remove extensions
self.labels = {}
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
label = record['label']
if id in videos:
self.labels[id] = label
self.videos.append(id)
else:
missing_videos.append(id)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
self.sample_rate = sample_rate
self.duration_sec = duration_sec
if audio_required:
self.expected_audio_length = self.audio_samples
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
if clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if clip_video_required:
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
if audio_required:
self.resampler = {}
# mmap
log.info(f'Loading precomputed mmap from {mmap_dir}')
mmap_dir = Path(mmap_dir)
td = TensorDict.load_memmap(mmap_dir)
log.info(f'Loaded precomputed mmap from {mmap_dir}')
self.sync_features = td['sync_features']
if for_generator:
self.mean = td['mean']
self.std = td['std']
self.text_clip_features = td['text_features']
if clip_video_required:
self.clip_features = td['clip_features']
else:
self.clip_features = None
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
mmap_tsync_dir = Path(mmap_tsync_dir)
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
self.text_features = td_tsync['text_features']
self.text_masks = td_tsync['text_masks']
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
if local_rank == 0:
log.info(f'Loaded {len(self)} samples.')
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
log.info(f'Loaded text_features: {self.text_features.shape}.')
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
if for_generator:
log.info(f'Loaded mean: {self.mean.shape}.')
log.info(f'Loaded std: {self.std.shape}.')
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
if clip_video_required:
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
if for_generator:
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
assert self.std.shape[1] == data_dim['latent_seq_len'], \
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
if clip_video_required:
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
self.video_exist = torch.tensor(1, dtype=torch.bool)
self.text_exist = torch.tensor(1, dtype=torch.bool)
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
latents = self.mean
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
def get_memory_mapped_tensor(self) -> TensorDict:
td = TensorDict({
'sync_features': self.sync_features,
'text_features': self.text_features,
'text_masks': self.text_masks,
})
if self.for_generator:
td['mean'] = self.mean
td['std'] = self.std
td['text_clip_features'] = self.text_clip_features
if self.clip_video_required:
td['clip_features'] = self.clip_features
return td
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
label = self.captions[video_id]
else:
label = self.labels[video_id]
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.audio_required:
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.audio_required:
audio_chunk = data_chunk[1]
if self.clip_video_required:
clip_chunk = data_chunk[2 if self.audio_required else 1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
# process audio
if self.audio_required:
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
audio_chunk = audio_chunk.transpose(0, 1)
audio_chunk = audio_chunk.mean(dim=0) # mono
if self.normalize_audio:
abs_max = audio_chunk.abs().max()
audio_chunk = audio_chunk * (0.95 / abs_max)
if abs_max <= 1e-6:
raise RuntimeError(f'Audio is silent {video_id}')
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.expected_audio_length:
raise RuntimeError(f'Audio too short {video_id}')
audio_chunk = audio_chunk[:self.expected_audio_length]
data = {
'id': video_id,
'caption': label,
'sync_video': sync_chunk,
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
'video_exist': self.video_exist,
'text_exist': self.text_exist,
}
if self.for_generator:
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
if self.audio_required:
data['audio'] = audio_chunk
if self.clip_video_required:
data['clip_video'] = clip_chunk
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)
+1
View File
@@ -0,0 +1 @@
+1
View File
@@ -0,0 +1 @@
from .autoencoder import AutoEncoderModule
+52
View File
@@ -0,0 +1,52 @@
from typing import Literal, Optional
import torch
import torch.nn as nn
from selva_core.ext.autoencoder.vae import VAE, get_my_vae
from selva_core.ext.bigvgan import BigVGAN
from selva_core.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
class AutoEncoderModule(nn.Module):
def __init__(self,
*,
vae_ckpt_path,
vocoder_ckpt_path: Optional[str] = None,
mode: Literal['16k', '44k'],
need_vae_encoder: bool = True):
super().__init__()
self.vae: VAE = get_my_vae(mode).eval()
vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
self.vae.load_state_dict(vae_state_dict)
self.vae.remove_weight_norm()
if mode == '16k':
assert vocoder_ckpt_path is not None
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
elif mode == '44k':
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
use_cuda_kernel=False)
self.vocoder.remove_weight_norm()
else:
raise ValueError(f'Unknown mode: {mode}')
for param in self.parameters():
param.requires_grad = False
if not need_vae_encoder:
del self.vae.encoder
@torch.inference_mode()
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
return self.vae.encode(x)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.vae.decode(z)
@torch.inference_mode()
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
return self.vocoder(spec)
+168
View File
@@ -0,0 +1,168 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Improved diffusion model architecture proposed in the paper
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
import numpy as np
import torch
#----------------------------------------------------------------------------
# Variant of constant() that inherits dtype and device from the given
# reference tensor by default.
_constant_cache = dict()
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
value = np.asarray(value)
if shape is not None:
shape = tuple(shape)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device('cpu')
if memory_format is None:
memory_format = torch.contiguous_format
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
tensor = _constant_cache.get(key, None)
if tensor is None:
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
if shape is not None:
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
tensor = tensor.contiguous(memory_format=memory_format)
_constant_cache[key] = tensor
return tensor
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
if dtype is None:
dtype = ref.dtype
if device is None:
device = ref.device
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
#----------------------------------------------------------------------------
# Normalize given tensor to unit magnitude with respect to the given
# dimensions. Default = all dimensions except the first.
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class Normalize(torch.nn.Module):
def __init__(self, dim=None, eps=1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return normalize(x, dim=self.dim, eps=self.eps)
#----------------------------------------------------------------------------
# Upsample or downsample the given tensor with the given filter,
# or keep it as is.
def resample(x, f=[1, 1], mode='keep'):
if mode == 'keep':
return x
f = np.float32(f)
assert f.ndim == 1 and len(f) % 2 == 0
pad = (len(f) - 1) // 2
f = f / f.sum()
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
f = const_like(x, f)
c = x.shape[1]
if mode == 'down':
return torch.nn.functional.conv2d(x,
f.tile([c, 1, 1, 1]),
groups=c,
stride=2,
padding=(pad, ))
assert mode == 'up'
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
groups=c,
stride=2,
padding=(pad, ))
#----------------------------------------------------------------------------
# Magnitude-preserving SiLU (Equation 81).
def mp_silu(x):
return torch.nn.functional.silu(x) / 0.596
class MPSiLU(torch.nn.Module):
def forward(self, x):
return mp_silu(x)
#----------------------------------------------------------------------------
# Magnitude-preserving sum (Equation 88).
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
#----------------------------------------------------------------------------
# Magnitude-preserving concatenation (Equation 103).
def mp_cat(a, b, dim=1, t=0.5):
Na = a.shape[dim]
Nb = b.shape[dim]
C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
wa = C / np.sqrt(Na) * (1 - t)
wb = C / np.sqrt(Nb) * t
return torch.cat([wa * a, wb * b], dim=dim)
#----------------------------------------------------------------------------
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
# with force weight normalization (Equation 66).
class MPConv1D(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
self.weight_norm_removed = False
def forward(self, x, gain=1):
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
w = self.weight * gain
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 3
return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
def remove_weight_norm(self):
w = self.weight.to(torch.float32)
w = normalize(w) # traditional weight normalization
w = w / np.sqrt(w[0].numel())
w = w.to(self.weight.dtype)
self.weight.data.copy_(w)
self.weight_norm_removed = True
return self
+369
View File
@@ -0,0 +1,369 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from selva_core.ext.autoencoder.edm2_utils import MPConv1D
from selva_core.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.data_mean) / self.data_std
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * self.data_std + self.data_mean
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
for name, m in self.named_modules():
if isinstance(m, MPConv1D):
m.remove_weight_norm()
log.debug(f"Removed weight norm from {name}")
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = MPConv1D(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
hs.append(h)
if i_level in self.down_layers:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h, gain=(self.learnable_gain + 1))
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h, gain=(self.learnable_gain + 1))
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')
if __name__ == '__main__':
network = get_my_vae('standard')
# print the number of parameters in terms of millions
num_params = sum(p.numel() for p in network.parameters()) / 1e6
print(f'Number of parameters: {num_params:.2f}M')
+117
View File
@@ -0,0 +1,117 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from selva_core.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
def nonlinearity(x):
# swish
return mp_silu(x)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
else:
self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=2).unbind(3)
q = rearrange(q, 'b h c l -> b h l c')
k = rearrange(k, 'b h c l -> b h l c')
v = rearrange(v, 'b h c l -> b h l c')
h = F.scaled_dot_product_attention(q, k, v)
h = rearrange(h, 'b h l c -> b (h c) l')
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 NVIDIA CORPORATION.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+1
View File
@@ -0,0 +1 @@
from .bigvgan import BigVGAN

Some files were not shown because too many files have changed in this diff Show More