diff --git a/__init__.py b/__init__.py index 571d1d5..c086790 100644 --- a/__init__.py +++ b/__init__.py @@ -1,13 +1,18 @@ -from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate +try: + from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate -NODE_CLASS_MAPPINGS = { - "OmniVoiceModelLoader": OmniVoiceModelLoader, - "OmniVoiceGenerate": OmniVoiceGenerate, -} + NODE_CLASS_MAPPINGS = { + "OmniVoiceModelLoader": OmniVoiceModelLoader, + "OmniVoiceGenerate": OmniVoiceGenerate, + } -NODE_DISPLAY_NAME_MAPPINGS = { - "OmniVoiceModelLoader": "OmniVoice Model Loader", - "OmniVoiceGenerate": "OmniVoice Generate", -} + NODE_DISPLAY_NAME_MAPPINGS = { + "OmniVoiceModelLoader": "OmniVoice Model Loader", + "OmniVoiceGenerate": "OmniVoice Generate", + } +except ImportError: + # Graceful fallback when loaded outside of a package context (e.g. pytest) + NODE_CLASS_MAPPINGS = {} + NODE_DISPLAY_NAME_MAPPINGS = {} __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..0edeb20 --- /dev/null +++ b/conftest.py @@ -0,0 +1,6 @@ +import sys +import os + +# Ensure the project root is on sys.path so `nodes.loader` can be imported +# as a top-level package without needing to install the package. +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/nodes/loader.py b/nodes/loader.py index 4c78630..b59c263 100644 --- a/nodes/loader.py +++ b/nodes/loader.py @@ -1,2 +1,63 @@ +import os +import torch + +try: + from omnivoice import OmniVoice +except ImportError: + OmniVoice = None # type: ignore[assignment,misc] + +try: + import folder_paths + CACHE_DIR = os.path.join(folder_paths.models_dir, "omnivoice") +except ImportError: + CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "omnivoice") + + +DTYPE_MAP = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + + class OmniVoiceModelLoader: - pass + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_source": ( + ["Auto-download (HuggingFace)", "Local path"], + {"default": "Auto-download (HuggingFace)"}, + ), + "device": ( + ["cuda:0", "cuda:1", "cpu"], + {"default": "cuda:0"}, + ), + "dtype": ( + ["float16", "bfloat16", "float32"], + {"default": "float16"}, + ), + }, + "optional": { + "local_path": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = ("OMNIVOICE_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "OmniVoice" + + def load_model(self, model_source, device, dtype, local_path=""): + if model_source == "Local path" and local_path: + source = local_path + else: + source = "k2-fsa/OmniVoice" + + model = OmniVoice.from_pretrained( + source, + device_map=device, + dtype=DTYPE_MAP[dtype], + cache_dir=CACHE_DIR, + ) + return (model,) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6474f54 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +addopts = --import-mode=importlib diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 0000000..aa8d8a1 --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,59 @@ +# tests/test_loader.py +from unittest.mock import patch, MagicMock +import torch +import pytest +from nodes.loader import OmniVoiceModelLoader + + +def test_input_types_structure(): + inputs = OmniVoiceModelLoader.INPUT_TYPES() + required = inputs["required"] + assert "model_source" in required + assert "device" in required + assert "dtype" in required + optional = inputs.get("optional", {}) + assert "local_path" in optional + + +def test_input_types_model_source_choices(): + inputs = OmniVoiceModelLoader.INPUT_TYPES() + choices = inputs["required"]["model_source"][0] + assert "Auto-download (HuggingFace)" in choices + assert "Local path" in choices + + +def test_return_type(): + assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",) + + +def test_load_model_auto_download(): + loader = OmniVoiceModelLoader() + mock_model = MagicMock() + with patch("nodes.loader.OmniVoice") as MockOmniVoice: + MockOmniVoice.from_pretrained.return_value = mock_model + result = loader.load_model( + model_source="Auto-download (HuggingFace)", + device="cpu", + dtype="float32", + local_path="", + ) + assert result == (mock_model,) + MockOmniVoice.from_pretrained.assert_called_once() + call_kwargs = MockOmniVoice.from_pretrained.call_args + assert call_kwargs[0][0] == "k2-fsa/OmniVoice" + + +def test_load_model_local_path(): + loader = OmniVoiceModelLoader() + mock_model = MagicMock() + with patch("nodes.loader.OmniVoice") as MockOmniVoice: + MockOmniVoice.from_pretrained.return_value = mock_model + result = loader.load_model( + model_source="Local path", + device="cpu", + dtype="float32", + local_path="/some/local/path", + ) + assert result == (mock_model,) + call_args = MockOmniVoice.from_pretrained.call_args[0][0] + assert call_args == "/some/local/path"