diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 76fbe1a..cfd6873 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -121,16 +121,18 @@ def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype, x1_pred = eval_fm.to_data(velocity_fn, x0) x1_unnorm = generator.unnormalize(x1_pred) - # feature_utils_orig may be on CPU (offload strategy) — move temporarily - orig_device = next(feature_utils_orig.parameters()).device - if orig_device != device: - feature_utils_orig.to(device) + # Only move the VAE+vocoder (tod) to GPU — avoids moving the + # entire FeaturesUtils (CLIP, T5, Synchformer) which wastes VRAM + # and fixes mixed-device state issues where the first parameter + # check could miss CPU-stranded vocoder weights. + tod = feature_utils_orig.tod + tod_orig_device = next(tod.parameters()).device + tod.to(device) try: spec = feature_utils_orig.decode(x1_unnorm) audio = feature_utils_orig.vocode(spec) finally: - if orig_device != device: - feature_utils_orig.to(orig_device) + tod.to(tod_orig_device) audio = audio.float().cpu() if audio.dim() == 2: diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py index 2f64b5b..66f8a32 100644 --- a/nodes/selva_textual_inversion_trainer.py +++ b/nodes/selva_textual_inversion_trainer.py @@ -93,15 +93,14 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode, x1_pred = eval_fm.to_data(velocity_fn, x0) x1_unnorm = generator.unnormalize(x1_pred) - orig_dev = next(feature_utils_orig.parameters()).device - if orig_dev != device: - feature_utils_orig.to(device) + tod = feature_utils_orig.tod + tod_orig_dev = next(tod.parameters()).device + tod.to(device) try: spec = feature_utils_orig.decode(x1_unnorm) audio = feature_utils_orig.vocode(spec) finally: - if orig_dev != device: - feature_utils_orig.to(orig_dev) + tod.to(tod_orig_dev) audio = audio.float().cpu() if audio.dim() == 2: diff --git a/nodes/selva_vae_roundtrip.py b/nodes/selva_vae_roundtrip.py index d563459..5a6db3d 100644 --- a/nodes/selva_vae_roundtrip.py +++ b/nodes/selva_vae_roundtrip.py @@ -124,15 +124,14 @@ class SelvaVaeRoundtrip: flush=True) # Decode using model's feature_utils — same path as the sampler - orig_device = next(feature_utils.parameters()).device - if orig_device != device: - feature_utils.to(device) + tod = feature_utils.tod + tod_orig_device = next(tod.parameters()).device + tod.to(device) try: spec = feature_utils.decode(latent_unnorm) out = feature_utils.vocode(spec) finally: - if orig_device != device: - feature_utils.to(orig_device) + tod.to(tod_orig_device) out = out.float().cpu() if out.dim() == 1: