diff --git a/nodes/generator.py b/nodes/generator.py index 3704373..32a02a9 100644 --- a/nodes/generator.py +++ b/nodes/generator.py @@ -1,2 +1,62 @@ +import tempfile +import os +import torch +import torchaudio + + class OmniVoiceGenerate: - pass + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("OMNIVOICE_MODEL",), + "text": ("STRING", {"multiline": True, "default": ""}), + "mode": ( + ["voice_cloning", "voice_design", "auto_voice"], + {"default": "voice_cloning"}, + ), + }, + "optional": { + "ref_audio": ("AUDIO",), + "ref_text": ("STRING", {"default": ""}), + "instruct": ("STRING", {"default": ""}), + "speed": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1}), + "num_step": ("INT", {"default": 32, "min": 1, "max": 100}), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "generate" + CATEGORY = "OmniVoice" + + def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32): + kwargs = {"text": text, "speed": speed, "num_step": num_step} + + if mode == "voice_cloning" and ref_audio is not None: + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + tmp.close() + try: + waveform = ref_audio["waveform"].squeeze(0) # (channels, samples) + torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"]) + kwargs["ref_audio"] = tmp_path + if ref_text: + kwargs["ref_text"] = ref_text + audio_tensors = model.generate(**kwargs) + finally: + os.unlink(tmp_path) + + elif mode == "voice_design" and instruct: + kwargs["instruct"] = instruct + audio_tensors = model.generate(**kwargs) + + else: # auto_voice or fallback + audio_tensors = model.generate(**kwargs) + + # Concatenate chunks: each tensor is (1, T) → concat along T → (1, T_total) + combined = torch.cat(audio_tensors, dim=1) # (1, T_total) + # ComfyUI AUDIO format: (batch, channels, samples) + waveform = combined.unsqueeze(0) # (1, 1, T_total) + + return ({"waveform": waveform, "sample_rate": 24000},) diff --git a/tests/test_generator.py b/tests/test_generator.py new file mode 100644 index 0000000..eda3a12 --- /dev/null +++ b/tests/test_generator.py @@ -0,0 +1,106 @@ +# tests/test_generator.py +from unittest.mock import patch, MagicMock, call +import torch +import pytest +from nodes.generator import OmniVoiceGenerate + + +def make_mock_model(return_tensors=None): + mock = MagicMock() + if return_tensors is None: + return_tensors = [torch.zeros(1, 24000)] # 1 second of silence + mock.generate.return_value = return_tensors + return mock + + +def test_input_types_structure(): + inputs = OmniVoiceGenerate.INPUT_TYPES() + required = inputs["required"] + assert "model" in required + assert "text" in required + assert "mode" in required + optional = inputs.get("optional", {}) + assert "ref_audio" in optional + assert "ref_text" in optional + assert "instruct" in optional + assert "speed" in optional + assert "num_step" in optional + + +def test_return_type(): + assert OmniVoiceGenerate.RETURN_TYPES == ("AUDIO",) + + +def test_generate_auto_voice(): + node = OmniVoiceGenerate() + mock_model = make_mock_model() + result = node.generate( + model=mock_model, + text="Hello world", + mode="auto_voice", + speed=1.0, + num_step=32, + ) + audio = result[0] + assert "waveform" in audio + assert "sample_rate" in audio + assert audio["sample_rate"] == 24000 + mock_model.generate.assert_called_once_with( + text="Hello world", speed=1.0, num_step=32 + ) + + +def test_generate_voice_design(): + node = OmniVoiceGenerate() + mock_model = make_mock_model() + result = node.generate( + model=mock_model, + text="Hello world", + mode="voice_design", + instruct="female, low pitch", + speed=1.0, + num_step=32, + ) + audio = result[0] + assert audio["sample_rate"] == 24000 + mock_model.generate.assert_called_once_with( + text="Hello world", instruct="female, low pitch", speed=1.0, num_step=32 + ) + + +def test_generate_voice_cloning(): + node = OmniVoiceGenerate() + mock_model = make_mock_model() + # Simulate ComfyUI AUDIO input: waveform shape (batch, channels, samples) + ref_waveform = torch.zeros(1, 1, 24000) + ref_audio_input = {"waveform": ref_waveform, "sample_rate": 24000} + + with patch("nodes.generator.torchaudio.save") as mock_save: + result = node.generate( + model=mock_model, + text="Hello world", + mode="voice_cloning", + ref_audio=ref_audio_input, + ref_text="reference text", + speed=1.0, + num_step=32, + ) + + assert mock_save.called + call_kwargs = mock_model.generate.call_args[1] + assert call_kwargs["ref_text"] == "reference text" + assert "ref_audio" in call_kwargs + + +def test_output_waveform_shape(): + node = OmniVoiceGenerate() + # Simulate two chunks returned by OmniVoice + chunk1 = torch.zeros(1, 24000) + chunk2 = torch.zeros(1, 12000) + mock_model = make_mock_model(return_tensors=[chunk1, chunk2]) + result = node.generate( + model=mock_model, text="Long text", mode="auto_voice", speed=1.0, num_step=32 + ) + waveform = result[0]["waveform"] + # Shape must be (batch=1, channels=1, samples=36000) + assert waveform.shape == (1, 1, 36000)