feat: add OmniVoiceModelLoader node
Implements OmniVoiceModelLoader with INPUT_TYPES, RETURN_TYPES, and load_model supporting both HuggingFace auto-download and local path sources. Adds TDD test suite and pytest infrastructure (conftest.py, pytest.ini) to enable testing outside ComfyUI without omnivoice installed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+10
-5
@@ -1,13 +1,18 @@
|
|||||||
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate
|
try:
|
||||||
|
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"OmniVoiceModelLoader": OmniVoiceModelLoader,
|
"OmniVoiceModelLoader": OmniVoiceModelLoader,
|
||||||
"OmniVoiceGenerate": OmniVoiceGenerate,
|
"OmniVoiceGenerate": OmniVoiceGenerate,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"OmniVoiceModelLoader": "OmniVoice Model Loader",
|
"OmniVoiceModelLoader": "OmniVoice Model Loader",
|
||||||
"OmniVoiceGenerate": "OmniVoice Generate",
|
"OmniVoiceGenerate": "OmniVoice Generate",
|
||||||
}
|
}
|
||||||
|
except ImportError:
|
||||||
|
# Graceful fallback when loaded outside of a package context (e.g. pytest)
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Ensure the project root is on sys.path so `nodes.loader` can be imported
|
||||||
|
# as a top-level package without needing to install the package.
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
+62
-1
@@ -1,2 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from omnivoice import OmniVoice
|
||||||
|
except ImportError:
|
||||||
|
OmniVoice = None # type: ignore[assignment,misc]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import folder_paths
|
||||||
|
CACHE_DIR = os.path.join(folder_paths.models_dir, "omnivoice")
|
||||||
|
except ImportError:
|
||||||
|
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "omnivoice")
|
||||||
|
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OmniVoiceModelLoader:
|
class OmniVoiceModelLoader:
|
||||||
pass
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_source": (
|
||||||
|
["Auto-download (HuggingFace)", "Local path"],
|
||||||
|
{"default": "Auto-download (HuggingFace)"},
|
||||||
|
),
|
||||||
|
"device": (
|
||||||
|
["cuda:0", "cuda:1", "cpu"],
|
||||||
|
{"default": "cuda:0"},
|
||||||
|
),
|
||||||
|
"dtype": (
|
||||||
|
["float16", "bfloat16", "float32"],
|
||||||
|
{"default": "float16"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"local_path": ("STRING", {"default": ""}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("OMNIVOICE_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
|
def load_model(self, model_source, device, dtype, local_path=""):
|
||||||
|
if model_source == "Local path" and local_path:
|
||||||
|
source = local_path
|
||||||
|
else:
|
||||||
|
source = "k2-fsa/OmniVoice"
|
||||||
|
|
||||||
|
model = OmniVoice.from_pretrained(
|
||||||
|
source,
|
||||||
|
device_map=device,
|
||||||
|
dtype=DTYPE_MAP[dtype],
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
)
|
||||||
|
return (model,)
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
addopts = --import-mode=importlib
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# tests/test_loader.py
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from nodes.loader import OmniVoiceModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_types_structure():
|
||||||
|
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
||||||
|
required = inputs["required"]
|
||||||
|
assert "model_source" in required
|
||||||
|
assert "device" in required
|
||||||
|
assert "dtype" in required
|
||||||
|
optional = inputs.get("optional", {})
|
||||||
|
assert "local_path" in optional
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_types_model_source_choices():
|
||||||
|
inputs = OmniVoiceModelLoader.INPUT_TYPES()
|
||||||
|
choices = inputs["required"]["model_source"][0]
|
||||||
|
assert "Auto-download (HuggingFace)" in choices
|
||||||
|
assert "Local path" in choices
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_type():
|
||||||
|
assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_auto_download():
|
||||||
|
loader = OmniVoiceModelLoader()
|
||||||
|
mock_model = MagicMock()
|
||||||
|
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
||||||
|
MockOmniVoice.from_pretrained.return_value = mock_model
|
||||||
|
result = loader.load_model(
|
||||||
|
model_source="Auto-download (HuggingFace)",
|
||||||
|
device="cpu",
|
||||||
|
dtype="float32",
|
||||||
|
local_path="",
|
||||||
|
)
|
||||||
|
assert result == (mock_model,)
|
||||||
|
MockOmniVoice.from_pretrained.assert_called_once()
|
||||||
|
call_kwargs = MockOmniVoice.from_pretrained.call_args
|
||||||
|
assert call_kwargs[0][0] == "k2-fsa/OmniVoice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_local_path():
|
||||||
|
loader = OmniVoiceModelLoader()
|
||||||
|
mock_model = MagicMock()
|
||||||
|
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
|
||||||
|
MockOmniVoice.from_pretrained.return_value = mock_model
|
||||||
|
result = loader.load_model(
|
||||||
|
model_source="Local path",
|
||||||
|
device="cpu",
|
||||||
|
dtype="float32",
|
||||||
|
local_path="/some/local/path",
|
||||||
|
)
|
||||||
|
assert result == (mock_model,)
|
||||||
|
call_args = MockOmniVoice.from_pretrained.call_args[0][0]
|
||||||
|
assert call_args == "/some/local/path"
|
||||||
Reference in New Issue
Block a user