fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure

torch.stft requires float32 input — casting vae_utils to bf16 caused silent
failures during dataset pre-loading. Also adds traceback.print_exc() so future
clip-load errors are visible in the ComfyUI log.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 21:57:20 +02:00
parent 56c8d5d6b4
commit 52434a053a
+6 -3
View File
@@ -1,6 +1,7 @@
import copy import copy
import json import json
import random import random
import traceback
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -291,12 +292,13 @@ class SelvaLoraTrainer:
"Run SelVA Model Loader first to auto-download weights." "Run SelVA Model Loader first to auto-download weights."
) )
print("[LoRA Trainer] Loading VAE encoder...", flush=True) print("[LoRA Trainer] Loading VAE encoder...", flush=True)
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
vae_utils = FeaturesUtils( vae_utils = FeaturesUtils(
tod_vae_ckpt=str(vae_path), tod_vae_ckpt=str(vae_path),
enable_conditions=False, enable_conditions=False,
mode=mode, mode=mode,
need_vae_encoder=True, need_vae_encoder=True,
).to(device, dtype).eval() ).to(device).eval()
# --- Pre-load dataset --- # --- Pre-load dataset ---
npz_files = sorted(data_dir.glob("*.npz")) npz_files = sorted(data_dir.glob("*.npz"))
@@ -324,9 +326,9 @@ class SelvaLoraTrainer:
try: try:
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration) audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
# Audio → latent via VAE # Audio → latent via VAE (float32: mel_converter/stft require float32)
# encode_audio is @inference_mode — .clone() exits inference mode # encode_audio is @inference_mode — .clone() exits inference mode
audio_b = audio.unsqueeze(0).to(device, dtype) audio_b = audio.unsqueeze(0).to(device)
dist = vae_utils.encode_audio(audio_b) dist = vae_utils.encode_audio(audio_b)
x1 = dist.mode().clone().cpu() x1 = dist.mode().clone().cpu()
@@ -336,6 +338,7 @@ class SelvaLoraTrainer:
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip)) dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
except Exception as e: except Exception as e:
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True) print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
traceback.print_exc()
pbar_load.update(1) pbar_load.update(1)