fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure
torch.stft requires float32 input — casting vae_utils to bf16 caused silent failures during dataset pre-loading. Also adds traceback.print_exc() so future clip-load errors are visible in the ComfyUI log. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -291,12 +292,13 @@ class SelvaLoraTrainer:
|
||||
"Run SelVA Model Loader first to auto-download weights."
|
||||
)
|
||||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||
vae_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=str(vae_path),
|
||||
enable_conditions=False,
|
||||
mode=mode,
|
||||
need_vae_encoder=True,
|
||||
).to(device, dtype).eval()
|
||||
).to(device).eval()
|
||||
|
||||
# --- Pre-load dataset ---
|
||||
npz_files = sorted(data_dir.glob("*.npz"))
|
||||
@@ -324,9 +326,9 @@ class SelvaLoraTrainer:
|
||||
try:
|
||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||
|
||||
# Audio → latent via VAE
|
||||
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||
audio_b = audio.unsqueeze(0).to(device, dtype)
|
||||
audio_b = audio.unsqueeze(0).to(device)
|
||||
dist = vae_utils.encode_audio(audio_b)
|
||||
x1 = dist.mode().clone().cpu()
|
||||
|
||||
@@ -336,6 +338,7 @@ class SelvaLoraTrainer:
|
||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||
except Exception as e:
|
||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
|
||||
pbar_load.update(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user