f7a6f7790d
Modernized MisoTTS integration for ComfyUI with no torchtune/moshi: - vendored plain-torch Llama backbone (csm_llama), parity-verified Δ=0 vs torchtune - transformers.MimiModel codec (bit-identical codes to moshi), drops moshi/bnb/sphn - low-memory loader: streams 32GB fp32 checkpoint to GPU in bf16 (~18GB VRAM) - nodes: Model Loader, Generate (audiobook chunking + voice anchoring), EPUB Loader - pin-free requirements; runs on modern torch / Blackwell GPUs Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import os
|
|
|
|
import torch
|
|
|
|
_import_error = None
|
|
try:
|
|
from ..misotts import load_miso_8b
|
|
except Exception as e: # pragma: no cover - surfaced to the user at node runtime
|
|
load_miso_8b = None
|
|
_import_error = e
|
|
|
|
try:
|
|
import folder_paths
|
|
CACHE_DIR = os.path.join(folder_paths.models_dir, "misotts")
|
|
except ImportError:
|
|
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "misotts")
|
|
|
|
DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
|
|
|
|
|
class MisoTTSModelLoader:
|
|
"""Load the MisoTTS 8B model (modernized: no torchtune/moshi).
|
|
|
|
The 32 GB fp32 checkpoint is streamed straight to the GPU in the chosen dtype, so
|
|
loading needs ~18 GB VRAM (bf16) and almost no system RAM.
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"device": (["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}),
|
|
"dtype": (["bfloat16", "float16", "float32"], {"default": "bfloat16"}),
|
|
},
|
|
"optional": {
|
|
"model_repo_or_path": ("STRING", {
|
|
"default": "MisoLabs/MisoTTS",
|
|
"tooltip": "HF repo id or a local path to a model.safetensors / model dir.",
|
|
}),
|
|
"tokenizer": ("STRING", {
|
|
"default": "unsloth/Llama-3.2-1B",
|
|
"tooltip": (
|
|
"Llama-3.2 tokenizer source. Default is an ungated mirror byte-identical "
|
|
"to meta-llama/Llama-3.2-1B. Change only if you know what you're doing."
|
|
),
|
|
}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MISOTTS_MODEL",)
|
|
RETURN_NAMES = ("model",)
|
|
FUNCTION = "load_model"
|
|
CATEGORY = "MisoTTS"
|
|
|
|
def load_model(self, device, dtype, model_repo_or_path="MisoLabs/MisoTTS", tokenizer="unsloth/Llama-3.2-1B"):
|
|
if load_miso_8b is None:
|
|
raise ImportError(
|
|
"MisoTTS engine failed to import. Ensure transformers, safetensors, tokenizers "
|
|
f"and torchaudio are installed.\nOriginal error: {_import_error}"
|
|
)
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
os.environ.setdefault("HF_HOME", CACHE_DIR)
|
|
source = model_repo_or_path.strip() or "MisoLabs/MisoTTS"
|
|
gen = load_miso_8b(
|
|
device=device,
|
|
model_path_or_repo_id=source,
|
|
dtype=DTYPE_MAP[dtype],
|
|
tokenizer_name=tokenizer.strip() or "unsloth/Llama-3.2-1B",
|
|
)
|
|
return (gen,)
|