fix(bigvgan-trainer): replace parameter objects to fully strip inference tensor flag

param.data = clone() only replaces storage — the nn.Parameter object itself
retains the inference tensor flag set when the model was loaded. Replace each
parameter with a fresh nn.Parameter(data.clone()) created inside
inference_mode(False) so both the object and its data are normal tensors.
Move optimizer creation to after re-creation so it references the new objects.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:58:57 +02:00
parent c86306bde8
commit 0fcb6d3106
+15 -7
View File
@@ -187,8 +187,6 @@ class SelvaBigvganTrainer:
soft_empty_cache() soft_empty_cache()
mel_converter.to(device) mel_converter.to(device)
vocoder.requires_grad_(True)
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
@@ -222,12 +220,22 @@ class SelvaBigvganTrainer:
try: try:
with torch.inference_mode(False): with torch.inference_mode(False):
with torch.enable_grad(): with torch.enable_grad():
# Vocoder parameters were loaded inside ComfyUI's inference_mode() # Vocoder parameters are inference tensors (loaded inside ComfyUI's
# and are inference tensors. Autograd cannot save them for backward. # inference_mode). param.data = clone() only changes storage — the
# Clone inside inference_mode(False) to get normal tensors. # nn.Parameter object itself still carries the inference flag.
for param in vocoder.parameters(): # Replace each parameter with a fresh nn.Parameter created here
param.data = param.data.clone() # (inside inference_mode(False)) so the object itself is normal.
import torch.nn as nn_mod
for module in vocoder.modules():
for pname, param in list(module._parameters.items()):
if param is not None:
module._parameters[pname] = nn_mod.Parameter(
param.data.clone(), requires_grad=True
)
optimizer = torch.optim.AdamW(
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
)
vocoder.train() vocoder.train()
for step in range(steps): for step in range(steps):