feat: add OmniVoiceGenerate node with voice cloning, design, and auto modes
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+61
-1
@@ -1,2 +1,62 @@
|
|||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
class OmniVoiceGenerate:
|
class OmniVoiceGenerate:
|
||||||
pass
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("OMNIVOICE_MODEL",),
|
||||||
|
"text": ("STRING", {"multiline": True, "default": ""}),
|
||||||
|
"mode": (
|
||||||
|
["voice_cloning", "voice_design", "auto_voice"],
|
||||||
|
{"default": "voice_cloning"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"ref_audio": ("AUDIO",),
|
||||||
|
"ref_text": ("STRING", {"default": ""}),
|
||||||
|
"instruct": ("STRING", {"default": ""}),
|
||||||
|
"speed": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1}),
|
||||||
|
"num_step": ("INT", {"default": 32, "min": 1, "max": 100}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
|
def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32):
|
||||||
|
kwargs = {"text": text, "speed": speed, "num_step": num_step}
|
||||||
|
|
||||||
|
if mode == "voice_cloning" and ref_audio is not None:
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
tmp.close()
|
||||||
|
try:
|
||||||
|
waveform = ref_audio["waveform"].squeeze(0) # (channels, samples)
|
||||||
|
torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"])
|
||||||
|
kwargs["ref_audio"] = tmp_path
|
||||||
|
if ref_text:
|
||||||
|
kwargs["ref_text"] = ref_text
|
||||||
|
audio_tensors = model.generate(**kwargs)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
elif mode == "voice_design" and instruct:
|
||||||
|
kwargs["instruct"] = instruct
|
||||||
|
audio_tensors = model.generate(**kwargs)
|
||||||
|
|
||||||
|
else: # auto_voice or fallback
|
||||||
|
audio_tensors = model.generate(**kwargs)
|
||||||
|
|
||||||
|
# Concatenate chunks: each tensor is (1, T) → concat along T → (1, T_total)
|
||||||
|
combined = torch.cat(audio_tensors, dim=1) # (1, T_total)
|
||||||
|
# ComfyUI AUDIO format: (batch, channels, samples)
|
||||||
|
waveform = combined.unsqueeze(0) # (1, 1, T_total)
|
||||||
|
|
||||||
|
return ({"waveform": waveform, "sample_rate": 24000},)
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
# tests/test_generator.py
|
||||||
|
from unittest.mock import patch, MagicMock, call
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from nodes.generator import OmniVoiceGenerate
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_model(return_tensors=None):
|
||||||
|
mock = MagicMock()
|
||||||
|
if return_tensors is None:
|
||||||
|
return_tensors = [torch.zeros(1, 24000)] # 1 second of silence
|
||||||
|
mock.generate.return_value = return_tensors
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_types_structure():
|
||||||
|
inputs = OmniVoiceGenerate.INPUT_TYPES()
|
||||||
|
required = inputs["required"]
|
||||||
|
assert "model" in required
|
||||||
|
assert "text" in required
|
||||||
|
assert "mode" in required
|
||||||
|
optional = inputs.get("optional", {})
|
||||||
|
assert "ref_audio" in optional
|
||||||
|
assert "ref_text" in optional
|
||||||
|
assert "instruct" in optional
|
||||||
|
assert "speed" in optional
|
||||||
|
assert "num_step" in optional
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_type():
|
||||||
|
assert OmniVoiceGenerate.RETURN_TYPES == ("AUDIO",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_auto_voice():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
mock_model = make_mock_model()
|
||||||
|
result = node.generate(
|
||||||
|
model=mock_model,
|
||||||
|
text="Hello world",
|
||||||
|
mode="auto_voice",
|
||||||
|
speed=1.0,
|
||||||
|
num_step=32,
|
||||||
|
)
|
||||||
|
audio = result[0]
|
||||||
|
assert "waveform" in audio
|
||||||
|
assert "sample_rate" in audio
|
||||||
|
assert audio["sample_rate"] == 24000
|
||||||
|
mock_model.generate.assert_called_once_with(
|
||||||
|
text="Hello world", speed=1.0, num_step=32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_voice_design():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
mock_model = make_mock_model()
|
||||||
|
result = node.generate(
|
||||||
|
model=mock_model,
|
||||||
|
text="Hello world",
|
||||||
|
mode="voice_design",
|
||||||
|
instruct="female, low pitch",
|
||||||
|
speed=1.0,
|
||||||
|
num_step=32,
|
||||||
|
)
|
||||||
|
audio = result[0]
|
||||||
|
assert audio["sample_rate"] == 24000
|
||||||
|
mock_model.generate.assert_called_once_with(
|
||||||
|
text="Hello world", instruct="female, low pitch", speed=1.0, num_step=32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_voice_cloning():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
mock_model = make_mock_model()
|
||||||
|
# Simulate ComfyUI AUDIO input: waveform shape (batch, channels, samples)
|
||||||
|
ref_waveform = torch.zeros(1, 1, 24000)
|
||||||
|
ref_audio_input = {"waveform": ref_waveform, "sample_rate": 24000}
|
||||||
|
|
||||||
|
with patch("nodes.generator.torchaudio.save") as mock_save:
|
||||||
|
result = node.generate(
|
||||||
|
model=mock_model,
|
||||||
|
text="Hello world",
|
||||||
|
mode="voice_cloning",
|
||||||
|
ref_audio=ref_audio_input,
|
||||||
|
ref_text="reference text",
|
||||||
|
speed=1.0,
|
||||||
|
num_step=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_save.called
|
||||||
|
call_kwargs = mock_model.generate.call_args[1]
|
||||||
|
assert call_kwargs["ref_text"] == "reference text"
|
||||||
|
assert "ref_audio" in call_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_waveform_shape():
|
||||||
|
node = OmniVoiceGenerate()
|
||||||
|
# Simulate two chunks returned by OmniVoice
|
||||||
|
chunk1 = torch.zeros(1, 24000)
|
||||||
|
chunk2 = torch.zeros(1, 12000)
|
||||||
|
mock_model = make_mock_model(return_tensors=[chunk1, chunk2])
|
||||||
|
result = node.generate(
|
||||||
|
model=mock_model, text="Long text", mode="auto_voice", speed=1.0, num_step=32
|
||||||
|
)
|
||||||
|
waveform = result[0]["waveform"]
|
||||||
|
# Shape must be (batch=1, channels=1, samples=36000)
|
||||||
|
assert waveform.shape == (1, 1, 36000)
|
||||||
Reference in New Issue
Block a user