diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 7511bf1..4b1b10a 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -277,15 +277,42 @@ def _do_train(vocoder, mel_converter, clips, _save_sample("baseline") - # Replace inference-tensor parameters with fresh trainable parameters. - # (Vocoder was loaded by ComfyUI before this thread existed; its parameters - # may carry the inference flag from that context. Here we're in a clean thread - # so all new tensors are normal.) + # Sanitize all inference tensors in the vocoder. + # Three categories to handle (all loaded in ComfyUI's inference_mode): + # + # 1. Registered parameters (_parameters): covers bias, alpha, etc. + # + # 2. Plain tensor attributes in __dict__: torch.nn.utils.parametrize. + # remove_parametrizations() calls setattr(module, name, tensor) with + # a raw tensor, NOT nn.Parameter. Module.__setattr__ stores raw tensors + # in __dict__ (not _parameters), so our parameter loop misses them. + # This is how BigVGAN's conv.weight ends up invisible to _parameters. + # Fix: re-register as Parameter, which also makes them trainable. + # + # 3. Registered buffers (_buffers): Activation1d's anti-aliasing filter + # tensors. Not trainable, but operations on inference buffers produce + # inference tensor outputs — which breaks the backward graph mid-network. + # Fix: clone to strip the inference flag (not registered as parameters). for module in vocoder.modules(): + # Category 1: registered parameters for pname, param in list(module._parameters.items()): if param is not None: - fresh = param.data.clone().detach().requires_grad_(True) - module._parameters[pname] = nn_mod.Parameter(fresh) + module._parameters[pname] = nn_mod.Parameter( + param.data.clone().detach(), requires_grad=True + ) + + # Category 2: plain tensor attributes (e.g. weight left by remove_parametrizations) + for name, val in list(module.__dict__.items()): + if (isinstance(val, torch.Tensor) + and not isinstance(val, nn_mod.Parameter) + and name not in module._buffers + and name not in module._modules): + module.register_parameter(name, nn_mod.Parameter(val.clone())) + + # Category 3: buffers (Activation1d filter, etc.) — clone, don't parametrize + for bname, buf in list(module._buffers.items()): + if buf is not None: + module._buffers[bname] = buf.clone() optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) vocoder.train() @@ -305,22 +332,6 @@ def _do_train(vocoder, mel_converter, clips, with torch.no_grad(): target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] - if step == 0: - print(f"[BigVGAN DEBUG] inference_mode={torch.is_inference_mode_enabled()}", flush=True) - cp = vocoder.conv_pre - print(f"[BigVGAN DEBUG] conv_pre._parameters keys={list(cp._parameters.keys())}", flush=True) - for k, v in cp._parameters.items(): - if v is not None: - print(f"[BigVGAN DEBUG] _parameters[{k!r}].is_inference()={v.is_inference()}", flush=True) - has_p = hasattr(cp, 'parametrizations') - print(f"[BigVGAN DEBUG] has parametrizations={has_p}", flush=True) - if has_p: - for attr, plist in cp.parametrizations.items(): - print(f"[BigVGAN DEBUG] parametrizations[{attr!r}]={plist}", flush=True) - for sub_name, sub_p in plist.named_parameters(): - print(f"[BigVGAN DEBUG] {sub_name}.is_inference()={sub_p.is_inference()}", flush=True) - print(f"[BigVGAN DEBUG] conv_pre.weight.is_inference()={cp.weight.is_inference()}", flush=True) - pred_wav = vocoder(target_mel) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1])