From 52434a053a212181f1986cc0c833b7f0644c859b Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 21:57:20 +0200 Subject: [PATCH] fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.stft requires float32 input — casting vae_utils to bf16 caused silent failures during dataset pre-loading. Also adds traceback.print_exc() so future clip-load errors are visible in the ComfyUI log. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_lora_trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 2fa6d78..9da3277 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -1,6 +1,7 @@ import copy import json import random +import traceback from pathlib import Path import numpy as np @@ -291,12 +292,13 @@ class SelvaLoraTrainer: "Run SelVA Model Loader first to auto-download weights." ) print("[LoRA Trainer] Loading VAE encoder...", flush=True) + # Keep VAE in float32: mel_converter uses torch.stft which requires float32 input. vae_utils = FeaturesUtils( tod_vae_ckpt=str(vae_path), enable_conditions=False, mode=mode, need_vae_encoder=True, - ).to(device, dtype).eval() + ).to(device).eval() # --- Pre-load dataset --- npz_files = sorted(data_dir.glob("*.npz")) @@ -324,9 +326,9 @@ class SelvaLoraTrainer: try: audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) - # Audio → latent via VAE + # Audio → latent via VAE (float32: mel_converter/stft require float32) # encode_audio is @inference_mode — .clone() exits inference mode - audio_b = audio.unsqueeze(0).to(device, dtype) + audio_b = audio.unsqueeze(0).to(device) dist = vae_utils.encode_audio(audio_b) x1 = dist.mode().clone().cpu() @@ -336,6 +338,7 @@ class SelvaLoraTrainer: dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) except Exception as e: print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True) + traceback.print_exc() pbar_load.update(1)