Initial release: ComfyUI-MisoTTS (modernized CSM 8B)
Modernized MisoTTS integration for ComfyUI with no torchtune/moshi: - vendored plain-torch Llama backbone (csm_llama), parity-verified Δ=0 vs torchtune - transformers.MimiModel codec (bit-identical codes to moshi), drops moshi/bnb/sphn - low-memory loader: streams 32GB fp32 checkpoint to GPU in bf16 (~18GB VRAM) - nodes: Model Loader, Generate (audiobook chunking + voice anchoring), EPUB Loader - pin-free requirements; runs on modern torch / Blackwell GPUs Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
_import_error = None
|
||||
try:
|
||||
from ..misotts import load_miso_8b
|
||||
except Exception as e: # pragma: no cover - surfaced to the user at node runtime
|
||||
load_miso_8b = None
|
||||
_import_error = e
|
||||
|
||||
try:
|
||||
import folder_paths
|
||||
CACHE_DIR = os.path.join(folder_paths.models_dir, "misotts")
|
||||
except ImportError:
|
||||
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "misotts")
|
||||
|
||||
DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||
|
||||
|
||||
class MisoTTSModelLoader:
|
||||
"""Load the MisoTTS 8B model (modernized: no torchtune/moshi).
|
||||
|
||||
The 32 GB fp32 checkpoint is streamed straight to the GPU in the chosen dtype, so
|
||||
loading needs ~18 GB VRAM (bf16) and almost no system RAM.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"device": (["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}),
|
||||
"dtype": (["bfloat16", "float16", "float32"], {"default": "bfloat16"}),
|
||||
},
|
||||
"optional": {
|
||||
"model_repo_or_path": ("STRING", {
|
||||
"default": "MisoLabs/MisoTTS",
|
||||
"tooltip": "HF repo id or a local path to a model.safetensors / model dir.",
|
||||
}),
|
||||
"tokenizer": ("STRING", {
|
||||
"default": "unsloth/Llama-3.2-1B",
|
||||
"tooltip": (
|
||||
"Llama-3.2 tokenizer source. Default is an ungated mirror byte-identical "
|
||||
"to meta-llama/Llama-3.2-1B. Change only if you know what you're doing."
|
||||
),
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MISOTTS_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "MisoTTS"
|
||||
|
||||
def load_model(self, device, dtype, model_repo_or_path="MisoLabs/MisoTTS", tokenizer="unsloth/Llama-3.2-1B"):
|
||||
if load_miso_8b is None:
|
||||
raise ImportError(
|
||||
"MisoTTS engine failed to import. Ensure transformers, safetensors, tokenizers "
|
||||
f"and torchaudio are installed.\nOriginal error: {_import_error}"
|
||||
)
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
os.environ.setdefault("HF_HOME", CACHE_DIR)
|
||||
source = model_repo_or_path.strip() or "MisoLabs/MisoTTS"
|
||||
gen = load_miso_8b(
|
||||
device=device,
|
||||
model_path_or_repo_id=source,
|
||||
dtype=DTYPE_MAP[dtype],
|
||||
tokenizer_name=tokenizer.strip() or "unsloth/Llama-3.2-1B",
|
||||
)
|
||||
return (gen,)
|
||||
Reference in New Issue
Block a user