feat: add OmniVoiceModelLoader node

Implements OmniVoiceModelLoader with INPUT_TYPES, RETURN_TYPES, and
load_model supporting both HuggingFace auto-download and local path
sources. Adds TDD test suite and pytest infrastructure (conftest.py,
pytest.ini) to enable testing outside ComfyUI without omnivoice installed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 08:52:26 +02:00
parent 0ed43a83ca
commit 069169485d
6 changed files with 144 additions and 10 deletions
+14 -9
View File
@@ -1,13 +1,18 @@
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate
try:
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate
NODE_CLASS_MAPPINGS = {
"OmniVoiceModelLoader": OmniVoiceModelLoader,
"OmniVoiceGenerate": OmniVoiceGenerate,
}
NODE_CLASS_MAPPINGS = {
"OmniVoiceModelLoader": OmniVoiceModelLoader,
"OmniVoiceGenerate": OmniVoiceGenerate,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OmniVoiceModelLoader": "OmniVoice Model Loader",
"OmniVoiceGenerate": "OmniVoice Generate",
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OmniVoiceModelLoader": "OmniVoice Model Loader",
"OmniVoiceGenerate": "OmniVoice Generate",
}
except ImportError:
# Graceful fallback when loaded outside of a package context (e.g. pytest)
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
+6
View File
@@ -0,0 +1,6 @@
import sys
import os
# Ensure the project root is on sys.path so `nodes.loader` can be imported
# as a top-level package without needing to install the package.
sys.path.insert(0, os.path.dirname(__file__))
+62 -1
View File
@@ -1,2 +1,63 @@
import os
import torch
try:
from omnivoice import OmniVoice
except ImportError:
OmniVoice = None # type: ignore[assignment,misc]
try:
import folder_paths
CACHE_DIR = os.path.join(folder_paths.models_dir, "omnivoice")
except ImportError:
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "omnivoice")
DTYPE_MAP = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
class OmniVoiceModelLoader:
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_source": (
["Auto-download (HuggingFace)", "Local path"],
{"default": "Auto-download (HuggingFace)"},
),
"device": (
["cuda:0", "cuda:1", "cpu"],
{"default": "cuda:0"},
),
"dtype": (
["float16", "bfloat16", "float32"],
{"default": "float16"},
),
},
"optional": {
"local_path": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("OMNIVOICE_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "OmniVoice"
def load_model(self, model_source, device, dtype, local_path=""):
if model_source == "Local path" and local_path:
source = local_path
else:
source = "k2-fsa/OmniVoice"
model = OmniVoice.from_pretrained(
source,
device_map=device,
dtype=DTYPE_MAP[dtype],
cache_dir=CACHE_DIR,
)
return (model,)
+3
View File
@@ -0,0 +1,3 @@
[pytest]
testpaths = tests
addopts = --import-mode=importlib
View File
+59
View File
@@ -0,0 +1,59 @@
# tests/test_loader.py
from unittest.mock import patch, MagicMock
import torch
import pytest
from nodes.loader import OmniVoiceModelLoader
def test_input_types_structure():
inputs = OmniVoiceModelLoader.INPUT_TYPES()
required = inputs["required"]
assert "model_source" in required
assert "device" in required
assert "dtype" in required
optional = inputs.get("optional", {})
assert "local_path" in optional
def test_input_types_model_source_choices():
inputs = OmniVoiceModelLoader.INPUT_TYPES()
choices = inputs["required"]["model_source"][0]
assert "Auto-download (HuggingFace)" in choices
assert "Local path" in choices
def test_return_type():
assert OmniVoiceModelLoader.RETURN_TYPES == ("OMNIVOICE_MODEL",)
def test_load_model_auto_download():
loader = OmniVoiceModelLoader()
mock_model = MagicMock()
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
MockOmniVoice.from_pretrained.return_value = mock_model
result = loader.load_model(
model_source="Auto-download (HuggingFace)",
device="cpu",
dtype="float32",
local_path="",
)
assert result == (mock_model,)
MockOmniVoice.from_pretrained.assert_called_once()
call_kwargs = MockOmniVoice.from_pretrained.call_args
assert call_kwargs[0][0] == "k2-fsa/OmniVoice"
def test_load_model_local_path():
loader = OmniVoiceModelLoader()
mock_model = MagicMock()
with patch("nodes.loader.OmniVoice") as MockOmniVoice:
MockOmniVoice.from_pretrained.return_value = mock_model
result = loader.load_model(
model_source="Local path",
device="cpu",
dtype="float32",
local_path="/some/local/path",
)
assert result == (mock_model,)
call_args = MockOmniVoice.from_pretrained.call_args[0][0]
assert call_args == "/some/local/path"