fix: use load_lora for resume and remove redundant inference_mode wrapper

- Resume now calls load_lora() instead of load_state_dict() directly,
  giving proper warnings for missing/unexpected LoRA keys.
- Remove redundant `with torch.inference_mode():` around encode_audio
  (already @inference_mode decorated); dist.mode().clone() pattern
  is now clearer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 17:09:35 +02:00
parent f206a1b38c
commit 57cd3dd4b4
+2 -2
View File
@@ -259,8 +259,8 @@ class SelvaLoraTrainer:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE # Audio → latent via VAE
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device, dtype) audio_b = audio.unsqueeze(0).to(device, dtype)
with torch.inference_mode():
dist = vae_utils.encode_audio(audio_b) dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu() x1 = dist.mode().clone().cpu()
@@ -323,7 +323,7 @@ class SelvaLoraTrainer:
raise ValueError( raise ValueError(
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}." f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
) )
generator.load_state_dict(ckpt["state_dict"], strict=False) load_lora(generator, ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"]) optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"]) scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True) print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)