The frozen discriminators are loaded in model dtype (bfloat16) but vocoder
waveform outputs are float32, causing a Conv2d dtype mismatch.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Training confirmed working — VRAM usage is normal backward-pass
activation memory, not a leak. Removed all debug _vram_log and _vram
calls. Kept the video_enc offload and torch.cuda.empty_cache fixes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch.cuda.memory_allocated only tracks PyTorch allocator. Added
torch.cuda.mem_get_info to see actual CUDA driver memory usage.
Also offload video_enc (TextSynch) which was missed in the original
offload — stays on GPU when strategy != offload_to_cpu.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
PyTorch's caching allocator reserves GPU memory from pre-generation
(~90 GiB for generator + tod) and doesn't return it to CUDA/OS.
soft_empty_cache may not call torch.cuda.empty_cache(). Force a full
cache release after CLIP encoding and after LoRA mel pre-generation.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Logs VRAM at: after target_mel, after vocoder forward, before loss,
after loss computation, and after backward. Only logs for step 0 to
avoid spam. Will identify which operation causes the 94 GiB spike.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Logs torch.cuda.memory_allocated/reserved at each step: before unload,
after unload_all_models, after feature_utils.to(cpu), after generator
to(cpu), after cache clear, after mel_converter to(device), and before
training loop. This will identify what's holding VRAM.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
_save_sample("baseline") was called before the vocoder's inference
tensors were sanitized, causing "Inference tensors do not track version
counter". Moved it after the clone/detach loop and vocoder.to(device).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
CLIP weights are inference tensors from ComfyUI loading. inference_mode
is thread-local, so the worker thread can't use CLIP even with a context
manager. Pre-compute all text embeddings in the main thread (where
inference_mode IS active), clone+detach to normal tensors, and pass them
to the worker via text_clip_cache dict. CLIP no longer needs to be on
GPU during pre-generation.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
CLIP weights are inference tensors from ComfyUI loading. The worker
thread runs without inference_mode, so PyTorch rejects inference tensors
in multi_head_attention_forward (version counter tracking). Wrap the
encode_text_clip call in torch.inference_mode() since text encoding
doesn't need gradients.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous offload ran inside the worker thread, but by then ComfyUI
had already loaded the full model to GPU. Now feature_utils.to('cpu')
and generator.to('cpu') run in the main thread right after
unload_all_models(), before the worker starts. vocoder.to(device, dtype)
is called explicitly after inference flag stripping in _do_train to
bring only the vocoder back to GPU.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16
before inference flag stripping. Cast mel to vocoder's dtype to prevent
input/bias type mismatch during baseline sample save.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
feature_utils.to(device) was loading CLIP ViT-H, synchformer, T5, VAE,
and vocoder (~90 GiB) to GPU for the entire training run. Now only
mel_converter (tiny) is moved to GPU. Pre-generation manages its own
device placement: temporarily moves CLIP and tod to GPU, then moves them
back when done. This frees ~90 GiB for the backward pass.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Only the vocoder and mel_converter are needed during BigVGAN training.
The rest of the SelVA pipeline (CLIP ViT-H, synchformer, T5, generator,
VAE) was staying on GPU and consuming ~90 GiB, leaving no room for
backward pass activations. Now offloaded individually to CPU before
the training loop starts.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
LoRA mel pre-generation runs a full ODE+CFG for every clip, which is slow.
Cache results to a .pt file next to the output, keyed by a SHA-256 hash
of the LoRA adapter content + generation parameters (seed, steps, CFG,
duration, sample rate, npz file list). Automatically reused on subsequent
runs when parameters haven't changed.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Discriminators are constructed as float32 but receive bfloat16 tensors
from the vocoder. Cast to model dtype on load to prevent conv dtype
mismatch in feature matching loss.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
GAFilter conv weights are created as float32 but the rest of the vocoder
is bfloat16. vocoder.to(device) missed the dtype cast, causing conv1d
dtype mismatch when Snake bfloat16 output flows into GAFilter.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
mel_converter outputs float32 (cuFFT requirement) but vocoder weights are
bfloat16 from model loading. Cast input_mel back to model dtype before
feeding the vocoder to avoid conv1d dtype mismatch.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-generated mels were using a bare forward pass with no classifier-free
guidance, producing mels that don't match what the vocoder sees at inference
(where cfg_strength=4.5 is the default). Now uses ode_wrapper with
preprocess_conditions/get_empty_conditions, same as the sampler node.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
mel_basis and hann_window buffers inherit bfloat16 from model loading.
Since all mel_converter inputs are cast to float32 for cuFFT, the
internal buffers must also be float32 to avoid matmul dtype mismatch.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
cuFFT does not support bfloat16 tensors. When the model is loaded in
bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram,
multi-resolution STFT loss) crash. Add .float() at every STFT boundary.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a lora_adapter path is provided, the trainer pre-generates
LoRA-distorted mels for each training clip (full ODE generation +
VAE decode) and trains the vocoder to produce clean audio from them.
This teaches the vocoder to compensate for LoRA latent distribution
shift without requiring perfectly aligned training pairs.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Drop SelvaFlashSR node — audiosr pins numpy<=1.23.5 which cannot build
on Python 3.12 (pkgutil.ImpImporter removed); use Saganaki22/ComfyUI-AudioSR instead
- BigVGAN trainer now writes <output_stem>_training_log.csv alongside the
checkpoint: step, total, fm, mel, stft, phase, l2sp columns, line-buffered
so loss can be tailed live during training
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Activation1d's anti-alias Kaiser sinc resampling (asymmetric pad_left /
pad_right) can produce ±1-2 sample rounding in edge cases, causing the
BigVGAN AMPBlock residual addition (xt + x) to fail with a size mismatch.
Trim or pad the output to exactly match the input length so the resblock
skip connection always has matching dimensions.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
L2-SP anchors trainable params to their pretrained values. GAFilter is a
newly initialized module (identity FIR filter) with no pretrained values —
anchoring it to identity initialization would resist learning. Exclude
gafilter params from the L2-SP loss so they train freely.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements AF-Vocoder GAFilter (Interspeech 2025): learnable per-channel
depthwise FIR filter inserted after each Snake/Activation1d in BigVGAN
residual blocks. Initialized as identity so training starts from pretrained
behaviour.
- inject_gafilters() walks resblocks.*.activations and wraps each Activation1d
with _ActivationWithGAFilter — weights appear in vocoder.state_dict() automatically
- Trained alongside Snake alphas in snake_alpha_only mode
- Checkpoint saves has_gafilter + gafilter_kernel_size metadata
- Loader detects metadata and injects before load_state_dict so weights populate correctly
- Controlled by use_gafilter (default True) and gafilter_kernel_size (default 9)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds L1 loss on real, imaginary, and magnitude STFT components across
three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing
directly — magnitude-only losses cannot distinguish correct spectrum
with wrong phase from a smeared spectrum.
Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top
of both the discriminator FM path and the fallback mel+STFT path.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
If no matching key was found for MPD or MRD in the checkpoint, the for-loops
completed silently and randomly-initialized discriminators were used as frozen
feature extractors — producing meaningless feature matching loss while
appearing to work. Now raises RuntimeError (caught by outer except) which
triggers the existing fallback to mel+STFT losses with a clear warning.
Also prints available checkpoint keys to help diagnose format mismatches.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Hardcoded max of 4.0 prevented using full 8s clips. Raised to 30s.
Also bumped default from 1.0 to 2.0 as a more sensible starting point.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
BigVGAN's 512x upsampling stack stores huge intermediate activations for
backward even in snake_alpha_only mode (only 5K trainable params, but
activation graph runs through the full network after each snake op).
Wrapping vocoder() in checkpoint(use_reentrant=False) recomputes activations
during backward instead of storing them — ~2x compute cost, large reduction
in peak VRAM. Should allow batch_size > 1 on 96 GB without OOM.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two bugs:
1. _DiscriminatorR used channels=32 but the BigVGAN pretrained discriminator
checkpoint has channels=128. All convs in _DiscriminatorR now use 128,
matching the checkpoint architecture so state_dict loads without error.
2. BigVGAN trainer OOM: SelVA generator and other ComfyUI models remain in
VRAM during training (~90 GiB used). Add unload_all_models() + cache
flush before the training loop to reclaim VRAM headroom.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.
Fix: use a straight-through estimator after the no-grad prefix:
x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.
Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Writes _gt_spec.png from ref_mel before training starts so each step's
_spec.png can be compared against the unmodified vocoder roundtrip target.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds _save_spectrogram() using PIL only (no matplotlib). Each _save_sample
call now writes both a .wav and a _spec.png so training progress is visible
without listening. Colour map is blue→green→yellow (viridis-ish), low
frequencies at the bottom.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Same environment has no compatible ffmpeg/torchcodec for saving.
Mirror the _load_wav pattern: try torchaudio, fall back to soundfile.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
remove_parametrizations() stores weight as a plain __dict__ tensor (not
nn.Parameter), making it invisible to _parameters iteration. Also, buffers
(Activation1d anti-aliasing filters) are inference tensors that break the
backward graph mid-network. Fix all three categories:
1. _parameters: clone().detach(), wrap as Parameter
2. plain __dict__ tensors: clone(), register_parameter (also makes trainable)
3. _buffers: clone() to strip inference flag without parametrizing
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.inference_mode is thread-local, but the inference flag lives on the
tensor object. Operations on inference tensors always propagate it, even in
a clean thread. The only escape is .clone() called outside inference_mode.
At thread entry (inference_mode disabled): clone clips and mel_converter
buffers to get clean normal tensors before any training computation.
Vocoder parameter clone() also now works correctly in this thread context.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.inference_mode is thread-local. ComfyUI sets it on the node-execution
thread; inference_mode(False) alone is insufficient to escape it in some
environments (e.g. async wrappers, lora-manager hook). A new thread always
starts clean. Moved all training logic into _do_train() called via
threading.Thread so every tensor is a normal autograd tensor by default.
Simplified parameter cloning: clone().detach().requires_grad_(True).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Previous fix only iterated mel_converter._buffers (direct buffers). Submodules
(e.g. Spectrogram.window) still held inference tensors. Switch to .modules()
to cover all nested buffers, matching the vocoder parameter sanitization.
Also add a zeros+copy_ safety net on target_mel output so conv can save it.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The clips list is built inside ComfyUI's inference_mode context, so every
element is an inference tensor. torch.stack().clone() propagates the flag.
Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor,
so mel_converter(target_flat) inside no_grad produces a saveable input.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
param.data.clone() and tensor.detach() on inference tensors both produce
inference tensors — the flag propagates through all operations on them.
Inside inference_mode(False), torch.zeros() creates genuine normal tensors.
Use zeros+copy_ to sanitize both vocoder parameters and mel_converter
buffers once before training, so autograd can save inputs for backward.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
param.data = clone() only replaces storage — the nn.Parameter object itself
retains the inference tensor flag set when the model was loaded. Replace each
parameter with a fresh nn.Parameter(data.clone()) created inside
inference_mode(False) so both the object and its data are normal tensors.
Move optimizer creation to after re-creation so it references the new objects.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The vocoder is loaded inside ComfyUI's torch.inference_mode(), making all
its parameters inference tensors. Autograd cannot save inference tensors
for backward even with requires_grad=True. Clone all parameters inside
torch.inference_mode(False) before training to get normal tensors.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
mel_converter buffers (mel_basis, hann_window) are inference tensors
because the model was loaded inside ComfyUI's torch.inference_mode().
Operations on them propagate the flag to outputs. Clone both target_mel
and pred_mel to get normal autograd-compatible tensors. .clone() is
differentiable so the grad graph to vocoder parameters is preserved.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Clips loaded outside torch.inference_mode(False) are inference tensors.
Autograd cannot save them for backward. .clone() creates a normal tensor,
same fix pattern as selva_lora_trainer's dist.mode().clone().
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Model loaded in bf16 causes mel_basis buffer to be bf16. Audio loaded
from disk is float32, causing matmul dtype mismatch. Cast all audio
tensors to model["dtype"] before passing to mel_converter/vocoder.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torchcodec/libavutil soname mismatch causes torchaudio to fail on every
file load, silently emptying clips. Add _load_wav() that tries torchaudio
first then falls back to soundfile (handles wav/flac without ffmpeg).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>