From 0fcb6d31069f09a7fdd0dbf89bc948688b03c54c Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 01:58:57 +0200 Subject: [PATCH] fix(bigvgan-trainer): replace parameter objects to fully strip inference tensor flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit param.data = clone() only replaces storage — the nn.Parameter object itself retains the inference tensor flag set when the model was loaded. Replace each parameter with a fresh nn.Parameter(data.clone()) created inside inference_mode(False) so both the object and its data are normal tensors. Move optimizer creation to after re-creation so it references the new objects. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index eed1a0b..0819d2f 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -187,8 +187,6 @@ class SelvaBigvganTrainer: soft_empty_cache() mel_converter.to(device) - vocoder.requires_grad_(True) - optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) torch.manual_seed(seed) random.seed(seed) @@ -222,12 +220,22 @@ class SelvaBigvganTrainer: try: with torch.inference_mode(False): with torch.enable_grad(): - # Vocoder parameters were loaded inside ComfyUI's inference_mode() - # and are inference tensors. Autograd cannot save them for backward. - # Clone inside inference_mode(False) to get normal tensors. - for param in vocoder.parameters(): - param.data = param.data.clone() + # Vocoder parameters are inference tensors (loaded inside ComfyUI's + # inference_mode). param.data = clone() only changes storage — the + # nn.Parameter object itself still carries the inference flag. + # Replace each parameter with a fresh nn.Parameter created here + # (inside inference_mode(False)) so the object itself is normal. + import torch.nn as nn_mod + for module in vocoder.modules(): + for pname, param in list(module._parameters.items()): + if param is not None: + module._parameters[pname] = nn_mod.Parameter( + param.data.clone(), requires_grad=True + ) + optimizer = torch.optim.AdamW( + vocoder.parameters(), lr=lr, betas=(0.8, 0.99) + ) vocoder.train() for step in range(steps):