fix: sanitize target_flat — clips are inference tensors from outer inference_mode
The clips list is built inside ComfyUI's inference_mode context, so every element is an inference tensor. torch.stack().clone() propagates the flag. Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor, so mel_converter(target_flat) inside no_grad produces a saveable input. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -268,8 +268,17 @@ class SelvaBigvganTrainer:
|
||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||
batch.append(clip[start : start + segment_samples])
|
||||
|
||||
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
# clips were loaded in ComfyUI's outer inference_mode, so every
|
||||
# element is an inference tensor. torch.stack().clone() is still
|
||||
# an inference tensor (the flag propagates through all ops).
|
||||
# Use zeros+copy_ to produce a genuine normal tensor.
|
||||
_stacked = torch.stack(batch).to(device, dtype)
|
||||
target_flat = torch.zeros(
|
||||
_stacked.shape, device=device, dtype=dtype
|
||||
)
|
||||
target_flat.copy_(_stacked)
|
||||
del _stacked
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
# Fixed target mel — buffers are now normal tensors (sanitized
|
||||
# above), so torch.no_grad() correctly produces a non-inference,
|
||||
|
||||
Reference in New Issue
Block a user