fix: use load_lora for resume and remove redundant inference_mode wrapper
- Resume now calls load_lora() instead of load_state_dict() directly, giving proper warnings for missing/unexpected LoRA keys. - Remove redundant `with torch.inference_mode():` around encode_audio (already @inference_mode decorated); dist.mode().clone() pattern is now clearer. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -259,10 +259,10 @@ class SelvaLoraTrainer:
|
||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
|
||||
# Audio → latent via VAE
|
||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||
audio_b = audio.unsqueeze(0).to(device, dtype)
|
||||
with torch.inference_mode():
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
x1 = dist.mode().clone().cpu()
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
x1 = dist.mode().clone().cpu()
|
||||
|
||||
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||
@@ -323,7 +323,7 @@ class SelvaLoraTrainer:
|
||||
raise ValueError(
|
||||
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
||||
)
|
||||
generator.load_state_dict(ckpt["state_dict"], strict=False)
|
||||
load_lora(generator, ckpt["state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
scheduler.load_state_dict(ckpt["scheduler"])
|
||||
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user