fix(bigvgan-trainer): clone mel outputs to strip inference tensor flag from buffers
mel_converter buffers (mel_basis, hann_window) are inference tensors because the model was loaded inside ComfyUI's torch.inference_mode(). Operations on them propagate the flag to outputs. Clone both target_mel and pred_mel to get normal autograd-compatible tensors. .clone() is differentiable so the grad graph to vocoder parameters is preserved. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -235,9 +235,11 @@ class SelvaBigvganTrainer:
|
||||
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
# Fixed target mel (no grad needed here)
|
||||
# Fixed target mel (no grad needed here).
|
||||
# .clone() strips the inference-tensor flag inherited from
|
||||
# mel_converter's buffers (loaded inside ComfyUI's inference_mode).
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, 80, T_mel]
|
||||
target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel]
|
||||
|
||||
# Vocoder forward: mel → waveform
|
||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||
@@ -248,7 +250,7 @@ class SelvaBigvganTrainer:
|
||||
target_t = target_wav[..., :T]
|
||||
|
||||
# Mel reconstruction loss: mel(pred) vs target_mel
|
||||
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel']
|
||||
pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel']
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user