fix(bigvgan-trainer): replace parameter objects to fully strip inference tensor flag
param.data = clone() only replaces storage — the nn.Parameter object itself retains the inference tensor flag set when the model was loaded. Replace each parameter with a fresh nn.Parameter(data.clone()) created inside inference_mode(False) so both the object and its data are normal tensors. Move optimizer creation to after re-creation so it references the new objects. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -187,8 +187,6 @@ class SelvaBigvganTrainer:
|
||||
soft_empty_cache()
|
||||
|
||||
mel_converter.to(device)
|
||||
vocoder.requires_grad_(True)
|
||||
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
@@ -222,12 +220,22 @@ class SelvaBigvganTrainer:
|
||||
try:
|
||||
with torch.inference_mode(False):
|
||||
with torch.enable_grad():
|
||||
# Vocoder parameters were loaded inside ComfyUI's inference_mode()
|
||||
# and are inference tensors. Autograd cannot save them for backward.
|
||||
# Clone inside inference_mode(False) to get normal tensors.
|
||||
for param in vocoder.parameters():
|
||||
param.data = param.data.clone()
|
||||
# Vocoder parameters are inference tensors (loaded inside ComfyUI's
|
||||
# inference_mode). param.data = clone() only changes storage — the
|
||||
# nn.Parameter object itself still carries the inference flag.
|
||||
# Replace each parameter with a fresh nn.Parameter created here
|
||||
# (inside inference_mode(False)) so the object itself is normal.
|
||||
import torch.nn as nn_mod
|
||||
for module in vocoder.modules():
|
||||
for pname, param in list(module._parameters.items()):
|
||||
if param is not None:
|
||||
module._parameters[pname] = nn_mod.Parameter(
|
||||
param.data.clone(), requires_grad=True
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
||||
)
|
||||
vocoder.train()
|
||||
|
||||
for step in range(steps):
|
||||
|
||||
Reference in New Issue
Block a user