fix: handle all three inference-tensor sources in vocoder sanitization

remove_parametrizations() stores weight as a plain __dict__ tensor (not
nn.Parameter), making it invisible to _parameters iteration. Also, buffers
(Activation1d anti-aliasing filters) are inference tensors that break the
backward graph mid-network. Fix all three categories:
1. _parameters: clone().detach(), wrap as Parameter
2. plain __dict__ tensors: clone(), register_parameter (also makes trainable)
3. _buffers: clone() to strip inference flag without parametrizing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 02:54:41 +02:00
parent b243908873
commit 5df2abd6dd
+33 -22
View File
@@ -277,15 +277,42 @@ def _do_train(vocoder, mel_converter, clips,
_save_sample("baseline")
# Replace inference-tensor parameters with fresh trainable parameters.
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
# may carry the inference flag from that context. Here we're in a clean thread
# so all new tensors are normal.)
# Sanitize all inference tensors in the vocoder.
# Three categories to handle (all loaded in ComfyUI's inference_mode):
#
# 1. Registered parameters (_parameters): covers bias, alpha, etc.
#
# 2. Plain tensor attributes in __dict__: torch.nn.utils.parametrize.
# remove_parametrizations() calls setattr(module, name, tensor) with
# a raw tensor, NOT nn.Parameter. Module.__setattr__ stores raw tensors
# in __dict__ (not _parameters), so our parameter loop misses them.
# This is how BigVGAN's conv.weight ends up invisible to _parameters.
# Fix: re-register as Parameter, which also makes them trainable.
#
# 3. Registered buffers (_buffers): Activation1d's anti-aliasing filter
# tensors. Not trainable, but operations on inference buffers produce
# inference tensor outputs — which breaks the backward graph mid-network.
# Fix: clone to strip the inference flag (not registered as parameters).
for module in vocoder.modules():
# Category 1: registered parameters
for pname, param in list(module._parameters.items()):
if param is not None:
fresh = param.data.clone().detach().requires_grad_(True)
module._parameters[pname] = nn_mod.Parameter(fresh)
module._parameters[pname] = nn_mod.Parameter(
param.data.clone().detach(), requires_grad=True
)
# Category 2: plain tensor attributes (e.g. weight left by remove_parametrizations)
for name, val in list(module.__dict__.items()):
if (isinstance(val, torch.Tensor)
and not isinstance(val, nn_mod.Parameter)
and name not in module._buffers
and name not in module._modules):
module.register_parameter(name, nn_mod.Parameter(val.clone()))
# Category 3: buffers (Activation1d filter, etc.) — clone, don't parametrize
for bname, buf in list(module._buffers.items()):
if buf is not None:
module._buffers[bname] = buf.clone()
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
vocoder.train()
@@ -305,22 +332,6 @@ def _do_train(vocoder, mel_converter, clips,
with torch.no_grad():
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
if step == 0:
print(f"[BigVGAN DEBUG] inference_mode={torch.is_inference_mode_enabled()}", flush=True)
cp = vocoder.conv_pre
print(f"[BigVGAN DEBUG] conv_pre._parameters keys={list(cp._parameters.keys())}", flush=True)
for k, v in cp._parameters.items():
if v is not None:
print(f"[BigVGAN DEBUG] _parameters[{k!r}].is_inference()={v.is_inference()}", flush=True)
has_p = hasattr(cp, 'parametrizations')
print(f"[BigVGAN DEBUG] has parametrizations={has_p}", flush=True)
if has_p:
for attr, plist in cp.parametrizations.items():
print(f"[BigVGAN DEBUG] parametrizations[{attr!r}]={plist}", flush=True)
for sub_name, sub_p in plist.named_parameters():
print(f"[BigVGAN DEBUG] {sub_name}.is_inference()={sub_p.is_inference()}", flush=True)
print(f"[BigVGAN DEBUG] conv_pre.weight.is_inference()={cp.weight.is_inference()}", flush=True)
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
T = min(pred_wav.shape[-1], target_wav.shape[-1])