Initial release: ComfyUI-MisoTTS (modernized CSM 8B)
Modernized MisoTTS integration for ComfyUI with no torchtune/moshi: - vendored plain-torch Llama backbone (csm_llama), parity-verified Δ=0 vs torchtune - transformers.MimiModel codec (bit-identical codes to moshi), drops moshi/bnb/sphn - low-memory loader: streams 32GB fp32 checkpoint to GPU in bf16 (~18GB VRAM) - nodes: Model Loader, Generate (audiobook chunking + voice anchoring), EPUB Loader - pin-free requirements; runs on modern torch / Blackwell GPUs Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from ..misotts import Segment
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- audio helpers
|
||||
def _audio_to_mono24k(audio_dict, sr_target=24000):
|
||||
"""ComfyUI AUDIO dict -> 1-D mono tensor at 24 kHz (Mimi's rate)."""
|
||||
wav = audio_dict["waveform"]
|
||||
sr = int(audio_dict["sample_rate"])
|
||||
if wav.dim() == 3:
|
||||
wav = wav[0] # (C, T)
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(0, keepdim=True) # mix to mono
|
||||
if sr != sr_target:
|
||||
wav = torchaudio.functional.resample(wav, sr, sr_target)
|
||||
return wav.squeeze(0).contiguous().float()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- text chunking
|
||||
def _split_sentences(text):
|
||||
parts = re.split(r"(?<=[.!?…])\s+", text.strip())
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
def _hard_split(s, max_chars):
|
||||
"""Break an over-long sentence on commas, then on words, so no chunk exceeds max_chars."""
|
||||
out, cur = [], ""
|
||||
for tok in re.split(r"(?<=,)\s+", s):
|
||||
if cur and len(cur) + 1 + len(tok) > max_chars:
|
||||
out.append(cur)
|
||||
cur = tok
|
||||
else:
|
||||
cur = f"{cur} {tok}".strip()
|
||||
if cur:
|
||||
out.append(cur)
|
||||
final = []
|
||||
for c in out:
|
||||
if len(c) <= max_chars:
|
||||
final.append(c)
|
||||
continue
|
||||
cc = ""
|
||||
for w in c.split():
|
||||
if cc and len(cc) + 1 + len(w) > max_chars:
|
||||
final.append(cc)
|
||||
cc = w
|
||||
else:
|
||||
cc = f"{cc} {w}".strip()
|
||||
if cc:
|
||||
final.append(cc)
|
||||
return final
|
||||
|
||||
|
||||
def _chunk_text(text, max_chars):
|
||||
"""Sentence-aware chunking. Respects paragraph breaks and EPUB '---' chapter markers,
|
||||
packs whole sentences up to max_chars, and hard-splits any sentence longer than that."""
|
||||
chunks = []
|
||||
paragraphs = re.split(r"\n\s*\n|\n?-{3,}\n?", text)
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
cur = ""
|
||||
for s in _split_sentences(para):
|
||||
if len(s) > max_chars:
|
||||
if cur:
|
||||
chunks.append(cur)
|
||||
cur = ""
|
||||
chunks.extend(_hard_split(s, max_chars))
|
||||
continue
|
||||
if cur and len(cur) + 1 + len(s) > max_chars:
|
||||
chunks.append(cur)
|
||||
cur = s
|
||||
else:
|
||||
cur = f"{cur} {s}".strip()
|
||||
if cur:
|
||||
chunks.append(cur)
|
||||
return chunks
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- node
|
||||
class MisoTTSGenerate:
|
||||
"""Generate speech from text. Handles arbitrarily long text (audiobooks/EPUB chapters)
|
||||
by sentence-aware chunking, and keeps a consistent voice across chunks by feeding prior
|
||||
audio (and an optional reference clip) back as context — CSM models otherwise drift."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MISOTTS_MODEL", {"tooltip": "Loaded by the MisoTTS Model Loader node."}),
|
||||
"text": ("STRING", {"multiline": True, "default": "",
|
||||
"tooltip": "Text to synthesize. Long text is chunked automatically."}),
|
||||
},
|
||||
"optional": {
|
||||
"ref_audio": ("AUDIO", {
|
||||
"tooltip": "Optional reference clip to clone the voice from. Anchored across every chunk.",
|
||||
}),
|
||||
"ref_text": ("STRING", {"default": "",
|
||||
"tooltip": "Transcript of ref_audio. Improves cloning quality."}),
|
||||
"speaker": ("INT", {"default": 0, "min": 0, "max": 31,
|
||||
"tooltip": "Speaker id. Keep fixed for a single narrator."}),
|
||||
"temperature": ("FLOAT", {"default": 0.9, "min": 0.1, "max": 2.0, "step": 0.05,
|
||||
"tooltip": "Sampling temperature. Lower = steadier, higher = more varied."}),
|
||||
"topk": ("INT", {"default": 50, "min": 1, "max": 500,
|
||||
"tooltip": "Top-k sampling cutoff."}),
|
||||
"max_chunk_seconds": ("FLOAT", {"default": 30.0, "min": 5.0, "max": 90.0, "step": 1.0,
|
||||
"tooltip": "Max audio length generated per text chunk."}),
|
||||
"chunk_chars": ("INT", {"default": 300, "min": 50, "max": 2000, "step": 10,
|
||||
"tooltip": "Target characters per chunk. Larger = fewer joins, more VRAM/time."}),
|
||||
"context_window": ("INT", {"default": 1, "min": 0, "max": 4,
|
||||
"tooltip": (
|
||||
"How many previous chunks to feed back as context to keep the voice "
|
||||
"consistent. 1 is a good default; 0 makes each chunk independent "
|
||||
"(voice may drift). Higher = steadier but slower / more VRAM.")}),
|
||||
"silence_ms": ("INT", {"default": 250, "min": 0, "max": 2000, "step": 10,
|
||||
"tooltip": "Silence inserted between chunks."}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 2**32 - 1,
|
||||
"tooltip": "0 = random each run. Set a fixed value for reproducible narration."}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = "MisoTTS"
|
||||
|
||||
def generate(self, model, text, ref_audio=None, ref_text="", speaker=0, temperature=0.9,
|
||||
topk=50, max_chunk_seconds=30.0, chunk_chars=300, context_window=1,
|
||||
silence_ms=250, seed=0):
|
||||
if seed != 0:
|
||||
torch.manual_seed(seed)
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
raise ValueError("MisoTTS Generate: text is empty.")
|
||||
|
||||
chunks = _chunk_text(text, int(chunk_chars))
|
||||
if not chunks:
|
||||
raise ValueError("MisoTTS Generate: no text chunks produced.")
|
||||
|
||||
sr = int(model.sample_rate)
|
||||
ms = float(max_chunk_seconds) * 1000.0
|
||||
|
||||
ref_seg = None
|
||||
if ref_audio is not None:
|
||||
ref_seg = Segment(speaker=int(speaker), text=(ref_text or "").strip(),
|
||||
audio=_audio_to_mono24k(ref_audio, sr))
|
||||
|
||||
gap = torch.zeros(int(sr * silence_ms / 1000.0)) if silence_ms > 0 else None
|
||||
keep = max(int(context_window), 1)
|
||||
|
||||
history, pieces = [], []
|
||||
for i, chunk in enumerate(chunks):
|
||||
ctx = []
|
||||
if ref_seg is not None:
|
||||
ctx.append(ref_seg)
|
||||
if context_window > 0:
|
||||
ctx.extend(history[-context_window:])
|
||||
|
||||
audio = model.generate(text=chunk, speaker=int(speaker), context=ctx,
|
||||
max_audio_length_ms=ms, temperature=float(temperature), topk=int(topk))
|
||||
audio = audio.detach().to("cpu", torch.float32)
|
||||
|
||||
if i > 0 and gap is not None:
|
||||
pieces.append(gap)
|
||||
pieces.append(audio)
|
||||
|
||||
history.append(Segment(speaker=int(speaker), text=chunk, audio=audio))
|
||||
history = history[-keep:]
|
||||
|
||||
waveform = torch.cat(pieces).unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
||||
return ({"waveform": waveform, "sample_rate": sr},)
|
||||
Reference in New Issue
Block a user