fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure

torch.stft requires float32 input — casting vae_utils to bf16 caused silent
failures during dataset pre-loading. Also adds traceback.print_exc() so future
clip-load errors are visible in the ComfyUI log.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 21:57:20 +02:00
parent 56c8d5d6b4
commit 52434a053a
+6 -3
View File
@@ -1,6 +1,7 @@
import copy
import json
import random
import traceback
from pathlib import Path
import numpy as np
@@ -291,12 +292,13 @@ class SelvaLoraTrainer:
"Run SelVA Model Loader first to auto-download weights."
)
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
).to(device).eval()
# --- Pre-load dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
@@ -324,9 +326,9 @@ class SelvaLoraTrainer:
try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE
# Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device, dtype)
audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu()
@@ -336,6 +338,7 @@ class SelvaLoraTrainer:
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1)