fix: sanitize all submodule buffers of mel_converter + guarantee target_mel output
Previous fix only iterated mel_converter._buffers (direct buffers). Submodules (e.g. Spectrogram.window) still held inference tensors. Switch to .modules() to cover all nested buffers, matching the vocoder parameter sanitization. Also add a zeros+copy_ safety net on target_mel output so conv can save it. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -241,19 +241,17 @@ class SelvaBigvganTrainer:
|
||||
fresh, requires_grad=True
|
||||
)
|
||||
|
||||
# mel_converter buffers (mel_basis, hann_window, etc.) were loaded
|
||||
# inside ComfyUI's outer inference_mode context, so they are inference
|
||||
# tensors. Operations on inference tensors ALWAYS produce inference
|
||||
# tensors, even inside inference_mode(False). torch.zeros() et al.
|
||||
# create normal tensors in the current (non-inference) context, so
|
||||
# we replace every buffer once via copy_() to break the chain.
|
||||
for bname, buf in list(mel_converter._buffers.items()):
|
||||
# mel_converter and its submodules (e.g. Spectrogram.window) have
|
||||
# inference-tensor buffers loaded in ComfyUI's outer inference_mode.
|
||||
# Must iterate .modules() — ._buffers only covers direct buffers.
|
||||
for sub in mel_converter.modules():
|
||||
for bname, buf in list(sub._buffers.items()):
|
||||
if buf is not None:
|
||||
fresh = torch.zeros(
|
||||
buf.shape, device=buf.device, dtype=buf.dtype
|
||||
)
|
||||
fresh.copy_(buf)
|
||||
mel_converter._buffers[bname] = fresh
|
||||
sub._buffers[bname] = fresh
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
||||
@@ -280,11 +278,16 @@ class SelvaBigvganTrainer:
|
||||
del _stacked
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
# Fixed target mel — buffers are now normal tensors (sanitized
|
||||
# above), so torch.no_grad() correctly produces a non-inference,
|
||||
# no-grad leaf tensor that conv layers can save for backward.
|
||||
# Compute target mel and guarantee it is not an inference tensor.
|
||||
# Even with sanitized buffers a submodule we missed could still
|
||||
# taint the output, so we always copy into a fresh tensor.
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, 80, T_mel]
|
||||
_mel = mel_converter(target_flat)
|
||||
target_mel = torch.empty(
|
||||
_mel.shape, device=device, dtype=dtype
|
||||
)
|
||||
target_mel.copy_(_mel)
|
||||
del _mel
|
||||
|
||||
# Vocoder forward: mel → waveform
|
||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||
|
||||
Reference in New Issue
Block a user