import os import torch try: from omnivoice import OmniVoice except ImportError: OmniVoice = None # deferred; will raise at runtime if package is missing try: import folder_paths CACHE_DIR = os.path.join(folder_paths.models_dir, "omnivoice") except ImportError: CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "omnivoice") DTYPE_MAP = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } class OmniVoiceModelLoader: @classmethod def INPUT_TYPES(cls): return { "required": { "device": ( ["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}, ), "dtype": ( ["float16", "bfloat16", "float32"], {"default": "float16"}, ), "compile": ( "BOOLEAN", { "default": False, "tooltip": ( "Run torch.compile() on the model after loading. " "First generation will be slow (~30-60s warmup) while the graph is compiled, " "then every subsequent generation in the session will be faster. " "Recommended for audiobook pipelines. Requires PyTorch 2.0+." ), }, ), }, } RETURN_TYPES = ("OMNIVOICE_MODEL",) RETURN_NAMES = ("model",) FUNCTION = "load_model" CATEGORY = "OmniVoice" def load_model(self, device, dtype, compile=False): if OmniVoice is None: raise ImportError( "omnivoice is not installed. Run: pip install omnivoice --no-deps" ) model = OmniVoice.from_pretrained( "k2-fsa/OmniVoice", device_map=device, dtype=DTYPE_MAP[dtype], cache_dir=CACHE_DIR, ) if compile: model = torch.compile(model) return (model,)