"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes. Typical chain: SelvaDatasetLoader ↓ AUDIO_DATASET SelvaDatasetResampler (optional) ↓ AUDIO_DATASET SelvaDatasetLUFSNormalizer (optional) ↓ AUDIO_DATASET SelvaDatasetInspector (optional) ↓ AUDIO_DATASET + STRING report SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes) """ from pathlib import Path import numpy as np import torch import torchaudio from .utils import SELVA_CATEGORY # ComfyUI custom type name — passed between all dataset pipeline nodes AUDIO_DATASET = "AUDIO_DATASET" _AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"} class SelvaDatasetLoader: """Load all audio files in a folder into an in-memory AUDIO_DATASET.""" @classmethod def INPUT_TYPES(cls): return { "required": { "folder": ("STRING", { "default": "", "tooltip": "Absolute path to folder containing audio files. Searched recursively.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "load" CATEGORY = SELVA_CATEGORY DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET." def load(self, folder: str): folder = Path(folder.strip()) if not folder.exists(): raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}") files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS] if not files: raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}") dataset = [] for f in sorted(files): try: wav, sr = torchaudio.load(str(f)) # [C, L] wav = wav.unsqueeze(0).float() # [1, C, L] dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem}) except Exception as e: print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True) print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True) return (dataset,) class SelvaDatasetResampler: """Resample all clips in a dataset to a target sample rate using soxr VHQ.""" @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "target_sr": ("INT", { "default": 44100, "min": 8000, "max": 192000, "tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "resample" CATEGORY = SELVA_CATEGORY DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate." def resample(self, dataset, target_sr: int): import soxr out = [] changed = 0 for item in dataset: sr = item["sample_rate"] if sr == target_sr: out.append(item) continue wav = item["waveform"][0] # [C, L] # soxr expects [L, C] (time-first), float64 wav_np = wav.permute(1, 0).double().numpy() # [L, C] wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ") wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L] out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]}) changed += 1 print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True) return (out,) class SelvaDatasetLUFSNormalizer: """Normalize each clip to a target integrated LUFS level + true peak limit.""" @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "target_lufs": ("FLOAT", { "default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5, "tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.", }), "true_peak_dbtp": ("FLOAT", { "default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5, "tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "normalize" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. " "Skips clips that are too short for LUFS measurement (< 0.4 s)." ) def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float): import pyloudnorm as pyln tp_linear = 10.0 ** (true_peak_dbtp / 20.0) out = [] skipped = 0 for item in dataset: wav = item["waveform"][0] # [C, L] sr = item["sample_rate"] # pyloudnorm wants [L] mono or [L, C] multichannel, float64 wav_np = wav.permute(1, 0).double().numpy() # [L, C] if wav_np.shape[1] == 1: wav_np = wav_np[:, 0] # [L] mono meter = pyln.Meter(sr) try: loudness = meter.integrated_loudness(wav_np) except Exception: skipped += 1 out.append(item) continue if not np.isfinite(loudness): skipped += 1 out.append(item) continue gain_db = target_lufs - loudness gain_linear = 10.0 ** (gain_db / 20.0) wav_norm = wav * gain_linear # True peak limit peak = wav_norm.abs().max().item() if peak > tp_linear: wav_norm = wav_norm * (tp_linear / peak) out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]}) print( f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized " f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}", flush=True, ) return (out,)