fix: handle all three inference-tensor sources in vocoder sanitization
remove_parametrizations() stores weight as a plain __dict__ tensor (not nn.Parameter), making it invisible to _parameters iteration. Also, buffers (Activation1d anti-aliasing filters) are inference tensors that break the backward graph mid-network. Fix all three categories: 1. _parameters: clone().detach(), wrap as Parameter 2. plain __dict__ tensors: clone(), register_parameter (also makes trainable) 3. _buffers: clone() to strip inference flag without parametrizing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -277,15 +277,42 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
|
||||
_save_sample("baseline")
|
||||
|
||||
# Replace inference-tensor parameters with fresh trainable parameters.
|
||||
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
|
||||
# may carry the inference flag from that context. Here we're in a clean thread
|
||||
# so all new tensors are normal.)
|
||||
# Sanitize all inference tensors in the vocoder.
|
||||
# Three categories to handle (all loaded in ComfyUI's inference_mode):
|
||||
#
|
||||
# 1. Registered parameters (_parameters): covers bias, alpha, etc.
|
||||
#
|
||||
# 2. Plain tensor attributes in __dict__: torch.nn.utils.parametrize.
|
||||
# remove_parametrizations() calls setattr(module, name, tensor) with
|
||||
# a raw tensor, NOT nn.Parameter. Module.__setattr__ stores raw tensors
|
||||
# in __dict__ (not _parameters), so our parameter loop misses them.
|
||||
# This is how BigVGAN's conv.weight ends up invisible to _parameters.
|
||||
# Fix: re-register as Parameter, which also makes them trainable.
|
||||
#
|
||||
# 3. Registered buffers (_buffers): Activation1d's anti-aliasing filter
|
||||
# tensors. Not trainable, but operations on inference buffers produce
|
||||
# inference tensor outputs — which breaks the backward graph mid-network.
|
||||
# Fix: clone to strip the inference flag (not registered as parameters).
|
||||
for module in vocoder.modules():
|
||||
# Category 1: registered parameters
|
||||
for pname, param in list(module._parameters.items()):
|
||||
if param is not None:
|
||||
fresh = param.data.clone().detach().requires_grad_(True)
|
||||
module._parameters[pname] = nn_mod.Parameter(fresh)
|
||||
module._parameters[pname] = nn_mod.Parameter(
|
||||
param.data.clone().detach(), requires_grad=True
|
||||
)
|
||||
|
||||
# Category 2: plain tensor attributes (e.g. weight left by remove_parametrizations)
|
||||
for name, val in list(module.__dict__.items()):
|
||||
if (isinstance(val, torch.Tensor)
|
||||
and not isinstance(val, nn_mod.Parameter)
|
||||
and name not in module._buffers
|
||||
and name not in module._modules):
|
||||
module.register_parameter(name, nn_mod.Parameter(val.clone()))
|
||||
|
||||
# Category 3: buffers (Activation1d filter, etc.) — clone, don't parametrize
|
||||
for bname, buf in list(module._buffers.items()):
|
||||
if buf is not None:
|
||||
module._buffers[bname] = buf.clone()
|
||||
|
||||
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||
vocoder.train()
|
||||
@@ -305,22 +332,6 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
|
||||
if step == 0:
|
||||
print(f"[BigVGAN DEBUG] inference_mode={torch.is_inference_mode_enabled()}", flush=True)
|
||||
cp = vocoder.conv_pre
|
||||
print(f"[BigVGAN DEBUG] conv_pre._parameters keys={list(cp._parameters.keys())}", flush=True)
|
||||
for k, v in cp._parameters.items():
|
||||
if v is not None:
|
||||
print(f"[BigVGAN DEBUG] _parameters[{k!r}].is_inference()={v.is_inference()}", flush=True)
|
||||
has_p = hasattr(cp, 'parametrizations')
|
||||
print(f"[BigVGAN DEBUG] has parametrizations={has_p}", flush=True)
|
||||
if has_p:
|
||||
for attr, plist in cp.parametrizations.items():
|
||||
print(f"[BigVGAN DEBUG] parametrizations[{attr!r}]={plist}", flush=True)
|
||||
for sub_name, sub_p in plist.named_parameters():
|
||||
print(f"[BigVGAN DEBUG] {sub_name}.is_inference()={sub_p.is_inference()}", flush=True)
|
||||
print(f"[BigVGAN DEBUG] conv_pre.weight.is_inference()={cp.weight.is_inference()}", flush=True)
|
||||
|
||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||
|
||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||
|
||||
Reference in New Issue
Block a user