fix: keep VAE in float32 for mel/stft; print full traceback on clip load failure
torch.stft requires float32 input — casting vae_utils to bf16 caused silent failures during dataset pre-loading. Also adds traceback.print_exc() so future clip-load errors are visible in the ComfyUI log. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -291,12 +292,13 @@ class SelvaLoraTrainer:
|
|||||||
"Run SelVA Model Loader first to auto-download weights."
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
)
|
)
|
||||||
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||||
|
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||||
vae_utils = FeaturesUtils(
|
vae_utils = FeaturesUtils(
|
||||||
tod_vae_ckpt=str(vae_path),
|
tod_vae_ckpt=str(vae_path),
|
||||||
enable_conditions=False,
|
enable_conditions=False,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
need_vae_encoder=True,
|
need_vae_encoder=True,
|
||||||
).to(device, dtype).eval()
|
).to(device).eval()
|
||||||
|
|
||||||
# --- Pre-load dataset ---
|
# --- Pre-load dataset ---
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
@@ -324,9 +326,9 @@ class SelvaLoraTrainer:
|
|||||||
try:
|
try:
|
||||||
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
|
||||||
# Audio → latent via VAE
|
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||||
# encode_audio is @inference_mode — .clone() exits inference mode
|
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||||
audio_b = audio.unsqueeze(0).to(device, dtype)
|
audio_b = audio.unsqueeze(0).to(device)
|
||||||
dist = vae_utils.encode_audio(audio_b)
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
x1 = dist.mode().clone().cpu()
|
x1 = dist.mode().clone().cpu()
|
||||||
|
|
||||||
@@ -336,6 +338,7 @@ class SelvaLoraTrainer:
|
|||||||
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
dataset.append((x1, bundle["clip_features"], bundle["sync_features"], text_clip))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
pbar_load.update(1)
|
pbar_load.update(1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user