fix: use load_lora for resume and remove redundant inference_mode wrapper
- Resume now calls load_lora() instead of load_state_dict() directly, giving proper warnings for missing/unexpected LoRA keys. - Remove redundant `with torch.inference_mode():` around encode_audio (already @inference_mode decorated); dist.mode().clone() pattern is now clearer. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -259,8 +259,8 @@ class SelvaLoraTrainer:
|
|||||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
|
||||||
# Audio → latent via VAE
|
# Audio → latent via VAE
|
||||||
|
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||||
audio_b = audio.unsqueeze(0).to(device, dtype)
|
audio_b = audio.unsqueeze(0).to(device, dtype)
|
||||||
with torch.inference_mode():
|
|
||||||
dist = vae_utils.encode_audio(audio_b)
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
x1 = dist.mode().clone().cpu()
|
x1 = dist.mode().clone().cpu()
|
||||||
|
|
||||||
@@ -323,7 +323,7 @@ class SelvaLoraTrainer:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
||||||
)
|
)
|
||||||
generator.load_state_dict(ckpt["state_dict"], strict=False)
|
load_lora(generator, ckpt["state_dict"])
|
||||||
optimizer.load_state_dict(ckpt["optimizer"])
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
scheduler.load_state_dict(ckpt["scheduler"])
|
scheduler.load_state_dict(ckpt["scheduler"])
|
||||||
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user