diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index f360e5d..398f8df 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -187,3 +187,120 @@ class SelvaDatasetLUFSNormalizer: flush=True, ) return (out,) + + +def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool: + """Return True if clip looks codec-compressed (hard HF shelf above 15 kHz). + + Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT. + A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts. + """ + if sr < 32000: + return False # can't assess HF at low sample rates + + n_fft = 2048 + hop = 512 + window = torch.hann_window(n_fft) + mono = wav[0].mean(0) # [L] + stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True) + mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs] + + freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1) + band_lo = (freqs >= 1000) & (freqs < 5000) + band_hi = (freqs >= 15000) & (freqs < 20000) + + if band_hi.sum() == 0: + return False + + energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12) + energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12) + ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item() + return ratio_db > 40.0 + + +def _estimate_snr(wav: torch.Tensor) -> float: + """Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS.""" + mono = wav[0].mean(0) # [L] + frames = mono.unfold(0, 2048, 512) # [N, 2048] + rms = frames.pow(2).mean(-1).sqrt() # [N] + p95 = torch.quantile(rms, 0.95).item() + p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item() + return 20.0 * np.log10(p95 / p05 + 1e-8) + + +class SelvaDatasetInspector: + """Analyze each clip for quality issues and optionally filter out flagged clips.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dataset": (AUDIO_DATASET,), + "skip_rejected": ("BOOLEAN", { + "default": True, + "tooltip": "If True, flagged clips are removed from the output dataset. " + "If False, all clips pass through but the report still lists issues.", + }), + "min_snr_db": ("FLOAT", { + "default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0, + "tooltip": "Clips with estimated SNR below this value are flagged.", + }), + "check_codec_artifacts": ("BOOLEAN", { + "default": True, + "tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).", + }), + } + } + + RETURN_TYPES = (AUDIO_DATASET, "STRING") + RETURN_NAMES = ("dataset", "report") + FUNCTION = "inspect" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Analyze each clip for clipping, low SNR, and codec artifacts. " + "Outputs a filtered AUDIO_DATASET and a text report. " + "Connect report to a ShowText node to preview in the UI." + ) + + def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool): + clean = [] + flagged = [] + lines = ["SelVA Dataset Inspector Report", "=" * 40] + + for item in dataset: + wav = item["waveform"] + sr = item["sample_rate"] + name = item["name"] + issues = [] + + # Clipping + peak = wav.abs().max().item() + if peak > 0.99: + issues.append(f"clipping (peak={peak:.3f})") + + # Low SNR + snr = _estimate_snr(wav) + if snr < min_snr_db: + issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)") + + # Codec artifacts + if check_codec_artifacts and _check_hf_shelf(wav, sr): + issues.append("codec artifact (HF shelf > 15 kHz)") + + if issues: + flagged.append(name) + lines.append(f" FLAGGED {name}: {', '.join(issues)}") + if not skip_rejected: + clean.append(item) + else: + clean.append(item) + lines.append(f" OK {name}") + + lines.append("=" * 40) + lines.append( + f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}" + + (" (removed)" if skip_rejected else " (kept)") + ) + report = "\n".join(lines) + print(f"[DatasetInspector]\n{report}", flush=True) + return (clean, report)