fix: sanitize target_flat — clips are inference tensors from outer inference_mode
The clips list is built inside ComfyUI's inference_mode context, so every element is an inference tensor. torch.stack().clone() propagates the flag. Use zeros+copy_ (same pattern as params/buffers) to get a normal tensor, so mel_converter(target_flat) inside no_grad produces a saveable input. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -268,7 +268,16 @@ class SelvaBigvganTrainer:
|
|||||||
start = random.randint(0, clip.shape[0] - segment_samples)
|
start = random.randint(0, clip.shape[0] - segment_samples)
|
||||||
batch.append(clip[start : start + segment_samples])
|
batch.append(clip[start : start + segment_samples])
|
||||||
|
|
||||||
target_flat = torch.stack(batch).to(device, dtype).clone() # [B, T]
|
# clips were loaded in ComfyUI's outer inference_mode, so every
|
||||||
|
# element is an inference tensor. torch.stack().clone() is still
|
||||||
|
# an inference tensor (the flag propagates through all ops).
|
||||||
|
# Use zeros+copy_ to produce a genuine normal tensor.
|
||||||
|
_stacked = torch.stack(batch).to(device, dtype)
|
||||||
|
target_flat = torch.zeros(
|
||||||
|
_stacked.shape, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
target_flat.copy_(_stacked)
|
||||||
|
del _stacked
|
||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
# Fixed target mel — buffers are now normal tensors (sanitized
|
# Fixed target mel — buffers are now normal tensors (sanitized
|
||||||
|
|||||||
Reference in New Issue
Block a user