diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 0819d2f..6f0061a 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -225,13 +225,35 @@ class SelvaBigvganTrainer: # nn.Parameter object itself still carries the inference flag. # Replace each parameter with a fresh nn.Parameter created here # (inside inference_mode(False)) so the object itself is normal. + # param.data.clone() of an inference tensor still produces an + # inference tensor. Use torch.zeros + copy_ to create a genuinely + # fresh normal tensor, then wrap in nn.Parameter (created here, + # inside inference_mode(False), so it is a normal parameter). import torch.nn as nn_mod for module in vocoder.modules(): for pname, param in list(module._parameters.items()): if param is not None: - module._parameters[pname] = nn_mod.Parameter( - param.data.clone(), requires_grad=True + fresh = torch.zeros( + param.shape, device=param.device, dtype=param.dtype ) + fresh.copy_(param.data) + module._parameters[pname] = nn_mod.Parameter( + fresh, requires_grad=True + ) + + # mel_converter buffers (mel_basis, hann_window, etc.) were loaded + # inside ComfyUI's outer inference_mode context, so they are inference + # tensors. Operations on inference tensors ALWAYS produce inference + # tensors, even inside inference_mode(False). torch.zeros() et al. + # create normal tensors in the current (non-inference) context, so + # we replace every buffer once via copy_() to break the chain. + for bname, buf in list(mel_converter._buffers.items()): + if buf is not None: + fresh = torch.zeros( + buf.shape, device=buf.device, dtype=buf.dtype + ) + fresh.copy_(buf) + mel_converter._buffers[bname] = fresh optimizer = torch.optim.AdamW( vocoder.parameters(), lr=lr, betas=(0.8, 0.99) @@ -249,11 +271,11 @@ class SelvaBigvganTrainer: target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T] target_wav = target_flat.unsqueeze(1) # [B, 1, T] - # Fixed target mel (no grad needed here). - # .clone() strips the inference-tensor flag inherited from - # mel_converter's buffers (loaded inside ComfyUI's inference_mode). + # Fixed target mel — buffers are now normal tensors (sanitized + # above), so torch.no_grad() correctly produces a non-inference, + # no-grad leaf tensor that conv layers can save for backward. with torch.no_grad(): - target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel] + target_mel = mel_converter(target_flat) # [B, 80, T_mel] # Vocoder forward: mel → waveform pred_wav = vocoder(target_mel) # [B, 1, T_wav] @@ -263,8 +285,9 @@ class SelvaBigvganTrainer: pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] - # Mel reconstruction loss: mel(pred) vs target_mel - pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel'] + # Mel reconstruction loss — no no_grad: grad must flow + # through pred_t → mel_converter → loss. + pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel'] T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])