fix: sanitize inference tensors in BigVGAN trainer via zeros+copy_ pattern
param.data.clone() and tensor.detach() on inference tensors both produce inference tensors — the flag propagates through all operations on them. Inside inference_mode(False), torch.zeros() creates genuine normal tensors. Use zeros+copy_ to sanitize both vocoder parameters and mel_converter buffers once before training, so autograd can save inputs for backward. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -225,13 +225,35 @@ class SelvaBigvganTrainer:
|
|||||||
# nn.Parameter object itself still carries the inference flag.
|
# nn.Parameter object itself still carries the inference flag.
|
||||||
# Replace each parameter with a fresh nn.Parameter created here
|
# Replace each parameter with a fresh nn.Parameter created here
|
||||||
# (inside inference_mode(False)) so the object itself is normal.
|
# (inside inference_mode(False)) so the object itself is normal.
|
||||||
|
# param.data.clone() of an inference tensor still produces an
|
||||||
|
# inference tensor. Use torch.zeros + copy_ to create a genuinely
|
||||||
|
# fresh normal tensor, then wrap in nn.Parameter (created here,
|
||||||
|
# inside inference_mode(False), so it is a normal parameter).
|
||||||
import torch.nn as nn_mod
|
import torch.nn as nn_mod
|
||||||
for module in vocoder.modules():
|
for module in vocoder.modules():
|
||||||
for pname, param in list(module._parameters.items()):
|
for pname, param in list(module._parameters.items()):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
module._parameters[pname] = nn_mod.Parameter(
|
fresh = torch.zeros(
|
||||||
param.data.clone(), requires_grad=True
|
param.shape, device=param.device, dtype=param.dtype
|
||||||
)
|
)
|
||||||
|
fresh.copy_(param.data)
|
||||||
|
module._parameters[pname] = nn_mod.Parameter(
|
||||||
|
fresh, requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# mel_converter buffers (mel_basis, hann_window, etc.) were loaded
|
||||||
|
# inside ComfyUI's outer inference_mode context, so they are inference
|
||||||
|
# tensors. Operations on inference tensors ALWAYS produce inference
|
||||||
|
# tensors, even inside inference_mode(False). torch.zeros() et al.
|
||||||
|
# create normal tensors in the current (non-inference) context, so
|
||||||
|
# we replace every buffer once via copy_() to break the chain.
|
||||||
|
for bname, buf in list(mel_converter._buffers.items()):
|
||||||
|
if buf is not None:
|
||||||
|
fresh = torch.zeros(
|
||||||
|
buf.shape, device=buf.device, dtype=buf.dtype
|
||||||
|
)
|
||||||
|
fresh.copy_(buf)
|
||||||
|
mel_converter._buffers[bname] = fresh
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
vocoder.parameters(), lr=lr, betas=(0.8, 0.99)
|
||||||
@@ -249,11 +271,11 @@ class SelvaBigvganTrainer:
|
|||||||
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
|
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
|
||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
# Fixed target mel (no grad needed here).
|
# Fixed target mel — buffers are now normal tensors (sanitized
|
||||||
# .clone() strips the inference-tensor flag inherited from
|
# above), so torch.no_grad() correctly produces a non-inference,
|
||||||
# mel_converter's buffers (loaded inside ComfyUI's inference_mode).
|
# no-grad leaf tensor that conv layers can save for backward.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat).clone() # [B, 80, T_mel]
|
target_mel = mel_converter(target_flat) # [B, 80, T_mel]
|
||||||
|
|
||||||
# Vocoder forward: mel → waveform
|
# Vocoder forward: mel → waveform
|
||||||
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
pred_wav = vocoder(target_mel) # [B, 1, T_wav]
|
||||||
@@ -263,8 +285,9 @@ class SelvaBigvganTrainer:
|
|||||||
pred_t = pred_wav[..., :T]
|
pred_t = pred_wav[..., :T]
|
||||||
target_t = target_wav[..., :T]
|
target_t = target_wav[..., :T]
|
||||||
|
|
||||||
# Mel reconstruction loss: mel(pred) vs target_mel
|
# Mel reconstruction loss — no no_grad: grad must flow
|
||||||
pred_mel = mel_converter(pred_t.squeeze(1)).clone() # [B, 80, T_mel']
|
# through pred_t → mel_converter → loss.
|
||||||
|
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, 80, T_mel']
|
||||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user