From f04d59fe63d013cccc7ef69dad24235adff16530 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 01:51:28 +0200 Subject: [PATCH] fix(bigvgan-trainer): clone mel outputs to strip inference tensor flag from buffers mel_converter buffers (mel_basis, hann_window) are inference tensors because the model was loaded inside ComfyUI's torch.inference_mode(). Operations on them propagate the flag to outputs. Clone both target_mel and pred_mel to get normal autograd-compatible tensors. .clone() is differentiable so the grad graph to vocoder parameters is preserved. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index b86d8fc..b8b87b8 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -235,9 +235,11 @@ class SelvaBigvganTrainer: target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T] - # Fixed target mel (no grad needed here) + # Fixed target mel (no grad needed here). + # .clone() strips the inference-tensor flag inherited from + # mel_converter's buffers (loaded inside ComfyUI's inference_mode). with torch.no_grad(): - target_mel = mel_converter(target_flat) # [B, 80, T_mel] + target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel] # Vocoder forward: mel → waveform pred_wav = vocoder(target_mel) # [B, 1, T_wav] @@ -248,7 +250,7 @@ class SelvaBigvganTrainer: target_t = target_wav[..., :T] # Mel reconstruction loss: mel(pred) vs target_mel - pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] + pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel'] T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])