fix: sanitize target_flat — clips are inference tensors from outer inference_mode

The clips list is built inside ComfyUI's inference_mode context, so every
element is an inference tensor. torch.stack().clone() propagates the flag.
Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor,
so mel_converter(target_flat) inside no_grad produces a saveable input.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 02:09:26 +02:00
parent b7565ec458
commit 51ac099073
+11 -2
View File
@@ -268,8 +268,17 @@ class SelvaBigvganTrainer:
start = random.randint(0, clip.shape[0] - segment_samples)
batch.append(clip[start : start + segment_samples])
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
# clips were loaded in ComfyUI's outer inference_mode, so every
# element is an inference tensor. torch.stack().clone() is still
# an inference tensor (the flag propagates through all ops).
# Use zeros+copy_ to produce a genuine normal tensor.
_stacked = torch.stack(batch).to(device, dtype)
target_flat = torch.zeros(
_stacked.shape, device=device, dtype=dtype
)
target_flat.copy_(_stacked)
del _stacked
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
# Fixed target mel — buffers are now normal tensors (sanitized
# above), so torch.no_grad() correctly produces a non-inference,