Remove local path option from model loader
Models always download to ComfyUI/models/omnivoice/ via HuggingFace. Local path added unnecessary complexity; users who want a custom path can symlink into the models directory. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -37,11 +37,11 @@ Loads the OmniVoice model. Downloads automatically from HuggingFace on first run
|
|||||||
|
|
||||||
| Input | Type | Description |
|
| Input | Type | Description |
|
||||||
|-------|------|-------------|
|
|-------|------|-------------|
|
||||||
| `model_source` | dropdown | `Auto-download (HuggingFace)` or `Local path` |
|
|
||||||
| `local_path` | string | Path to local checkpoint (optional) |
|
|
||||||
| `device` | dropdown | `cuda:0`, `cuda:1`, or `cpu` |
|
| `device` | dropdown | `cuda:0`, `cuda:1`, or `cpu` |
|
||||||
| `dtype` | dropdown | `float16`, `bfloat16`, or `float32` |
|
| `dtype` | dropdown | `float16`, `bfloat16`, or `float32` |
|
||||||
|
|
||||||
|
Downloads automatically from HuggingFace on first run and caches to `ComfyUI/models/omnivoice/`.
|
||||||
|
|
||||||
**Output:** `OMNIVOICE_MODEL`
|
**Output:** `OMNIVOICE_MODEL`
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
+3
-15
@@ -25,10 +25,6 @@ class OmniVoiceModelLoader:
|
|||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model_source": (
|
|
||||||
["Auto-download (HuggingFace)", "Local path"],
|
|
||||||
{"default": "Auto-download (HuggingFace)"},
|
|
||||||
),
|
|
||||||
"device": (
|
"device": (
|
||||||
["cuda:0", "cuda:1", "cpu"],
|
["cuda:0", "cuda:1", "cpu"],
|
||||||
{"default": "cuda:0"},
|
{"default": "cuda:0"},
|
||||||
@@ -38,9 +34,6 @@ class OmniVoiceModelLoader:
|
|||||||
{"default": "float16"},
|
{"default": "float16"},
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"optional": {
|
|
||||||
"local_path": ("STRING", {"default": ""}),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("OMNIVOICE_MODEL",)
|
RETURN_TYPES = ("OMNIVOICE_MODEL",)
|
||||||
@@ -48,19 +41,14 @@ class OmniVoiceModelLoader:
|
|||||||
FUNCTION = "load_model"
|
FUNCTION = "load_model"
|
||||||
CATEGORY = "OmniVoice"
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
def load_model(self, model_source, device, dtype, local_path=""):
|
def load_model(self, device, dtype):
|
||||||
if OmniVoice is None:
|
if OmniVoice is None:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"omnivoice is not installed. Run: pip install omnivoice"
|
"omnivoice is not installed. Run: pip install omnivoice --no-deps"
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_source == "Local path" and local_path:
|
|
||||||
source = local_path
|
|
||||||
else:
|
|
||||||
source = "k2-fsa/OmniVoice"
|
|
||||||
|
|
||||||
model = OmniVoice.from_pretrained(
|
model = OmniVoice.from_pretrained(
|
||||||
source,
|
"k2-fsa/OmniVoice",
|
||||||
device_map=device,
|
device_map=device,
|
||||||
dtype=DTYPE_MAP[dtype],
|
dtype=DTYPE_MAP[dtype],
|
||||||
cache_dir=CACHE_DIR,
|
cache_dir=CACHE_DIR,
|
||||||
|
|||||||
+15
-30
@@ -1,6 +1,5 @@
|
|||||||
# tests/test_loader.py
|
# tests/test_loader.py
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
from nodes.loader import OmniVoiceModelLoader
|
from nodes.loader import OmniVoiceModelLoader
|
||||||
|
|
||||||
@@ -8,52 +7,38 @@ from nodes.loader import OmniVoiceModelLoader
|
|||||||
def test_input_types_structure():
|
def test_input_types_structure():
|
||||||
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
||||||
required = inputs["required"]
|
required = inputs["required"]
|
||||||
assert "model_source" in required
|
|
||||||
assert "device" in required
|
assert "device" in required
|
||||||
assert "dtype" in required
|
assert "dtype" in required
|
||||||
optional = inputs.get("optional", {})
|
assert "optional" not in inputs or "local_path" not in inputs.get("optional", {})
|
||||||
assert "local_path" in optional
|
|
||||||
|
|
||||||
|
|
||||||
def test_input_types_model_source_choices():
|
def test_input_types_device_choices():
|
||||||
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
||||||
choices = inputs["required"]["model_source"][0]
|
choices = inputs["required"]["device"][0]
|
||||||
assert "Auto-download (HuggingFace)" in choices
|
assert "cuda:0" in choices
|
||||||
assert "Local path" in choices
|
assert "cpu" in choices
|
||||||
|
|
||||||
|
|
||||||
def test_return_type():
|
def test_return_type():
|
||||||
assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",)
|
assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",)
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_auto_download():
|
def test_load_model():
|
||||||
loader = OmniVoiceModelLoader()
|
loader = OmniVoiceModelLoader()
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
||||||
MockOmniVoice.from_pretrained.return_value = mock_model
|
MockOmniVoice.from_pretrained.return_value = mock_model
|
||||||
result = loader.load_model(
|
result = loader.load_model(device="cpu", dtype="float32")
|
||||||
model_source="Auto-download (HuggingFace)",
|
|
||||||
device="cpu",
|
|
||||||
dtype="float32",
|
|
||||||
local_path="",
|
|
||||||
)
|
|
||||||
assert result == (mock_model,)
|
assert result == (mock_model,)
|
||||||
MockOmniVoice.from_pretrained.assert_called_once()
|
call_args = MockOmniVoice.from_pretrained.call_args
|
||||||
call_kwargs = MockOmniVoice.from_pretrained.call_args
|
assert call_args[0][0] == "k2-fsa/OmniVoice"
|
||||||
assert call_kwargs[0][0] == "k2-fsa/OmniVoice"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_local_path():
|
def test_load_model_dtype_mapped():
|
||||||
|
import torch
|
||||||
loader = OmniVoiceModelLoader()
|
loader = OmniVoiceModelLoader()
|
||||||
mock_model = MagicMock()
|
|
||||||
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
||||||
MockOmniVoice.from_pretrained.return_value = mock_model
|
MockOmniVoice.from_pretrained.return_value = MagicMock()
|
||||||
result = loader.load_model(
|
loader.load_model(device="cpu", dtype="float16")
|
||||||
model_source="Local path",
|
call_kwargs = MockOmniVoice.from_pretrained.call_args[1]
|
||||||
device="cpu",
|
assert call_kwargs["dtype"] == torch.float16
|
||||||
dtype="float32",
|
|
||||||
local_path="/some/local/path",
|
|
||||||
)
|
|
||||||
assert result == (mock_model,)
|
|
||||||
call_args = MockOmniVoice.from_pretrained.call_args[0][0]
|
|
||||||
assert call_args == "/some/local/path"
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"properties": {"Node name for S&R": "OmniVoiceModelLoader"},
|
"properties": {"Node name for S&R": "OmniVoiceModelLoader"},
|
||||||
"widgets_values": ["Auto-download (HuggingFace)", "cuda:0", "float16", ""]
|
"widgets_values": ["cuda:0", "float16"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2,
|
"id": 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user