Files
ComfyUI-MisoTTS/nodes/loader.py
T
Ethanfel f7a6f7790d Initial release: ComfyUI-MisoTTS (modernized CSM 8B)
Modernized MisoTTS integration for ComfyUI with no torchtune/moshi:
- vendored plain-torch Llama backbone (csm_llama), parity-verified Δ=0 vs torchtune
- transformers.MimiModel codec (bit-identical codes to moshi), drops moshi/bnb/sphn
- low-memory loader: streams 32GB fp32 checkpoint to GPU in bf16 (~18GB VRAM)
- nodes: Model Loader, Generate (audiobook chunking + voice anchoring), EPUB Loader
- pin-free requirements; runs on modern torch / Blackwell GPUs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-06 23:37:54 +02:00

71 lines
2.5 KiB
Python

import os
import torch
_import_error = None
try:
from ..misotts import load_miso_8b
except Exception as e: # pragma: no cover - surfaced to the user at node runtime
load_miso_8b = None
_import_error = e
try:
import folder_paths
CACHE_DIR = os.path.join(folder_paths.models_dir, "misotts")
except ImportError:
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "misotts")
DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
class MisoTTSModelLoader:
"""Load the MisoTTS 8B model (modernized: no torchtune/moshi).
The 32 GB fp32 checkpoint is streamed straight to the GPU in the chosen dtype, so
loading needs ~18 GB VRAM (bf16) and almost no system RAM.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"device": (["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}),
"dtype": (["bfloat16", "float16", "float32"], {"default": "bfloat16"}),
},
"optional": {
"model_repo_or_path": ("STRING", {
"default": "MisoLabs/MisoTTS",
"tooltip": "HF repo id or a local path to a model.safetensors / model dir.",
}),
"tokenizer": ("STRING", {
"default": "unsloth/Llama-3.2-1B",
"tooltip": (
"Llama-3.2 tokenizer source. Default is an ungated mirror byte-identical "
"to meta-llama/Llama-3.2-1B. Change only if you know what you're doing."
),
}),
},
}
RETURN_TYPES = ("MISOTTS_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "MisoTTS"
def load_model(self, device, dtype, model_repo_or_path="MisoLabs/MisoTTS", tokenizer="unsloth/Llama-3.2-1B"):
if load_miso_8b is None:
raise ImportError(
"MisoTTS engine failed to import. Ensure transformers, safetensors, tokenizers "
f"and torchaudio are installed.\nOriginal error: {_import_error}"
)
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.setdefault("HF_HOME", CACHE_DIR)
source = model_repo_or_path.strip() or "MisoLabs/MisoTTS"
gen = load_miso_8b(
device=device,
model_path_or_repo_id=source,
dtype=DTYPE_MAP[dtype],
tokenizer_name=tokenizer.strip() or "unsloth/Llama-3.2-1B",
)
return (gen,)