fix(bigvgan-trainer): clone mel outputs to strip inference tensor flag from buffers

mel_converter buffers (mel_basis, hann_window) are inference tensors
because the model was loaded inside ComfyUI's torch.inference_mode().
Operations on them propagate the flag to outputs. Clone both target_mel
and pred_mel to get normal autograd-compatible tensors. .clone() is
differentiable so the grad graph to vocoder parameters is preserved.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:51:28 +02:00
parent daa36a5f7b
commit f04d59fe63
+5 -3
View File
@@ -235,9 +235,11 @@ class SelvaBigvganTrainer:
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T] target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T]
# Fixed target mel (no grad needed here) # Fixed target mel (no grad needed here).
# .clone() strips the inference-tensor flag inherited from
# mel_converter's buffers (loaded inside ComfyUI's inference_mode).
with torch.no_grad(): with torch.no_grad():
target_mel = mel_converter(target_flat) # [B, 80, T_mel] target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel]
# Vocoder forward: mel → waveform # Vocoder forward: mel → waveform
pred_wav = vocoder(target_mel) # [B, 1, T_wav] pred_wav = vocoder(target_mel) # [B, 1, T_wav]
@@ -248,7 +250,7 @@ class SelvaBigvganTrainer:
target_t = target_wav[..., :T] target_t = target_wav[..., :T]
# Mel reconstruction loss: mel(pred) vs target_mel # Mel reconstruction loss: mel(pred) vs target_mel
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel']
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])