diff --git a/nodes/generator.py b/nodes/generator.py index 32a02a9..5239f90 100644 --- a/nodes/generator.py +++ b/nodes/generator.py @@ -33,12 +33,17 @@ class OmniVoiceGenerate: def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32): kwargs = {"text": text, "speed": speed, "num_step": num_step} + if mode == "voice_cloning" and ref_audio is None: + raise ValueError("voice_cloning mode requires ref_audio to be connected") + if mode == "voice_design" and not instruct: + raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')") + if mode == "voice_cloning" and ref_audio is not None: tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_path = tmp.name tmp.close() try: - waveform = ref_audio["waveform"].squeeze(0) # (channels, samples) + waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples) torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"]) kwargs["ref_audio"] = tmp_path if ref_text: diff --git a/tests/test_generator.py b/tests/test_generator.py index eda3a12..6c1a306 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,5 +1,5 @@ # tests/test_generator.py -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock import torch import pytest from nodes.generator import OmniVoiceGenerate @@ -92,6 +92,32 @@ def test_generate_voice_cloning(): assert "ref_audio" in call_kwargs +def test_voice_cloning_without_ref_audio_raises(): + node = OmniVoiceGenerate() + mock_model = make_mock_model() + with pytest.raises(ValueError, match="ref_audio"): + node.generate( + model=mock_model, + text="Hello", + mode="voice_cloning", + speed=1.0, + num_step=32, + ) + + +def test_voice_design_without_instruct_raises(): + node = OmniVoiceGenerate() + mock_model = make_mock_model() + with pytest.raises(ValueError, match="instruct"): + node.generate( + model=mock_model, + text="Hello", + mode="voice_design", + speed=1.0, + num_step=32, + ) + + def test_output_waveform_shape(): node = OmniVoiceGenerate() # Simulate two chunks returned by OmniVoice