From 95cf706b19257d5ba7f94b3bda740e8c8d8294aa Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 6 Apr 2026 09:08:23 +0200 Subject: [PATCH] feat: add multi-speaker generation with JS-powered dynamic slots MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add OmniVoiceSpeaker node (label + ref_audio + ref_text → OMNIVOICE_SPEAKER) - Add OmniVoiceSpeakers node (roster with dynamic speaker_N inputs driven by num_speakers INT widget; slots expand/collapse via ComfyUI JS extension) - Add web/multi_speaker.js: ComfyUI extension that hooks onNodeCreated and onConfigure to sync speaker_N inputs in real time (max 8 speakers) - Extend OmniVoiceGenerate with optional speakers (OMNIVOICE_SPEAKERS) input; when connected it routes each paragraph to the assigned speaker and concatenates the results — supports alternate_paragraphs and tagged_speakers modes - Remove OmniVoiceMultiSpeakerGenerate (generation now lives in the existing Generate node) - Refactor generator.py: extract _write_tmp_wav helper, add _tensors_to_audio Co-Authored-By: Claude Sonnet 4.6 --- __init__.py | 10 +++- nodes/__init__.py | 3 +- nodes/generator.py | 108 ++++++++++++++++++++++++++++++++++++----- nodes/multi_speaker.py | 97 ++++++++++++++++++++++++++++++++++++ web/multi_speaker.js | 70 ++++++++++++++++++++++++++ 5 files changed, 272 insertions(+), 16 deletions(-) create mode 100644 nodes/multi_speaker.py create mode 100644 web/multi_speaker.js diff --git a/__init__.py b/__init__.py index b5e605b..264ed0b 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate, OmniVoiceEpubLoader, OmniVoiceVoicePreset, OmniVoiceMixVoices, OmniVoiceVoiceDesign +from .nodes import OmniVoiceModelLoader, OmniVoiceGenerate, OmniVoiceEpubLoader, OmniVoiceVoicePreset, OmniVoiceMixVoices, OmniVoiceVoiceDesign, OmniVoiceSpeaker, OmniVoiceSpeakers NODE_CLASS_MAPPINGS = { "OmniVoiceModelLoader": OmniVoiceModelLoader, @@ -7,6 +7,8 @@ NODE_CLASS_MAPPINGS = { "OmniVoiceVoicePreset": OmniVoiceVoicePreset, "OmniVoiceMixVoices": OmniVoiceMixVoices, "OmniVoiceVoiceDesign": OmniVoiceVoiceDesign, + "OmniVoiceSpeaker": OmniVoiceSpeaker, + "OmniVoiceSpeakers": OmniVoiceSpeakers, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -16,6 +18,10 @@ NODE_DISPLAY_NAME_MAPPINGS = { "OmniVoiceVoicePreset": "OmniVoice Voice Preset", "OmniVoiceMixVoices": "OmniVoice Mix Voices", "OmniVoiceVoiceDesign": "OmniVoice Voice Design", + "OmniVoiceSpeaker": "OmniVoice Speaker", + "OmniVoiceSpeakers": "OmniVoice Speakers", } -__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] +WEB_DIRECTORY = "./web" + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] diff --git a/nodes/__init__.py b/nodes/__init__.py index dd8be47..7db320c 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -4,5 +4,6 @@ from .epub_loader import OmniVoiceEpubLoader from .voice_presets import OmniVoiceVoicePreset from .mix_voices import OmniVoiceMixVoices from .voice_design import OmniVoiceVoiceDesign +from .multi_speaker import OmniVoiceSpeaker, OmniVoiceSpeakers -__all__ = ["OmniVoiceModelLoader", "OmniVoiceGenerate", "OmniVoiceEpubLoader", "OmniVoiceVoicePreset", "OmniVoiceMixVoices", "OmniVoiceVoiceDesign"] +__all__ = ["OmniVoiceModelLoader", "OmniVoiceGenerate", "OmniVoiceEpubLoader", "OmniVoiceVoicePreset", "OmniVoiceMixVoices", "OmniVoiceVoiceDesign", "OmniVoiceSpeaker", "OmniVoiceSpeakers"] diff --git a/nodes/generator.py b/nodes/generator.py index 9aecb2e..e32f978 100644 --- a/nodes/generator.py +++ b/nodes/generator.py @@ -1,8 +1,26 @@ +import re import tempfile import os import torch import soundfile as sf +_TAG_RE = re.compile(r'^\[([^\]]+)\]\s*(.*)', re.DOTALL) + + +def _write_tmp_wav(ref_audio): + """Write a ComfyUI AUDIO dict to a temp WAV file. Returns the path (caller must delete).""" + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + tmp.close() + waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples) + audio_np = waveform.numpy() + sf.write( + tmp_path, + audio_np[0] if audio_np.shape[0] == 1 else audio_np.T, + int(ref_audio["sample_rate"]), + ) + return tmp_path + class OmniVoiceGenerate: @classmethod @@ -49,12 +67,21 @@ class OmniVoiceGenerate: "tooltip": ( "voice_cloning – clone the voice from ref_audio (requires ref_audio)\n" "voice_design – describe a voice with the instruct field (requires instruct)\n" - "auto_voice – model picks a voice automatically" + "auto_voice – model picks a voice automatically\n" + "\n" + "Ignored when a Speakers roster is connected." ), }, ), }, "optional": { + "speakers": ("OMNIVOICE_SPEAKERS", { + "tooltip": ( + "Connect an OmniVoice Speakers node to enable multi-speaker generation.\n" + "When connected, ref_audio / instruct / mode are ignored and each paragraph\n" + "is routed to its assigned speaker automatically." + ), + }), "ref_audio": ("AUDIO", { "tooltip": "Reference audio clip to clone the voice from. Used in voice_cloning mode.", }), @@ -113,10 +140,16 @@ class OmniVoiceGenerate: FUNCTION = "generate" CATEGORY = "OmniVoice" - def generate(self, model, text, mode, ref_audio=None, ref_text="", + def generate(self, model, text, mode, speakers=None, ref_audio=None, ref_text="", instruct="", guidance_scale=2.0, speed=1.0, num_step=32, seed=0): if seed != 0: torch.manual_seed(seed) + + if speakers is not None: + return self._generate_multi_speaker( + model, text, speakers, guidance_scale, speed, num_step + ) + kwargs = {"text": text, "speed": speed, "num_step": num_step, "guidance_scale": guidance_scale} if mode == "voice_cloning" and ref_audio is None: @@ -125,14 +158,8 @@ class OmniVoiceGenerate: raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')") if mode == "voice_cloning": - tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) - tmp_path = tmp.name - tmp.close() + tmp_path = _write_tmp_wav(ref_audio) try: - ref_waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples) - audio_np = ref_waveform.numpy() - # soundfile expects (samples,) for mono or (samples, channels) for multi-channel - sf.write(tmp_path, audio_np[0] if audio_np.shape[0] == 1 else audio_np.T, int(ref_audio["sample_rate"])) kwargs["ref_audio"] = tmp_path if ref_text: kwargs["ref_text"] = ref_text @@ -152,9 +179,64 @@ class OmniVoiceGenerate: else: # auto_voice or fallback audio_tensors = model.generate(**kwargs) - # Concatenate chunks: each tensor is (1, T) → concat along T → (1, T_total) - combined = torch.cat(audio_tensors, dim=1).cpu() # (1, T_total) on CPU - # ComfyUI AUDIO format: (batch, channels, samples) - waveform = combined.unsqueeze(0) # (1, 1, T_total) + return self._tensors_to_audio(audio_tensors) + def _generate_multi_speaker(self, model, text, speakers_data, guidance_scale, speed, num_step): + speaker_list = speakers_data["speakers"] + spk_mode = speakers_data["mode"] + label_map = {s["label"].lower(): i for i, s in enumerate(speaker_list)} + + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] + if not paragraphs: + raise ValueError("OmniVoice Multi-Speaker: no paragraphs found in text.") + + if spk_mode == "alternate_paragraphs": + segments = [ + (para, speaker_list[i % len(speaker_list)]) + for i, para in enumerate(paragraphs) + ] + else: # tagged_speakers + segments = [] + for para in paragraphs: + m = _TAG_RE.match(para) + if m: + tag = m.group(1).strip().lower() + body = m.group(2).strip() + spk = speaker_list[label_map.get(tag, 0)] + else: + body = para + spk = speaker_list[0] + if body: + segments.append((body, spk)) + + if not segments: + raise ValueError("OmniVoice Multi-Speaker: no text segments to generate.") + + all_chunks = [] + for para_text, spk in segments: + tmp_path = _write_tmp_wav(spk["ref_audio"]) + try: + kwargs = { + "text": para_text, + "ref_audio": tmp_path, + "speed": speed, + "num_step": num_step, + "guidance_scale": guidance_scale, + } + if spk["ref_text"]: + kwargs["ref_text"] = spk["ref_text"] + chunks = model.generate(**kwargs) + all_chunks.extend(chunks) + finally: + try: + os.unlink(tmp_path) + except OSError: + pass + + return self._tensors_to_audio(all_chunks) + + @staticmethod + def _tensors_to_audio(tensors): + combined = torch.cat(tensors, dim=1).cpu() # (1, T_total) + waveform = combined.unsqueeze(0) # (1, 1, T_total) return ({"waveform": waveform, "sample_rate": 24000},) diff --git a/nodes/multi_speaker.py b/nodes/multi_speaker.py new file mode 100644 index 0000000..96918cd --- /dev/null +++ b/nodes/multi_speaker.py @@ -0,0 +1,97 @@ +class OmniVoiceSpeaker: + """Bundle a label, reference audio, and optional transcript into a speaker slot.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "label": ("STRING", { + "default": "Narrator", + "tooltip": ( + "Name used to identify this speaker.\n" + "In tagged_speakers mode, prefix paragraphs with [Label]:\n" + " [Narrator] Once upon a time...\n" + "In alternate_paragraphs mode the label is informational only." + ), + }), + "ref_audio": ("AUDIO", { + "tooltip": "Reference audio clip for this speaker's voice.", + }), + }, + "optional": { + "ref_text": ("STRING", { + "default": "", + "tooltip": "Transcript of ref_audio. Improves cloning quality.", + }), + }, + } + + RETURN_TYPES = ("OMNIVOICE_SPEAKER",) + RETURN_NAMES = ("speaker",) + FUNCTION = "build" + CATEGORY = "OmniVoice" + + def build(self, label, ref_audio, ref_text=""): + return ({"label": label, "ref_audio": ref_audio, "ref_text": ref_text},) + + +class OmniVoiceSpeakers: + """Collect multiple speakers into a roster for multi-speaker generation. + + The number of speaker input slots expands dynamically when num_speakers changes + (requires the OmniVoice web extension to be loaded by ComfyUI). + Connect one OmniVoice Speaker node per slot. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "num_speakers": ("INT", { + "default": 2, "min": 2, "max": 8, "step": 1, + "tooltip": ( + "Number of active speaker slots.\n" + "Changing this value adds or removes speaker_N inputs on the node." + ), + }), + "mode": ( + ["alternate_paragraphs", "tagged_speakers"], + { + "default": "alternate_paragraphs", + "tooltip": ( + "alternate_paragraphs – paragraphs (separated by blank lines) rotate\n" + " through speakers in order: 1 → 2 → 3 → 1 → …\n" + "\n" + "tagged_speakers – prefix each paragraph with [Label] to assign\n" + " a specific speaker. Labels must match those on the Speaker nodes.\n" + " Unrecognised tags fall back to speaker 1.\n" + "\n" + " Example:\n" + " [Narrator] The door creaked open.\n" + "\n" + " [Alice] Who is there?" + ), + }, + ), + }, + # speaker_1 … speaker_8 are added/removed dynamically by the JS extension. + # They are not listed here so ComfyUI does not render them as static widgets. + } + + RETURN_TYPES = ("OMNIVOICE_SPEAKERS",) + RETURN_NAMES = ("speakers",) + FUNCTION = "build" + CATEGORY = "OmniVoice" + + def build(self, num_speakers, mode, **kwargs): + speakers = [] + for i in range(1, num_speakers + 1): + spk = kwargs.get(f"speaker_{i}") + if spk is not None: + speakers.append(spk) + if len(speakers) < 2: + raise ValueError( + f"OmniVoice Speakers: at least 2 speakers must be connected " + f"(got {len(speakers)})." + ) + return ({"speakers": speakers, "mode": mode},) diff --git a/web/multi_speaker.js b/web/multi_speaker.js new file mode 100644 index 0000000..dc5864a --- /dev/null +++ b/web/multi_speaker.js @@ -0,0 +1,70 @@ +import { app } from "../../scripts/app.js"; + +const MAX_SPEAKERS = 8; + +app.registerExtension({ + name: "OmniVoice.MultiSpeaker", + + beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name !== "OmniVoiceSpeakers") return; + + /** + * Ensure the node has exactly `count` speaker_N inputs. + * Safe to call multiple times with the same count (idempotent). + */ + function syncSpeakerInputs(node, count) { + count = Math.max(2, Math.min(MAX_SPEAKERS, Math.floor(count))); + + // Add any missing slots + for (let i = 1; i <= count; i++) { + const name = `speaker_${i}`; + if (!node.inputs?.find(inp => inp.name === name)) { + node.addInput(name, "OMNIVOICE_SPEAKER"); + } + } + + // Remove excess slots (high → low so indices stay valid) + for (let i = MAX_SPEAKERS; i > count; i--) { + const name = `speaker_${i}`; + const idx = node.inputs?.findIndex(inp => inp.name === name) ?? -1; + if (idx === -1) continue; + // Sever any connected link before removing the slot + const linkId = node.inputs[idx].link; + if (linkId != null) node.graph?.removeLink(linkId); + node.removeInput(idx); + } + + node.setDirtyCanvas(true, true); + } + + /** + * Attach the num_speakers widget callback once per node instance. + * Guarded by a flag so configure() can call it safely on reload. + */ + function attachCallback(node) { + if (node._omnivoiceCbAttached) return; + const w = node.widgets?.find(w => w.name === "num_speakers"); + if (!w) return; + node._omnivoiceCbAttached = true; + w.callback = (value) => syncSpeakerInputs(node, value); + } + + // --- Fresh node creation --- + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + onNodeCreated?.apply(this, arguments); + attachCallback(this); + const w = this.widgets?.find(w => w.name === "num_speakers"); + if (w) syncSpeakerInputs(this, w.value); + }; + + // --- Workflow load: called by LiteGraph after widget values are restored --- + const onConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (data) { + onConfigure?.apply(this, arguments); + attachCallback(this); + const w = this.widgets?.find(w => w.name === "num_speakers"); + if (w) syncSpeakerInputs(this, w.value); + }; + }, +});