fix: move only VAE+vocoder to GPU during eval to prevent device mismatch

The previous check (next(feature_utils_orig.parameters()).device) only
inspected the first parameter (from CLIP), missing CPU-stranded vocoder
weights when the module was in a mixed-device state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 19:36:02 +02:00
parent 8fa2699551
commit 1d1ae61409
3 changed files with 16 additions and 16 deletions
+4 -5
View File
@@ -93,15 +93,14 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
orig_dev = next(feature_utils_orig.parameters()).device
if orig_dev != device:
feature_utils_orig.to(device)
tod = feature_utils_orig.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
if orig_dev != device:
feature_utils_orig.to(orig_dev)
tod.to(tod_orig_dev)
audio = audio.float().cpu()
if audio.dim() == 2: