fix: move only VAE+vocoder to GPU during eval to prevent device mismatch
The previous check (next(feature_utils_orig.parameters()).device) only inspected the first parameter (from CLIP), missing CPU-stranded vocoder weights when the module was in a mixed-device state. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -121,16 +121,18 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
|||||||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||||||
x1_unnorm = generator.unnormalize(x1_pred)
|
x1_unnorm = generator.unnormalize(x1_pred)
|
||||||
|
|
||||||
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
|
# Only move the VAE+vocoder (tod) to GPU — avoids moving the
|
||||||
orig_device = next(feature_utils_orig.parameters()).device
|
# entire FeaturesUtils (CLIP, T5, Synchformer) which wastes VRAM
|
||||||
if orig_device != device:
|
# and fixes mixed-device state issues where the first parameter
|
||||||
feature_utils_orig.to(device)
|
# check could miss CPU-stranded vocoder weights.
|
||||||
|
tod = feature_utils_orig.tod
|
||||||
|
tod_orig_device = next(tod.parameters()).device
|
||||||
|
tod.to(device)
|
||||||
try:
|
try:
|
||||||
spec = feature_utils_orig.decode(x1_unnorm)
|
spec = feature_utils_orig.decode(x1_unnorm)
|
||||||
audio = feature_utils_orig.vocode(spec)
|
audio = feature_utils_orig.vocode(spec)
|
||||||
finally:
|
finally:
|
||||||
if orig_device != device:
|
tod.to(tod_orig_device)
|
||||||
feature_utils_orig.to(orig_device)
|
|
||||||
|
|
||||||
audio = audio.float().cpu()
|
audio = audio.float().cpu()
|
||||||
if audio.dim() == 2:
|
if audio.dim() == 2:
|
||||||
|
|||||||
@@ -93,15 +93,14 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
|
|||||||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||||||
x1_unnorm = generator.unnormalize(x1_pred)
|
x1_unnorm = generator.unnormalize(x1_pred)
|
||||||
|
|
||||||
orig_dev = next(feature_utils_orig.parameters()).device
|
tod = feature_utils_orig.tod
|
||||||
if orig_dev != device:
|
tod_orig_dev = next(tod.parameters()).device
|
||||||
feature_utils_orig.to(device)
|
tod.to(device)
|
||||||
try:
|
try:
|
||||||
spec = feature_utils_orig.decode(x1_unnorm)
|
spec = feature_utils_orig.decode(x1_unnorm)
|
||||||
audio = feature_utils_orig.vocode(spec)
|
audio = feature_utils_orig.vocode(spec)
|
||||||
finally:
|
finally:
|
||||||
if orig_dev != device:
|
tod.to(tod_orig_dev)
|
||||||
feature_utils_orig.to(orig_dev)
|
|
||||||
|
|
||||||
audio = audio.float().cpu()
|
audio = audio.float().cpu()
|
||||||
if audio.dim() == 2:
|
if audio.dim() == 2:
|
||||||
|
|||||||
@@ -124,15 +124,14 @@ class SelvaVaeRoundtrip:
|
|||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Decode using model's feature_utils — same path as the sampler
|
# Decode using model's feature_utils — same path as the sampler
|
||||||
orig_device = next(feature_utils.parameters()).device
|
tod = feature_utils.tod
|
||||||
if orig_device != device:
|
tod_orig_device = next(tod.parameters()).device
|
||||||
feature_utils.to(device)
|
tod.to(device)
|
||||||
try:
|
try:
|
||||||
spec = feature_utils.decode(latent_unnorm)
|
spec = feature_utils.decode(latent_unnorm)
|
||||||
out = feature_utils.vocode(spec)
|
out = feature_utils.vocode(spec)
|
||||||
finally:
|
finally:
|
||||||
if orig_device != device:
|
tod.to(tod_orig_device)
|
||||||
feature_utils.to(orig_device)
|
|
||||||
|
|
||||||
out = out.float().cpu()
|
out = out.float().cpu()
|
||||||
if out.dim() == 1:
|
if out.dim() == 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user