diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 2fa6d78..9da3277 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -1,6 +1,7 @@ import copy import json import random +import traceback from pathlib import Path import numpy as np @@ -291,12 +292,13 @@ class SelvaLoraTrainer: "Run SelVA Model Loader first to auto-download weights." ) print("[LoRA Trainer] Loading VAE encoder...", flush=True) + # Keep VAE in float32: mel_converter uses torch.stft which requires float32 input. vae_utils = FeaturesUtils( tod_vae_ckpt=str(vae_path), enable_conditions=False, mode=mode, need_vae_encoder=True, - ).to(device, dtype).eval() + ).to(device).eval() # --- Pre-load dataset --- npz_files = sorted(data_dir.glob("*.npz")) @@ -324,9 +326,9 @@ class SelvaLoraTrainer: try: audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) - # Audio → latent via VAE + # Audio → latent via VAE (float32: mel_converter/stft require float32) # encode_audio is @inference_mode — .clone() exits inference mode - audio_b = audio.unsqueeze(0).to(device, dtype) + audio_b = audio.unsqueeze(0).to(device) dist = vae_utils.encode_audio(audio_b) x1 = dist.mode().clone().cpu() @@ -336,6 +338,7 @@ class SelvaLoraTrainer: dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) except Exception as e: print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True) + traceback.print_exc() pbar_load.update(1)