fix: add input validation and cpu() guard in OmniVoiceGenerate
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+6
-1
@@ -33,12 +33,17 @@ class OmniVoiceGenerate:
|
|||||||
def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32):
|
def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32):
|
||||||
kwargs = {"text": text, "speed": speed, "num_step": num_step}
|
kwargs = {"text": text, "speed": speed, "num_step": num_step}
|
||||||
|
|
||||||
|
if mode == "voice_cloning" and ref_audio is None:
|
||||||
|
raise ValueError("voice_cloning mode requires ref_audio to be connected")
|
||||||
|
if mode == "voice_design" and not instruct:
|
||||||
|
raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')")
|
||||||
|
|
||||||
if mode == "voice_cloning" and ref_audio is not None:
|
if mode == "voice_cloning" and ref_audio is not None:
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
tmp.close()
|
tmp.close()
|
||||||
try:
|
try:
|
||||||
waveform = ref_audio["waveform"].squeeze(0) # (channels, samples)
|
waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples)
|
||||||
torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"])
|
torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"])
|
||||||
kwargs["ref_audio"] = tmp_path
|
kwargs["ref_audio"] = tmp_path
|
||||||
if ref_text:
|
if ref_text:
|
||||||
|
|||||||
+27
-1
@@ -1,5 +1,5 @@
|
|||||||
# tests/test_generator.py
|
# tests/test_generator.py
|
||||||
from unittest.mock import patch, MagicMock, call
|
from unittest.mock import patch, MagicMock
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from nodes.generator import OmniVoiceGenerate
|
from nodes.generator import OmniVoiceGenerate
|
||||||
@@ -92,6 +92,32 @@ def test_generate_voice_cloning():
|
|||||||
assert "ref_audio" in call_kwargs
|
assert "ref_audio" in call_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_cloning_without_ref_audio_raises():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
mock_model = make_mock_model()
|
||||||
|
with pytest.raises(ValueError, match="ref_audio"):
|
||||||
|
node.generate(
|
||||||
|
model=mock_model,
|
||||||
|
text="Hello",
|
||||||
|
mode="voice_cloning",
|
||||||
|
speed=1.0,
|
||||||
|
num_step=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_voice_design_without_instruct_raises():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
mock_model = make_mock_model()
|
||||||
|
with pytest.raises(ValueError, match="instruct"):
|
||||||
|
node.generate(
|
||||||
|
model=mock_model,
|
||||||
|
text="Hello",
|
||||||
|
mode="voice_design",
|
||||||
|
speed=1.0,
|
||||||
|
num_step=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_output_waveform_shape():
|
def test_output_waveform_shape():
|
||||||
node = OmniVoiceGenerate()
|
node = OmniVoiceGenerate()
|
||||||
# Simulate two chunks returned by OmniVoice
|
# Simulate two chunks returned by OmniVoice
|
||||||
|
|||||||
Reference in New Issue
Block a user