diff --git a/nodes/voice_presets.py b/nodes/voice_presets.py index 5878655..649301f 100644 --- a/nodes/voice_presets.py +++ b/nodes/voice_presets.py @@ -53,14 +53,45 @@ PRESETS = { } -def _load_audio(url): - """Download (once) and return (waveform_tensor, sample_rate).""" +_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} +_BUILTIN_FILES = frozenset(os.path.basename(url.split("?")[0]) for url, _ in PRESETS.values()) + + +def _scan_user_presets(): + """Return a dict of user presets found in _CACHE_DIR. + + For each audio file that is not a cached built-in, look for a same-stem + .txt file for the transcript. Key format: " (local)". + """ + user = {} + if not os.path.isdir(_CACHE_DIR): + return user + for fname in sorted(os.listdir(_CACHE_DIR)): + stem, ext = os.path.splitext(fname) + if ext.lower() not in _AUDIO_EXTS or fname in _BUILTIN_FILES: + continue + audio_path = os.path.join(_CACHE_DIR, fname) + txt_path = os.path.join(_CACHE_DIR, stem + ".txt") + transcript = "" + if os.path.exists(txt_path): + with open(txt_path, "r", encoding="utf-8") as f: + transcript = f.read().strip() + user[f"{stem} (local)"] = (audio_path, transcript) + return user + + +def _load_audio(source): + """Load audio from a URL (downloading once) or a local file path.""" os.makedirs(_CACHE_DIR, exist_ok=True) - filename = os.path.basename(url.split("?")[0]) - cache_path = os.path.join(_CACHE_DIR, filename) - if not os.path.exists(cache_path): - urllib.request.urlretrieve(url, cache_path) - audio_np, sr = sf.read(cache_path, dtype="float32") + if source.startswith("http://") or source.startswith("https://"): + filename = os.path.basename(source.split("?")[0]) + cache_path = os.path.join(_CACHE_DIR, filename) + if not os.path.exists(cache_path): + urllib.request.urlretrieve(source, cache_path) + path = cache_path + else: + path = source + audio_np, sr = sf.read(path, dtype="float32") if audio_np.ndim == 1: audio_np = audio_np[np.newaxis, :] # (1, samples) else: @@ -72,15 +103,20 @@ def _load_audio(url): class OmniVoiceVoicePreset: @classmethod def INPUT_TYPES(cls): + all_presets = {**PRESETS, **_scan_user_presets()} return { "required": { "preset": ( - list(PRESETS.keys()), + list(all_presets.keys()), { "tooltip": ( "Pre-fetched reference voice for OmniVoice Generate.\n" "Connect ref_audio → ref_audio and ref_text → ref_text.\n" - "If ref_text is blank, connect a Whisper node to supply the transcript." + "\n" + "To add your own presets, drop audio files into:\n" + f" {_CACHE_DIR}\n" + "Add a same-name .txt file alongside for the transcript.\n" + "Restart ComfyUI to pick up new files." ), }, ), @@ -93,6 +129,7 @@ class OmniVoiceVoicePreset: CATEGORY = "OmniVoice" def load_preset(self, preset): - url, transcript = PRESETS[preset] - waveform, sr = _load_audio(url) + all_presets = {**PRESETS, **_scan_user_presets()} + source, transcript = all_presets[preset] + waveform, sr = _load_audio(source) return ({"waveform": waveform, "sample_rate": sr}, transcript)