diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index f50dfba..2f19249 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -259,10 +259,10 @@ class SelvaLoraTrainer: audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) # Audio → latent via VAE + # encode_audio is @inference_mode — .clone() exits inference mode audio_b = audio.unsqueeze(0).to(device, dtype) - with torch.inference_mode(): - dist = vae_utils.encode_audio(audio_b) - x1 = dist.mode().clone().cpu() + dist = vae_utils.encode_audio(audio_b) + x1 = dist.mode().clone().cpu() # Text → CLIP features (reuse already-loaded CLIP from inference model) text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu() @@ -323,7 +323,7 @@ class SelvaLoraTrainer: raise ValueError( f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}." ) - generator.load_state_dict(ckpt["state_dict"], strict=False) + load_lora(generator, ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)