"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes. Typical chain: SelvaDatasetLoader ↓ AUDIO_DATASET SelvaDatasetResampler (optional) ↓ AUDIO_DATASET SelvaDatasetLUFSNormalizer (optional) ↓ AUDIO_DATASET SelvaDatasetCompressor (optional) ↓ AUDIO_DATASET SelvaDatasetSpectralMatcher (optional — batch spectral EQ) ↓ AUDIO_DATASET SelvaDatasetHfSmoother (optional — batch HF attenuation) ↓ AUDIO_DATASET SelvaDatasetAugmenter (optional — gain/pitch/stretch variants) ↓ AUDIO_DATASET SelvaDatasetInspector (optional) ↓ AUDIO_DATASET + STRING report SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes) """ from pathlib import Path import numpy as np import torch import torchaudio from .utils import SELVA_CATEGORY # ComfyUI custom type name — passed between all dataset pipeline nodes AUDIO_DATASET = "AUDIO_DATASET" _AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"} _SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg def _load_audio(path: Path): """Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues.""" if path.suffix.lower() in _SOUNDFILE_EXTS: import soundfile as sf wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C] wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L] else: wav, sr = torchaudio.load(str(path)) # [C, L] wav = wav.unsqueeze(0).float() # [1, C, L] return wav, sr class SelvaDatasetLoader: """Load all audio files in a folder into an in-memory AUDIO_DATASET.""" @classmethod def INPUT_TYPES(cls): return { "required": { "folder": ("STRING", { "default": "", "tooltip": "Absolute path to folder containing audio files. Searched recursively.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "load" CATEGORY = SELVA_CATEGORY DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET." def load(self, folder: str): folder = Path(folder.strip()) if not folder.exists(): raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}") files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS] if not files: raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}") dataset = [] for f in sorted(files): try: wav, sr = _load_audio(f) dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem}) except Exception as e: print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True) print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True) return (dataset,) class SelvaDatasetResampler: """Resample all clips in a dataset to a target sample rate using soxr VHQ.""" @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "target_sr": ("INT", { "default": 44100, "min": 8000, "max": 192000, "tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "resample" CATEGORY = SELVA_CATEGORY DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate." def resample(self, dataset, target_sr: int): import soxr out = [] changed = 0 for item in dataset: sr = item["sample_rate"] if sr == target_sr: out.append(item) continue wav = item["waveform"][0] # [C, L] # soxr expects [L, C] (time-first), float64 wav_np = wav.permute(1, 0).double().numpy() # [L, C] wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ") wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L] out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]}) changed += 1 print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True) return (out,) class SelvaDatasetLUFSNormalizer: """Normalize each clip to a target integrated LUFS level + true peak limit.""" @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "target_lufs": ("FLOAT", { "default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5, "tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.", }), "true_peak_dbtp": ("FLOAT", { "default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5, "tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "normalize" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. " "Skips clips that are too short for LUFS measurement (< 0.4 s)." ) def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float): import pyloudnorm as pyln tp_linear = 10.0 ** (true_peak_dbtp / 20.0) out = [] skipped = 0 for item in dataset: wav = item["waveform"][0] # [C, L] sr = item["sample_rate"] # pyloudnorm wants [L] mono or [L, C] multichannel, float64 wav_np = wav.permute(1, 0).double().numpy() # [L, C] if wav_np.shape[1] == 1: wav_np = wav_np[:, 0] # [L] mono meter = pyln.Meter(sr) try: loudness = meter.integrated_loudness(wav_np) except Exception: skipped += 1 out.append(item) continue if not np.isfinite(loudness): skipped += 1 out.append(item) continue gain_db = target_lufs - loudness gain_linear = 10.0 ** (gain_db / 20.0) wav_norm = wav * gain_linear # True peak limit peak = wav_norm.abs().max().item() if peak > tp_linear: wav_norm = wav_norm * (tp_linear / peak) out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]}) print( f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized " f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}", flush=True, ) return (out,) class SelvaDatasetCompressor: """Apply mild parallel compression to reduce within-clip loudness variance. Uses pedalboard.Compressor (2:1–3:1 ratio). Parallel (New York) style: blends compressed signal with dry so transients are preserved while the dynamic range is gently tightened. Apply after LUFS normalization. """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "threshold_db": ("FLOAT", { "default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0, "tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.", }), "ratio": ("FLOAT", { "default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5, "tooltip": "Compression ratio. 2:1–3:1 is mild; stay below 4:1 to avoid pumping.", }), "attack_ms": ("FLOAT", { "default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0, "tooltip": "Attack time in ms. Slower attack preserves transients.", }), "release_ms": ("FLOAT", { "default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0, "tooltip": "Release time in ms.", }), "mix": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.3–0.5 is typical.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "compress" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Mild parallel compression to reduce within-clip dynamic range. " "Blends compressed signal with dry at the given mix ratio. " "Apply after LUFS normalization." ) def compress(self, dataset, threshold_db: float, ratio: float, attack_ms: float, release_ms: float, mix: float): from pedalboard import Compressor, Pedalboard board = Pedalboard([Compressor( threshold_db=threshold_db, ratio=ratio, attack_ms=attack_ms, release_ms=release_ms, )]) out = [] for item in dataset: wav = item["waveform"][0] # [C, L] sr = item["sample_rate"] # pedalboard expects [C, L] float32 numpy wav_np = wav.float().numpy() # [C, L] compressed = board(wav_np, sr) # [C, L] mixed = (1.0 - mix) * wav_np + mix * compressed wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L] out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]}) print( f"[DatasetCompressor] {len(out)} clips compressed " f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}", flush=True, ) return (out,) def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: """Return True if clip looks codec-compressed (hard HF shelf above 15 kHz). Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT. A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts. """ if sr < 32000: return False # can't assess HF at low sample rates n_fft = 2048 hop = 512 mono = wav[0].mean(0) # [L] window = torch.hann_window(n_fft, device=mono.device) stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device) band_lo = (freqs >= 1000) & (freqs < 5000) band_hi = (freqs >= 15000) & (freqs < 20000) if band_hi.sum() == 0: return False energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12) energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12) ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item() return ratio_db > 40.0 def _estimate_snr(wav: torch.Tensor) -> float: """Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS.""" mono = wav[0].mean(0) # [L] if mono.shape[0] < 2048: return 60.0 # clip too short to frame — assume clean frames = mono.unfold(0, 2048, 512) # [N, 2048] rms = frames.pow(2).mean(-1).sqrt() # [N] p95 = torch.quantile(rms, 0.95).item() p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item() return 20.0 * np.log10(p95 / p05 + 1e-8) class SelvaDatasetInspector: """Analyze each clip for quality issues and optionally filter out flagged clips.""" @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "skip_rejected": ("BOOLEAN", { "default": True, "tooltip": "If True, flagged clips are removed from the output dataset. " "If False, all clips pass through but the report still lists issues.", }), "min_snr_db": ("FLOAT", { "default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0, "tooltip": "Clips with estimated SNR below this value are flagged.", }), "check_codec_artifacts": ("BOOLEAN", { "default": True, "tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).", }), "max_silence_fraction": ("FLOAT", { "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Flag clips where more than this fraction of frames are near-silent " "(< -60 dBFS). Set to 0 to disable silence detection.", }), } } RETURN_TYPES = (AUDIO_DATASET, "STRING") RETURN_NAMES = ("dataset", "report") FUNCTION = "inspect" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Analyze each clip for clipping, low SNR, and codec artifacts. " "Outputs a filtered AUDIO_DATASET and a text report. " "Connect report to a ShowText node to preview in the UI." ) def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool, max_silence_fraction: float = 0.5): clean = [] flagged = [] lines = ["SelVA Dataset Inspector Report", "=" * 40] for item in dataset: wav = item["waveform"] sr = item["sample_rate"] name = item["name"] issues = [] # Clipping peak = wav.abs().max().item() if peak > 0.99: issues.append(f"clipping (peak={peak:.3f})") # Low SNR snr = _estimate_snr(wav) if snr < min_snr_db: issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)") # Codec artifacts if check_codec_artifacts and _check_hf_shelf(wav, sr): issues.append("codec artifact (HF shelf > 15 kHz)") # Silence detection if max_silence_fraction > 0: mono = wav[0].mean(0) if mono.shape[0] >= 2048: frames = mono.unfold(0, 2048, 512) rms = frames.pow(2).mean(-1).sqrt() silent_frac = (rms < 1e-3).float().mean().item() if silent_frac > max_silence_fraction: issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)") if issues: flagged.append(name) lines.append(f" FLAGGED {name}: {', '.join(issues)}") if not skip_rejected: clean.append(item) else: clean.append(item) lines.append(f" OK {name}") lines.append("=" * 40) lines.append( f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}" + (" (removed)" if skip_rejected else " (kept)") ) report = "\n".join(lines) print(f"[DatasetInspector]\n{report}", flush=True) return (clean, report) class SelvaDatasetItemExtractor: """Extract a single AUDIO item from an AUDIO_DATASET by index. Bridges the dataset pipeline to any node that accepts a standard AUDIO input — save audio, HF Smoother, Spectral Matcher, etc. """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "index": ("INT", { "default": 0, "min": 0, "max": 9999, "tooltip": "0-based index. Wraps around if index >= dataset length.", }), } } RETURN_TYPES = ("AUDIO", "STRING", "INT") RETURN_NAMES = ("audio", "name", "total") FUNCTION = "extract" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Extract one clip from an AUDIO_DATASET by index. " "Returns standard AUDIO (compatible with all audio nodes), " "the clip name, and the total dataset length." ) def extract(self, dataset, index: int): if not dataset: raise RuntimeError("[DatasetItemExtractor] Dataset is empty.") idx = index % len(dataset) item = dataset[idx] audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]} print( f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} " f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}", flush=True, ) return (audio, item["name"], len(dataset)) class SelvaDatasetSaver: """Save all clips in an AUDIO_DATASET to disk as FLAC files. Optionally copies matching .npz feature files from a source directory, keeping FLAC/NPZ pairs in sync after the inspector has filtered clips. """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "output_dir": ("STRING", { "default": "", "tooltip": "Absolute path to output folder. Created if it does not exist.", }), }, "optional": { "npz_source_dir": ("STRING", { "default": "", "tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. " "Missing NPZs are warned but do not abort the save.", }), }, } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("report",) OUTPUT_NODE = True FUNCTION = "save" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Save every clip in an AUDIO_DATASET to output_dir as FLAC. " "If npz_source_dir is provided, copies the matching .npz file for each clip — " "so rejected clips never get their NPZ copied." ) def save(self, dataset, output_dir: str, npz_source_dir: str = ""): import shutil import soundfile as sf out = Path(output_dir.strip()) out.mkdir(parents=True, exist_ok=True) npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None saved = 0 npz_copied = 0 npz_missing = [] for item in dataset: name = item["name"] wav = item["waveform"][0] # [C, L] sr = item["sample_rate"] # soundfile wants [L] mono or [L, C] multichannel, float32 wav_np = wav.permute(1, 0).float().numpy() # [L, C] if wav_np.shape[1] == 1: wav_np = wav_np[:, 0] # [L] mono flac_path = out / f"{name}.flac" sf.write(str(flac_path), wav_np, sr, subtype="PCM_24") saved += 1 if npz_src is not None: # Augmented clips store their origin name — use it to find the .npz lookup = item.get("origin_name", name) npz_path = npz_src / f"{lookup}.npz" if npz_path.exists(): shutil.copy2(str(npz_path), str(out / f"{name}.npz")) npz_copied += 1 else: npz_missing.append(name) lines = [ f"[DatasetSaver] Saved {saved} clips → {out}", ] if npz_src is not None: lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}") for n in npz_missing: lines.append(f" MISSING NPZ: {n}") report = "\n".join(lines) print(report, flush=True) return (report,) # ── Batch wrappers for audio preprocessors ─────────────────────────────────── class SelvaDatasetSpectralMatcher: """Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET. Wraps SelvaSpectralMatcher so it works on batch datasets instead of individual AUDIO items. Same parameters — see that node for details. """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "mode": (["44k", "16k"], { "tooltip": "Must match the SelVA model you are training. " "44k = large model, 16k = small model.", }), "strength": ("FLOAT", { "default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "0 = no correction, 1 = full match to VAE distribution.", }), "max_gain_db": ("FLOAT", { "default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0, "tooltip": "Clamps per-band gain to ±dB.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "process" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Apply adaptive spectral matching to every clip in a dataset. " "Batch version of SelVA Spectral Matcher — same per-band EQ toward the " "VAE's expected distribution." ) def process(self, dataset, mode: str, strength: float, max_gain_db: float): from .selva_audio_preprocessors import SelvaSpectralMatcher matcher = SelvaSpectralMatcher() out = [] for item in dataset: audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]} (result,) = matcher.process(audio, mode, strength, max_gain_db) new_item = dict(item) # preserve origin_name and any extra keys new_item["waveform"] = result["waveform"] new_item["sample_rate"] = result["sample_rate"] out.append(new_item) print(f"[DatasetSpectralMatcher] {len(out)} clips processed " f"mode={mode} strength={strength}", flush=True) return (out,) class SelvaDatasetHfSmoother: """Apply SelVA HF Smoother to every clip in an AUDIO_DATASET. Wraps SelvaHfSmoother so it works on batch datasets instead of individual AUDIO items. Same parameters — see that node for details. """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "cutoff_hz": ("FLOAT", { "default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0, "tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.", }), "blend": ("FLOAT", { "default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "0 = original, 1 = fully filtered.", }), } } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "process" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Apply soft HF attenuation to every clip in a dataset. " "Batch version of SelVA HF Smoother — blends a low-pass filtered copy " "with the original to tame extreme HF content." ) def process(self, dataset, cutoff_hz: float, blend: float): from .selva_audio_preprocessors import SelvaHfSmoother smoother = SelvaHfSmoother() out = [] for item in dataset: audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]} (result,) = smoother.process(audio, cutoff_hz, blend) new_item = dict(item) # preserve origin_name and any extra keys new_item["waveform"] = result["waveform"] new_item["sample_rate"] = result["sample_rate"] out.append(new_item) print(f"[DatasetHfSmoother] {len(out)} clips processed " f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True) return (out,) # ── Dataset augmenter ──────────────────────────────────────────────────────── class SelvaDatasetAugmenter: """Create augmented variants of each clip to expand a small dataset. Supports gain variation (always available) and optionally pitch shift and time stretch via audiomentations. Install audiomentations for the full feature set: ``pip install audiomentations`` Each original clip produces ``variants_per_clip`` augmented copies. Originals are kept by default (toggle ``keep_originals``). """ @classmethod def INPUT_TYPES(cls): return { "required": { "dataset": (AUDIO_DATASET,), "variants_per_clip": ("INT", { "default": 2, "min": 1, "max": 20, "tooltip": "Number of augmented copies per original clip.", }), "gain_range_db": ("FLOAT", { "default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5, "tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.", }), "seed": ("INT", {"default": 42}), }, "optional": { "pitch_range_semitones": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25, "tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.", }), "time_stretch_range": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05, "tooltip": "Random time stretch ±fraction (0.1 = 90%–110% speed). " "Requires audiomentations. 0 = disabled.", }), "keep_originals": ("BOOLEAN", { "default": True, "tooltip": "Include the original unaugmented clips in the output.", }), }, } RETURN_TYPES = (AUDIO_DATASET,) RETURN_NAMES = ("dataset",) FUNCTION = "augment" CATEGORY = SELVA_CATEGORY DESCRIPTION = ( "Create augmented variants of each clip (gain, pitch, time stretch) " "to expand small training datasets. Gain is always available; pitch and " "time stretch require audiomentations (pip install audiomentations)." ) def augment(self, dataset, variants_per_clip: int, gain_range_db: float, seed: int, pitch_range_semitones: float = 0.0, time_stretch_range: float = 0.0, keep_originals: bool = True): rng = np.random.RandomState(seed) # Try audiomentations for pitch/stretch use_am = False am_compose = None needs_am = pitch_range_semitones > 0 or time_stretch_range > 0 if needs_am: try: import audiomentations as am transforms = [] if pitch_range_semitones > 0: transforms.append(am.PitchShift( min_semitones=-pitch_range_semitones, max_semitones=pitch_range_semitones, p=0.5, )) if time_stretch_range > 0: transforms.append(am.TimeStretch( min_rate=1.0 - time_stretch_range, max_rate=1.0 + time_stretch_range, leave_length_unchanged=True, p=0.5, )) am_compose = am.Compose(transforms) use_am = True except ImportError: print("[DatasetAugmenter] audiomentations not installed — " "pitch_shift and time_stretch disabled. " "Install: pip install audiomentations", flush=True) out = [] if keep_originals: out.extend(dataset) for item in dataset: wav = item["waveform"] # [1, C, L] sr = item["sample_rate"] name = item["name"] for v in range(variants_per_clip): # Gain variation (always applied) gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0 gain_lin = 10.0 ** (gain_db / 20.0) wav_aug = wav * gain_lin # Pitch/stretch via audiomentations if use_am and am_compose is not None: wav_np = wav_aug[0].numpy() # [C, L] float32 if wav_np.shape[0] == 1: wav_np = wav_np[0] # [L] mono for audiomentations wav_np = am_compose(samples=wav_np, sample_rate=sr) if wav_np.ndim == 1: wav_np = wav_np[np.newaxis, :] # back to [1, L] wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L] # Prevent clipping peak = wav_aug.abs().max() if peak > 1.0: wav_aug = wav_aug / peak out.append({ "waveform": wav_aug, "sample_rate": sr, "name": f"{name}_aug{v:02d}", "origin_name": name, }) print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips " f"gain=±{gain_range_db:.1f}dB" + (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "") + (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""), flush=True) return (out,)