diff --git a/README.md b/README.md index 968716e..1ae7280 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ Loads the OmniVoice model. Downloads automatically from HuggingFace on first run | Input | Type | Description | |-------|------|-------------| -| `model_source` | dropdown | `Auto-download (HuggingFace)` or `Local path` | -| `local_path` | string | Path to local checkpoint (optional) | | `device` | dropdown | `cuda:0`, `cuda:1`, or `cpu` | | `dtype` | dropdown | `float16`, `bfloat16`, or `float32` | +Downloads automatically from HuggingFace on first run and caches to `ComfyUI/models/omnivoice/`. + **Output:** `OMNIVOICE_MODEL` --- diff --git a/nodes/loader.py b/nodes/loader.py index f1d1ee7..80fefe5 100644 --- a/nodes/loader.py +++ b/nodes/loader.py @@ -25,10 +25,6 @@ class OmniVoiceModelLoader: def INPUT_TYPES(cls): return { "required": { - "model_source": ( - ["Auto-download (HuggingFace)", "Local path"], - {"default": "Auto-download (HuggingFace)"}, - ), "device": ( ["cuda:0", "cuda:1", "cpu"], {"default": "cuda:0"}, @@ -38,9 +34,6 @@ class OmniVoiceModelLoader: {"default": "float16"}, ), }, - "optional": { - "local_path": ("STRING", {"default": ""}), - }, } RETURN_TYPES = ("OMNIVOICE_MODEL",) @@ -48,19 +41,14 @@ class OmniVoiceModelLoader: FUNCTION = "load_model" CATEGORY = "OmniVoice" - def load_model(self, model_source, device, dtype, local_path=""): + def load_model(self, device, dtype): if OmniVoice is None: raise ImportError( - "omnivoice is not installed. Run: pip install omnivoice" + "omnivoice is not installed. Run: pip install omnivoice --no-deps" ) - if model_source == "Local path" and local_path: - source = local_path - else: - source = "k2-fsa/OmniVoice" - model = OmniVoice.from_pretrained( - source, + "k2-fsa/OmniVoice", device_map=device, dtype=DTYPE_MAP[dtype], cache_dir=CACHE_DIR, diff --git a/tests/test_loader.py b/tests/test_loader.py index aa8d8a1..5bf31d9 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,6 +1,5 @@ # tests/test_loader.py from unittest.mock import patch, MagicMock -import torch import pytest from nodes.loader import OmniVoiceModelLoader @@ -8,52 +7,38 @@ from nodes.loader import OmniVoiceModelLoader def test_input_types_structure(): inputs = OmniVoiceModelLoader.INPUT_TYPES() required = inputs["required"] - assert "model_source" in required assert "device" in required assert "dtype" in required - optional = inputs.get("optional", {}) - assert "local_path" in optional + assert "optional" not in inputs or "local_path" not in inputs.get("optional", {}) -def test_input_types_model_source_choices(): +def test_input_types_device_choices(): inputs = OmniVoiceModelLoader.INPUT_TYPES() - choices = inputs["required"]["model_source"][0] - assert "Auto-download (HuggingFace)" in choices - assert "Local path" in choices + choices = inputs["required"]["device"][0] + assert "cuda:0" in choices + assert "cpu" in choices def test_return_type(): assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",) -def test_load_model_auto_download(): +def test_load_model(): loader = OmniVoiceModelLoader() mock_model = MagicMock() with patch("nodes.loader.OmniVoice") as MockOmniVoice: MockOmniVoice.from_pretrained.return_value = mock_model - result = loader.load_model( - model_source="Auto-download (HuggingFace)", - device="cpu", - dtype="float32", - local_path="", - ) + result = loader.load_model(device="cpu", dtype="float32") assert result == (mock_model,) - MockOmniVoice.from_pretrained.assert_called_once() - call_kwargs = MockOmniVoice.from_pretrained.call_args - assert call_kwargs[0][0] == "k2-fsa/OmniVoice" + call_args = MockOmniVoice.from_pretrained.call_args + assert call_args[0][0] == "k2-fsa/OmniVoice" -def test_load_model_local_path(): +def test_load_model_dtype_mapped(): + import torch loader = OmniVoiceModelLoader() - mock_model = MagicMock() with patch("nodes.loader.OmniVoice") as MockOmniVoice: - MockOmniVoice.from_pretrained.return_value = mock_model - result = loader.load_model( - model_source="Local path", - device="cpu", - dtype="float32", - local_path="/some/local/path", - ) - assert result == (mock_model,) - call_args = MockOmniVoice.from_pretrained.call_args[0][0] - assert call_args == "/some/local/path" + MockOmniVoice.from_pretrained.return_value = MagicMock() + loader.load_model(device="cpu", dtype="float16") + call_kwargs = MockOmniVoice.from_pretrained.call_args[1] + assert call_kwargs["dtype"] == torch.float16 diff --git a/workflows/voice_cloning.json b/workflows/voice_cloning.json index 87b5348..2b7b6cb 100644 --- a/workflows/voice_cloning.json +++ b/workflows/voice_cloning.json @@ -20,7 +20,7 @@ } ], "properties": {"Node name for S&R": "OmniVoiceModelLoader"}, - "widgets_values": ["Auto-download (HuggingFace)", "cuda:0", "float16", ""] + "widgets_values": ["cuda:0", "float16"] }, { "id": 2,