fix: use load_lora for resume and remove redundant inference_mode wrapper

- Resume now calls load_lora() instead of load_state_dict() directly,
  giving proper warnings for missing/unexpected LoRA keys.
- Remove redundant `with torch.inference_mode():` around encode_audio
  (already @inference_mode decorated); dist.mode().clone() pattern
  is now clearer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 17:09:35 +02:00
parent f206a1b38c
commit 57cd3dd4b4
+4 -4
View File
@@ -259,10 +259,10 @@ class SelvaLoraTrainer:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device, dtype)
with torch.inference_mode():
dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu()
dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu()
# Text → CLIP features (reuse already-loaded CLIP from inference model)
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
@@ -323,7 +323,7 @@ class SelvaLoraTrainer:
raise ValueError(
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
)
generator.load_state_dict(ckpt["state_dict"], strict=False)
load_lora(generator, ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)