diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index b86d8fc..b8b87b8 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -235,9 +235,11 @@ class SelvaBigvganTrainer: target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T] - # Fixed target mel (no grad needed here) + # Fixed target mel (no grad needed here). + # .clone() strips the inference-tensor flag inherited from + # mel_converter's buffers (loaded inside ComfyUI's inference_mode). with torch.no_grad(): - target_mel = mel_converter(target_flat) # [B, 80, T_mel] + target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel] # Vocoder forward: mel → waveform pred_wav = vocoder(target_mel) # [B, 1, T_wav] @@ -248,7 +250,7 @@ class SelvaBigvganTrainer: target_t = target_wav[..., :T] # Mel reconstruction loss: mel(pred) vs target_mel - pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] + pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel'] T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])