diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index e27bfd6..f235d17 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -241,19 +241,17 @@ class SelvaBigvganTrainer: fresh, requires_grad=True ) - # mel_converter buffers (mel_basis, hann_window, etc.) were loaded - # inside ComfyUI's outer inference_mode context, so they are inference - # tensors. Operations on inference tensors ALWAYS produce inference - # tensors, even inside inference_mode(False). torch.zeros() et al. - # create normal tensors in the current (non-inference) context, so - # we replace every buffer once via copy_() to break the chain. - for bname, buf in list(mel_converter._buffers.items()): - if buf is not None: - fresh = torch.zeros( - buf.shape, device=buf.device, dtype=buf.dtype - ) - fresh.copy_(buf) - mel_converter._buffers[bname] = fresh + # mel_converter and its submodules (e.g. Spectrogram.window) have + # inference-tensor buffers loaded in ComfyUI's outer inference_mode. + # Must iterate .modules() — ._buffers only covers direct buffers. + for sub in mel_converter.modules(): + for bname, buf in list(sub._buffers.items()): + if buf is not None: + fresh = torch.zeros( + buf.shape, device=buf.device, dtype=buf.dtype + ) + fresh.copy_(buf) + sub._buffers[bname] = fresh optimizer = torch.optim.AdamW( vocoder.parameters(), lr=lr, betas=(0.8, 0.99) @@ -280,11 +278,16 @@ class SelvaBigvganTrainer: del _stacked target_wav = target_flat.unsqueeze(1) # [B, 1, T] - # Fixed target mel — buffers are now normal tensors (sanitized - # above), so torch.no_grad() correctly produces a non-inference, - # no-grad leaf tensor that conv layers can save for backward. + # Compute target mel and guarantee it is not an inference tensor. + # Even with sanitized buffers a submodule we missed could still + # taint the output, so we always copy into a fresh tensor. with torch.no_grad(): - target_mel = mel_converter(target_flat) # [B, 80, T_mel] + _mel = mel_converter(target_flat) + target_mel = torch.empty( + _mel.shape, device=device, dtype=dtype + ) + target_mel.copy_(_mel) + del _mel # Vocoder forward: mel → waveform pred_wav = vocoder(target_mel) # [B, 1, T_wav]