fix: add input validation and cpu() guard in OmniVoiceGenerate
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+27
-1
@@ -1,5 +1,5 @@
|
||||
# tests/test_generator.py
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
import pytest
|
||||
from nodes.generator import OmniVoiceGenerate
|
||||
@@ -92,6 +92,32 @@ def test_generate_voice_cloning():
|
||||
assert "ref_audio" in call_kwargs
|
||||
|
||||
|
||||
def test_voice_cloning_without_ref_audio_raises():
|
||||
node = OmniVoiceGenerate()
|
||||
mock_model = make_mock_model()
|
||||
with pytest.raises(ValueError, match="ref_audio"):
|
||||
node.generate(
|
||||
model=mock_model,
|
||||
text="Hello",
|
||||
mode="voice_cloning",
|
||||
speed=1.0,
|
||||
num_step=32,
|
||||
)
|
||||
|
||||
|
||||
def test_voice_design_without_instruct_raises():
|
||||
node = OmniVoiceGenerate()
|
||||
mock_model = make_mock_model()
|
||||
with pytest.raises(ValueError, match="instruct"):
|
||||
node.generate(
|
||||
model=mock_model,
|
||||
text="Hello",
|
||||
mode="voice_design",
|
||||
speed=1.0,
|
||||
num_step=32,
|
||||
)
|
||||
|
||||
|
||||
def test_output_waveform_shape():
|
||||
node = OmniVoiceGenerate()
|
||||
# Simulate two chunks returned by OmniVoice
|
||||
|
||||
Reference in New Issue
Block a user