feat: auto-download SelVA weights on first use
Uses selva_core/utils/download_utils.py (already has URLs + MD5s for all weights). Models download to models/selva/ on first load. Synchformer reuses models/prismaudio/synchformer_state_dict.pth if already present (no duplicate download for PrismAudio users), otherwise downloads to models/selva/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+23
-30
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
@@ -12,17 +13,26 @@ _VARIANTS = {
|
|||||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||||
}
|
}
|
||||||
|
|
||||||
_SELVA_DIR = os.path.join(folder_paths.models_dir, "selva")
|
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||||
|
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
|
||||||
|
|
||||||
|
|
||||||
def _selva_path(*parts):
|
def _ensure(filename, subdir=None):
|
||||||
return os.path.join(_SELVA_DIR, *parts)
|
"""Return path to weight file, downloading it if missing."""
|
||||||
|
from selva_core.utils.download_utils import download_model_if_needed
|
||||||
|
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
|
||||||
|
path = dest_dir / filename
|
||||||
|
download_model_if_needed(path)
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
def _require(path, hint):
|
def _synchformer_path():
|
||||||
if not os.path.exists(path):
|
"""Return synchformer path, reusing models/prismaudio/ if already present."""
|
||||||
raise RuntimeError(f"[SelVA] Missing: {path}\n{hint}")
|
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
|
||||||
return path
|
if prismaudio_path.exists():
|
||||||
|
return str(prismaudio_path)
|
||||||
|
# Not downloaded for PrismAudio yet — download to models/selva/
|
||||||
|
return _ensure("synchformer_state_dict.pth")
|
||||||
|
|
||||||
|
|
||||||
class SelvaModelLoader:
|
class SelvaModelLoader:
|
||||||
@@ -53,29 +63,12 @@ class SelvaModelLoader:
|
|||||||
strategy = determine_offload_strategy(offload_strategy)
|
strategy = determine_offload_strategy(offload_strategy)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Resolve weight paths
|
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
|
||||||
video_enc_path = _require(
|
video_enc_path = _ensure("video_enc_sup_5.pth")
|
||||||
_selva_path("video_enc_sup_5.pth"),
|
gen_path = _ensure(gen_filename)
|
||||||
"Download from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
vae_path = _ensure(f"v1-{mode}.pth", subdir="ext")
|
||||||
)
|
synch_path = _synchformer_path()
|
||||||
gen_path = _require(
|
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
|
||||||
_selva_path(gen_filename),
|
|
||||||
f"Download {gen_filename} from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
|
||||||
)
|
|
||||||
vae_path = _require(
|
|
||||||
_selva_path("ext", f"v1-{mode}.pth"),
|
|
||||||
f"Download v1-{mode}.pth from MMAudio/SelVA release and place in models/selva/ext/"
|
|
||||||
)
|
|
||||||
synch_path = _require(
|
|
||||||
os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth"),
|
|
||||||
"Synchformer checkpoint missing from models/prismaudio/ — download from FunAudioLLM/PrismAudio"
|
|
||||||
)
|
|
||||||
bigvgan_path = None
|
|
||||||
if has_bigvgan:
|
|
||||||
bigvgan_path = _require(
|
|
||||||
_selva_path("ext", "best_netG.pt"),
|
|
||||||
"Download best_netG.pt (BigVGAN 16k vocoder) from MMAudio release and place in models/selva/ext/"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||||
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user