fix: handle all three inference-tensor sources in vocoder sanitization
remove_parametrizations() stores weight as a plain __dict__ tensor (not nn.Parameter), making it invisible to _parameters iteration. Also, buffers (Activation1d anti-aliasing filters) are inference tensors that break the backward graph mid-network. Fix all three categories: 1. _parameters: clone().detach(), wrap as Parameter 2. plain __dict__ tensors: clone(), register_parameter (also makes trainable) 3. _buffers: clone() to strip inference flag without parametrizing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -277,15 +277,42 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
|
|
||||||
_save_sample("baseline")
|
_save_sample("baseline")
|
||||||
|
|
||||||
# Replace inference-tensor parameters with fresh trainable parameters.
|
# Sanitize all inference tensors in the vocoder.
|
||||||
# (Vocoder was loaded by ComfyUI before this thread existed; its parameters
|
# Three categories to handle (all loaded in ComfyUI's inference_mode):
|
||||||
# may carry the inference flag from that context. Here we're in a clean thread
|
#
|
||||||
# so all new tensors are normal.)
|
# 1. Registered parameters (_parameters): covers bias, alpha, etc.
|
||||||
|
#
|
||||||
|
# 2. Plain tensor attributes in __dict__: torch.nn.utils.parametrize.
|
||||||
|
# remove_parametrizations() calls setattr(module, name, tensor) with
|
||||||
|
# a raw tensor, NOT nn.Parameter. Module.__setattr__ stores raw tensors
|
||||||
|
# in __dict__ (not _parameters), so our parameter loop misses them.
|
||||||
|
# This is how BigVGAN's conv.weight ends up invisible to _parameters.
|
||||||
|
# Fix: re-register as Parameter, which also makes them trainable.
|
||||||
|
#
|
||||||
|
# 3. Registered buffers (_buffers): Activation1d's anti-aliasing filter
|
||||||
|
# tensors. Not trainable, but operations on inference buffers produce
|
||||||
|
# inference tensor outputs — which breaks the backward graph mid-network.
|
||||||
|
# Fix: clone to strip the inference flag (not registered as parameters).
|
||||||
for module in vocoder.modules():
|
for module in vocoder.modules():
|
||||||
|
# Category 1: registered parameters
|
||||||
for pname, param in list(module._parameters.items()):
|
for pname, param in list(module._parameters.items()):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
fresh = param.data.clone().detach().requires_grad_(True)
|
module._parameters[pname] = nn_mod.Parameter(
|
||||||
module._parameters[pname] = nn_mod.Parameter(fresh)
|
param.data.clone().detach(), requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Category 2: plain tensor attributes (e.g. weight left by remove_parametrizations)
|
||||||
|
for name, val in list(module.__dict__.items()):
|
||||||
|
if (isinstance(val, torch.Tensor)
|
||||||
|
and not isinstance(val, nn_mod.Parameter)
|
||||||
|
and name not in module._buffers
|
||||||
|
and name not in module._modules):
|
||||||
|
module.register_parameter(name, nn_mod.Parameter(val.clone()))
|
||||||
|
|
||||||
|
# Category 3: buffers (Activation1d filter, etc.) — clone, don't parametrize
|
||||||
|
for bname, buf in list(module._buffers.items()):
|
||||||
|
if buf is not None:
|
||||||
|
module._buffers[bname] = buf.clone()
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||||
vocoder.train()
|
vocoder.train()
|
||||||
@@ -305,22 +332,6 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||||
|
|
||||||
if step == 0:
|
|
||||||
print(f"[BigVGAN DEBUG] inference_mode={torch.is_inference_mode_enabled()}", flush=True)
|
|
||||||
cp = vocoder.conv_pre
|
|
||||||
print(f"[BigVGAN DEBUG] conv_pre._parameters keys={list(cp._parameters.keys())}", flush=True)
|
|
||||||
for k, v in cp._parameters.items():
|
|
||||||
if v is not None:
|
|
||||||
print(f"[BigVGAN DEBUG] _parameters[{k!r}].is_inference()={v.is_inference()}", flush=True)
|
|
||||||
has_p = hasattr(cp, 'parametrizations')
|
|
||||||
print(f"[BigVGAN DEBUG] has parametrizations={has_p}", flush=True)
|
|
||||||
if has_p:
|
|
||||||
for attr, plist in cp.parametrizations.items():
|
|
||||||
print(f"[BigVGAN DEBUG] parametrizations[{attr!r}]={plist}", flush=True)
|
|
||||||
for sub_name, sub_p in plist.named_parameters():
|
|
||||||
print(f"[BigVGAN DEBUG] {sub_name}.is_inference()={sub_p.is_inference()}", flush=True)
|
|
||||||
print(f"[BigVGAN DEBUG] conv_pre.weight.is_inference()={cp.weight.is_inference()}", flush=True)
|
|
||||||
|
|
||||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||||
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
|
|||||||
Reference in New Issue
Block a user