feat: add multi-speaker generation with JS-powered dynamic slots
- Add OmniVoiceSpeaker node (label + ref_audio + ref_text → OMNIVOICE_SPEAKER) - Add OmniVoiceSpeakers node (roster with dynamic speaker_N inputs driven by num_speakers INT widget; slots expand/collapse via ComfyUI JS extension) - Add web/multi_speaker.js: ComfyUI extension that hooks onNodeCreated and onConfigure to sync speaker_N inputs in real time (max 8 speakers) - Extend OmniVoiceGenerate with optional speakers (OMNIVOICE_SPEAKERS) input; when connected it routes each paragraph to the assigned speaker and concatenates the results — supports alternate_paragraphs and tagged_speakers modes - Remove OmniVoiceMultiSpeakerGenerate (generation now lives in the existing Generate node) - Refactor generator.py: extract _write_tmp_wav helper, add _tensors_to_audio Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+8
-2
@@ -1,4 +1,4 @@
|
|||||||
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate, OmniVoiceEpubLoader, OmniVoiceVoicePreset, OmniVoiceMixVoices, OmniVoiceVoiceDesign
|
from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate, OmniVoiceEpubLoader, OmniVoiceVoicePreset, OmniVoiceMixVoices, OmniVoiceVoiceDesign, OmniVoiceSpeaker, OmniVoiceSpeakers
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"OmniVoiceModelLoader": OmniVoiceModelLoader,
|
"OmniVoiceModelLoader": OmniVoiceModelLoader,
|
||||||
@@ -7,6 +7,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"OmniVoiceVoicePreset": OmniVoiceVoicePreset,
|
"OmniVoiceVoicePreset": OmniVoiceVoicePreset,
|
||||||
"OmniVoiceMixVoices": OmniVoiceMixVoices,
|
"OmniVoiceMixVoices": OmniVoiceMixVoices,
|
||||||
"OmniVoiceVoiceDesign": OmniVoiceVoiceDesign,
|
"OmniVoiceVoiceDesign": OmniVoiceVoiceDesign,
|
||||||
|
"OmniVoiceSpeaker": OmniVoiceSpeaker,
|
||||||
|
"OmniVoiceSpeakers": OmniVoiceSpeakers,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -16,6 +18,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"OmniVoiceVoicePreset": "OmniVoice Voice Preset",
|
"OmniVoiceVoicePreset": "OmniVoice Voice Preset",
|
||||||
"OmniVoiceMixVoices": "OmniVoice Mix Voices",
|
"OmniVoiceMixVoices": "OmniVoice Mix Voices",
|
||||||
"OmniVoiceVoiceDesign": "OmniVoice Voice Design",
|
"OmniVoiceVoiceDesign": "OmniVoice Voice Design",
|
||||||
|
"OmniVoiceSpeaker": "OmniVoice Speaker",
|
||||||
|
"OmniVoiceSpeakers": "OmniVoice Speakers",
|
||||||
}
|
}
|
||||||
|
|
||||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
WEB_DIRECTORY = "./web"
|
||||||
|
|
||||||
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
||||||
|
|||||||
+2
-1
@@ -4,5 +4,6 @@ from .epub_loader import OmniVoiceEpubLoader
|
|||||||
from .voice_presets import OmniVoiceVoicePreset
|
from .voice_presets import OmniVoiceVoicePreset
|
||||||
from .mix_voices import OmniVoiceMixVoices
|
from .mix_voices import OmniVoiceMixVoices
|
||||||
from .voice_design import OmniVoiceVoiceDesign
|
from .voice_design import OmniVoiceVoiceDesign
|
||||||
|
from .multi_speaker import OmniVoiceSpeaker, OmniVoiceSpeakers
|
||||||
|
|
||||||
__all__ = ["OmniVoiceModelLoader", "OmniVoiceGenerate", "OmniVoiceEpubLoader", "OmniVoiceVoicePreset", "OmniVoiceMixVoices", "OmniVoiceVoiceDesign"]
|
__all__ = ["OmniVoiceModelLoader", "OmniVoiceGenerate", "OmniVoiceEpubLoader", "OmniVoiceVoicePreset", "OmniVoiceMixVoices", "OmniVoiceVoiceDesign", "OmniVoiceSpeaker", "OmniVoiceSpeakers"]
|
||||||
|
|||||||
+95
-13
@@ -1,8 +1,26 @@
|
|||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
_TAG_RE = re.compile(r'^\[([^\]]+)\]\s*(.*)', re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_tmp_wav(ref_audio):
|
||||||
|
"""Write a ComfyUI AUDIO dict to a temp WAV file. Returns the path (caller must delete)."""
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
tmp.close()
|
||||||
|
waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples)
|
||||||
|
audio_np = waveform.numpy()
|
||||||
|
sf.write(
|
||||||
|
tmp_path,
|
||||||
|
audio_np[0] if audio_np.shape[0] == 1 else audio_np.T,
|
||||||
|
int(ref_audio["sample_rate"]),
|
||||||
|
)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
class OmniVoiceGenerate:
|
class OmniVoiceGenerate:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -49,12 +67,21 @@ class OmniVoiceGenerate:
|
|||||||
"tooltip": (
|
"tooltip": (
|
||||||
"voice_cloning – clone the voice from ref_audio (requires ref_audio)\n"
|
"voice_cloning – clone the voice from ref_audio (requires ref_audio)\n"
|
||||||
"voice_design – describe a voice with the instruct field (requires instruct)\n"
|
"voice_design – describe a voice with the instruct field (requires instruct)\n"
|
||||||
"auto_voice – model picks a voice automatically"
|
"auto_voice – model picks a voice automatically\n"
|
||||||
|
"\n"
|
||||||
|
"Ignored when a Speakers roster is connected."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
|
"speakers": ("OMNIVOICE_SPEAKERS", {
|
||||||
|
"tooltip": (
|
||||||
|
"Connect an OmniVoice Speakers node to enable multi-speaker generation.\n"
|
||||||
|
"When connected, ref_audio / instruct / mode are ignored and each paragraph\n"
|
||||||
|
"is routed to its assigned speaker automatically."
|
||||||
|
),
|
||||||
|
}),
|
||||||
"ref_audio": ("AUDIO", {
|
"ref_audio": ("AUDIO", {
|
||||||
"tooltip": "Reference audio clip to clone the voice from. Used in voice_cloning mode.",
|
"tooltip": "Reference audio clip to clone the voice from. Used in voice_cloning mode.",
|
||||||
}),
|
}),
|
||||||
@@ -113,10 +140,16 @@ class OmniVoiceGenerate:
|
|||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
CATEGORY = "OmniVoice"
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
def generate(self, model, text, mode, ref_audio=None, ref_text="",
|
def generate(self, model, text, mode, speakers=None, ref_audio=None, ref_text="",
|
||||||
instruct="", guidance_scale=2.0, speed=1.0, num_step=32, seed=0):
|
instruct="", guidance_scale=2.0, speed=1.0, num_step=32, seed=0):
|
||||||
if seed != 0:
|
if seed != 0:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if speakers is not None:
|
||||||
|
return self._generate_multi_speaker(
|
||||||
|
model, text, speakers, guidance_scale, speed, num_step
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = {"text": text, "speed": speed, "num_step": num_step, "guidance_scale": guidance_scale}
|
kwargs = {"text": text, "speed": speed, "num_step": num_step, "guidance_scale": guidance_scale}
|
||||||
|
|
||||||
if mode == "voice_cloning" and ref_audio is None:
|
if mode == "voice_cloning" and ref_audio is None:
|
||||||
@@ -125,14 +158,8 @@ class OmniVoiceGenerate:
|
|||||||
raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')")
|
raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')")
|
||||||
|
|
||||||
if mode == "voice_cloning":
|
if mode == "voice_cloning":
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
tmp_path = _write_tmp_wav(ref_audio)
|
||||||
tmp_path = tmp.name
|
|
||||||
tmp.close()
|
|
||||||
try:
|
try:
|
||||||
ref_waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples)
|
|
||||||
audio_np = ref_waveform.numpy()
|
|
||||||
# soundfile expects (samples,) for mono or (samples, channels) for multi-channel
|
|
||||||
sf.write(tmp_path, audio_np[0] if audio_np.shape[0] == 1 else audio_np.T, int(ref_audio["sample_rate"]))
|
|
||||||
kwargs["ref_audio"] = tmp_path
|
kwargs["ref_audio"] = tmp_path
|
||||||
if ref_text:
|
if ref_text:
|
||||||
kwargs["ref_text"] = ref_text
|
kwargs["ref_text"] = ref_text
|
||||||
@@ -152,9 +179,64 @@ class OmniVoiceGenerate:
|
|||||||
else: # auto_voice or fallback
|
else: # auto_voice or fallback
|
||||||
audio_tensors = model.generate(**kwargs)
|
audio_tensors = model.generate(**kwargs)
|
||||||
|
|
||||||
# Concatenate chunks: each tensor is (1, T) → concat along T → (1, T_total)
|
return self._tensors_to_audio(audio_tensors)
|
||||||
combined = torch.cat(audio_tensors, dim=1).cpu() # (1, T_total) on CPU
|
|
||||||
# ComfyUI AUDIO format: (batch, channels, samples)
|
|
||||||
waveform = combined.unsqueeze(0) # (1, 1, T_total)
|
|
||||||
|
|
||||||
|
def _generate_multi_speaker(self, model, text, speakers_data, guidance_scale, speed, num_step):
|
||||||
|
speaker_list = speakers_data["speakers"]
|
||||||
|
spk_mode = speakers_data["mode"]
|
||||||
|
label_map = {s["label"].lower(): i for i, s in enumerate(speaker_list)}
|
||||||
|
|
||||||
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
|
if not paragraphs:
|
||||||
|
raise ValueError("OmniVoice Multi-Speaker: no paragraphs found in text.")
|
||||||
|
|
||||||
|
if spk_mode == "alternate_paragraphs":
|
||||||
|
segments = [
|
||||||
|
(para, speaker_list[i % len(speaker_list)])
|
||||||
|
for i, para in enumerate(paragraphs)
|
||||||
|
]
|
||||||
|
else: # tagged_speakers
|
||||||
|
segments = []
|
||||||
|
for para in paragraphs:
|
||||||
|
m = _TAG_RE.match(para)
|
||||||
|
if m:
|
||||||
|
tag = m.group(1).strip().lower()
|
||||||
|
body = m.group(2).strip()
|
||||||
|
spk = speaker_list[label_map.get(tag, 0)]
|
||||||
|
else:
|
||||||
|
body = para
|
||||||
|
spk = speaker_list[0]
|
||||||
|
if body:
|
||||||
|
segments.append((body, spk))
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
raise ValueError("OmniVoice Multi-Speaker: no text segments to generate.")
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
for para_text, spk in segments:
|
||||||
|
tmp_path = _write_tmp_wav(spk["ref_audio"])
|
||||||
|
try:
|
||||||
|
kwargs = {
|
||||||
|
"text": para_text,
|
||||||
|
"ref_audio": tmp_path,
|
||||||
|
"speed": speed,
|
||||||
|
"num_step": num_step,
|
||||||
|
"guidance_scale": guidance_scale,
|
||||||
|
}
|
||||||
|
if spk["ref_text"]:
|
||||||
|
kwargs["ref_text"] = spk["ref_text"]
|
||||||
|
chunks = model.generate(**kwargs)
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self._tensors_to_audio(all_chunks)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tensors_to_audio(tensors):
|
||||||
|
combined = torch.cat(tensors, dim=1).cpu() # (1, T_total)
|
||||||
|
waveform = combined.unsqueeze(0) # (1, 1, T_total)
|
||||||
return ({"waveform": waveform, "sample_rate": 24000},)
|
return ({"waveform": waveform, "sample_rate": 24000},)
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
class OmniVoiceSpeaker:
|
||||||
|
"""Bundle a label, reference audio, and optional transcript into a speaker slot."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"label": ("STRING", {
|
||||||
|
"default": "Narrator",
|
||||||
|
"tooltip": (
|
||||||
|
"Name used to identify this speaker.\n"
|
||||||
|
"In tagged_speakers mode, prefix paragraphs with [Label]:\n"
|
||||||
|
" [Narrator] Once upon a time...\n"
|
||||||
|
"In alternate_paragraphs mode the label is informational only."
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
"ref_audio": ("AUDIO", {
|
||||||
|
"tooltip": "Reference audio clip for this speaker's voice.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"ref_text": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Transcript of ref_audio. Improves cloning quality.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("OMNIVOICE_SPEAKER",)
|
||||||
|
RETURN_NAMES = ("speaker",)
|
||||||
|
FUNCTION = "build"
|
||||||
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
|
def build(self, label, ref_audio, ref_text=""):
|
||||||
|
return ({"label": label, "ref_audio": ref_audio, "ref_text": ref_text},)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniVoiceSpeakers:
|
||||||
|
"""Collect multiple speakers into a roster for multi-speaker generation.
|
||||||
|
|
||||||
|
The number of speaker input slots expands dynamically when num_speakers changes
|
||||||
|
(requires the OmniVoice web extension to be loaded by ComfyUI).
|
||||||
|
Connect one OmniVoice Speaker node per slot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"num_speakers": ("INT", {
|
||||||
|
"default": 2, "min": 2, "max": 8, "step": 1,
|
||||||
|
"tooltip": (
|
||||||
|
"Number of active speaker slots.\n"
|
||||||
|
"Changing this value adds or removes speaker_N inputs on the node."
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
"mode": (
|
||||||
|
["alternate_paragraphs", "tagged_speakers"],
|
||||||
|
{
|
||||||
|
"default": "alternate_paragraphs",
|
||||||
|
"tooltip": (
|
||||||
|
"alternate_paragraphs – paragraphs (separated by blank lines) rotate\n"
|
||||||
|
" through speakers in order: 1 → 2 → 3 → 1 → …\n"
|
||||||
|
"\n"
|
||||||
|
"tagged_speakers – prefix each paragraph with [Label] to assign\n"
|
||||||
|
" a specific speaker. Labels must match those on the Speaker nodes.\n"
|
||||||
|
" Unrecognised tags fall back to speaker 1.\n"
|
||||||
|
"\n"
|
||||||
|
" Example:\n"
|
||||||
|
" [Narrator] The door creaked open.\n"
|
||||||
|
"\n"
|
||||||
|
" [Alice] Who is there?"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
# speaker_1 … speaker_8 are added/removed dynamically by the JS extension.
|
||||||
|
# They are not listed here so ComfyUI does not render them as static widgets.
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("OMNIVOICE_SPEAKERS",)
|
||||||
|
RETURN_NAMES = ("speakers",)
|
||||||
|
FUNCTION = "build"
|
||||||
|
CATEGORY = "OmniVoice"
|
||||||
|
|
||||||
|
def build(self, num_speakers, mode, **kwargs):
|
||||||
|
speakers = []
|
||||||
|
for i in range(1, num_speakers + 1):
|
||||||
|
spk = kwargs.get(f"speaker_{i}")
|
||||||
|
if spk is not None:
|
||||||
|
speakers.append(spk)
|
||||||
|
if len(speakers) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"OmniVoice Speakers: at least 2 speakers must be connected "
|
||||||
|
f"(got {len(speakers)})."
|
||||||
|
)
|
||||||
|
return ({"speakers": speakers, "mode": mode},)
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
|
const MAX_SPEAKERS = 8;
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "OmniVoice.MultiSpeaker",
|
||||||
|
|
||||||
|
beforeRegisterNodeDef(nodeType, nodeData) {
|
||||||
|
if (nodeData.name !== "OmniVoiceSpeakers") return;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ensure the node has exactly `count` speaker_N inputs.
|
||||||
|
* Safe to call multiple times with the same count (idempotent).
|
||||||
|
*/
|
||||||
|
function syncSpeakerInputs(node, count) {
|
||||||
|
count = Math.max(2, Math.min(MAX_SPEAKERS, Math.floor(count)));
|
||||||
|
|
||||||
|
// Add any missing slots
|
||||||
|
for (let i = 1; i <= count; i++) {
|
||||||
|
const name = `speaker_${i}`;
|
||||||
|
if (!node.inputs?.find(inp => inp.name === name)) {
|
||||||
|
node.addInput(name, "OMNIVOICE_SPEAKER");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove excess slots (high → low so indices stay valid)
|
||||||
|
for (let i = MAX_SPEAKERS; i > count; i--) {
|
||||||
|
const name = `speaker_${i}`;
|
||||||
|
const idx = node.inputs?.findIndex(inp => inp.name === name) ?? -1;
|
||||||
|
if (idx === -1) continue;
|
||||||
|
// Sever any connected link before removing the slot
|
||||||
|
const linkId = node.inputs[idx].link;
|
||||||
|
if (linkId != null) node.graph?.removeLink(linkId);
|
||||||
|
node.removeInput(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
node.setDirtyCanvas(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attach the num_speakers widget callback once per node instance.
|
||||||
|
* Guarded by a flag so configure() can call it safely on reload.
|
||||||
|
*/
|
||||||
|
function attachCallback(node) {
|
||||||
|
if (node._omnivoiceCbAttached) return;
|
||||||
|
const w = node.widgets?.find(w => w.name === "num_speakers");
|
||||||
|
if (!w) return;
|
||||||
|
node._omnivoiceCbAttached = true;
|
||||||
|
w.callback = (value) => syncSpeakerInputs(node, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Fresh node creation ---
|
||||||
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
onNodeCreated?.apply(this, arguments);
|
||||||
|
attachCallback(this);
|
||||||
|
const w = this.widgets?.find(w => w.name === "num_speakers");
|
||||||
|
if (w) syncSpeakerInputs(this, w.value);
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Workflow load: called by LiteGraph after widget values are restored ---
|
||||||
|
const onConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function (data) {
|
||||||
|
onConfigure?.apply(this, arguments);
|
||||||
|
attachCallback(this);
|
||||||
|
const w = this.widgets?.find(w => w.name === "num_speakers");
|
||||||
|
if (w) syncSpeakerInputs(this, w.value);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user