import os import torch _import_error = None try: from ..misotts import load_miso_8b except Exception as e: # pragma: no cover - surfaced to the user at node runtime load_miso_8b = None _import_error = e try: import folder_paths CACHE_DIR = os.path.join(folder_paths.models_dir, "misotts") except ImportError: CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "misotts") DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} class MisoTTSModelLoader: """Load the MisoTTS 8B model (modernized: no torchtune/moshi). The 32 GB fp32 checkpoint is streamed straight to the GPU in the chosen dtype, so loading needs ~18 GB VRAM (bf16) and almost no system RAM. """ @classmethod def INPUT_TYPES(cls): return { "required": { "device": (["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}), "dtype": (["bfloat16", "float16", "float32"], {"default": "bfloat16"}), }, "optional": { "model_repo_or_path": ("STRING", { "default": "MisoLabs/MisoTTS", "tooltip": "HF repo id or a local path to a model.safetensors / model dir.", }), "tokenizer": ("STRING", { "default": "unsloth/Llama-3.2-1B", "tooltip": ( "Llama-3.2 tokenizer source. Default is an ungated mirror byte-identical " "to meta-llama/Llama-3.2-1B. Change only if you know what you're doing." ), }), }, } RETURN_TYPES = ("MISOTTS_MODEL",) RETURN_NAMES = ("model",) FUNCTION = "load_model" CATEGORY = "MisoTTS" def load_model(self, device, dtype, model_repo_or_path="MisoLabs/MisoTTS", tokenizer="unsloth/Llama-3.2-1B"): if load_miso_8b is None: raise ImportError( "MisoTTS engine failed to import. Ensure transformers, safetensors, tokenizers " f"and torchaudio are installed.\nOriginal error: {_import_error}" ) os.makedirs(CACHE_DIR, exist_ok=True) os.environ.setdefault("HF_HOME", CACHE_DIR) source = model_repo_or_path.strip() or "MisoLabs/MisoTTS" gen = load_miso_8b( device=device, model_path_or_repo_id=source, dtype=DTYPE_MAP[dtype], tokenizer_name=tokenizer.strip() or "unsloth/Llama-3.2-1B", ) return (gen,)