import re import torch import torchaudio from ..misotts import Segment # --------------------------------------------------------------------------- audio helpers def _audio_to_mono24k(audio_dict, sr_target=24000): """ComfyUI AUDIO dict -> 1-D mono tensor at 24 kHz (Mimi's rate).""" wav = audio_dict["waveform"] sr = int(audio_dict["sample_rate"]) if wav.dim() == 3: wav = wav[0] # (C, T) if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True) # mix to mono if sr != sr_target: wav = torchaudio.functional.resample(wav, sr, sr_target) return wav.squeeze(0).contiguous().float() # --------------------------------------------------------------------------- text chunking def _split_sentences(text): parts = re.split(r"(?<=[.!?…])\s+", text.strip()) return [p.strip() for p in parts if p.strip()] def _hard_split(s, max_chars): """Break an over-long sentence on commas, then on words, so no chunk exceeds max_chars.""" out, cur = [], "" for tok in re.split(r"(?<=,)\s+", s): if cur and len(cur) + 1 + len(tok) > max_chars: out.append(cur) cur = tok else: cur = f"{cur} {tok}".strip() if cur: out.append(cur) final = [] for c in out: if len(c) <= max_chars: final.append(c) continue cc = "" for w in c.split(): if cc and len(cc) + 1 + len(w) > max_chars: final.append(cc) cc = w else: cc = f"{cc} {w}".strip() if cc: final.append(cc) return final def _chunk_text(text, max_chars): """Sentence-aware chunking. Respects paragraph breaks and EPUB '---' chapter markers, packs whole sentences up to max_chars, and hard-splits any sentence longer than that.""" chunks = [] paragraphs = re.split(r"\n\s*\n|\n?-{3,}\n?", text) for para in paragraphs: para = para.strip() if not para: continue cur = "" for s in _split_sentences(para): if len(s) > max_chars: if cur: chunks.append(cur) cur = "" chunks.extend(_hard_split(s, max_chars)) continue if cur and len(cur) + 1 + len(s) > max_chars: chunks.append(cur) cur = s else: cur = f"{cur} {s}".strip() if cur: chunks.append(cur) return chunks # --------------------------------------------------------------------------- node class MisoTTSGenerate: """Generate speech from text. Handles arbitrarily long text (audiobooks/EPUB chapters) by sentence-aware chunking, and keeps a consistent voice across chunks by feeding prior audio (and an optional reference clip) back as context — CSM models otherwise drift.""" @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MISOTTS_MODEL", {"tooltip": "Loaded by the MisoTTS Model Loader node."}), "text": ("STRING", {"multiline": True, "default": "", "tooltip": "Text to synthesize. Long text is chunked automatically."}), }, "optional": { "ref_audio": ("AUDIO", { "tooltip": "Optional reference clip to clone the voice from. Anchored across every chunk.", }), "ref_text": ("STRING", {"default": "", "tooltip": "Transcript of ref_audio. Improves cloning quality."}), "speaker": ("INT", {"default": 0, "min": 0, "max": 31, "tooltip": "Speaker id. Keep fixed for a single narrator."}), "temperature": ("FLOAT", {"default": 0.9, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Sampling temperature. Lower = steadier, higher = more varied."}), "topk": ("INT", {"default": 50, "min": 1, "max": 500, "tooltip": "Top-k sampling cutoff."}), "max_chunk_seconds": ("FLOAT", {"default": 30.0, "min": 5.0, "max": 90.0, "step": 1.0, "tooltip": "Max audio length generated per text chunk."}), "chunk_chars": ("INT", {"default": 300, "min": 50, "max": 2000, "step": 10, "tooltip": "Target characters per chunk. Larger = fewer joins, more VRAM/time."}), "context_window": ("INT", {"default": 1, "min": 0, "max": 4, "tooltip": ( "How many previous chunks to feed back as context to keep the voice " "consistent. 1 is a good default; 0 makes each chunk independent " "(voice may drift). Higher = steadier but slower / more VRAM.")}), "silence_ms": ("INT", {"default": 250, "min": 0, "max": 2000, "step": 10, "tooltip": "Silence inserted between chunks."}), "seed": ("INT", {"default": 0, "min": 0, "max": 2**32 - 1, "tooltip": "0 = random each run. Set a fixed value for reproducible narration."}), }, } RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "generate" CATEGORY = "MisoTTS" def generate(self, model, text, ref_audio=None, ref_text="", speaker=0, temperature=0.9, topk=50, max_chunk_seconds=30.0, chunk_chars=300, context_window=1, silence_ms=250, seed=0): if seed != 0: torch.manual_seed(seed) text = (text or "").strip() if not text: raise ValueError("MisoTTS Generate: text is empty.") chunks = _chunk_text(text, int(chunk_chars)) if not chunks: raise ValueError("MisoTTS Generate: no text chunks produced.") sr = int(model.sample_rate) ms = float(max_chunk_seconds) * 1000.0 ref_seg = None if ref_audio is not None: ref_seg = Segment(speaker=int(speaker), text=(ref_text or "").strip(), audio=_audio_to_mono24k(ref_audio, sr)) gap = torch.zeros(int(sr * silence_ms / 1000.0)) if silence_ms > 0 else None keep = max(int(context_window), 1) history, pieces = [], [] for i, chunk in enumerate(chunks): ctx = [] if ref_seg is not None: ctx.append(ref_seg) if context_window > 0: ctx.extend(history[-context_window:]) audio = model.generate(text=chunk, speaker=int(speaker), context=ctx, max_audio_length_ms=ms, temperature=float(temperature), topk=int(topk)) audio = audio.detach().to("cpu", torch.float32) if i > 0 and gap is not None: pieces.append(gap) pieces.append(audio) history.append(Segment(speaker=int(speaker), text=chunk, audio=audio)) history = history[-keep:] waveform = torch.cat(pieces).unsqueeze(0).unsqueeze(0) # (1, 1, T) return ({"waveform": waveform, "sample_rate": sr},)