feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -187,3 +187,120 @@ class SelvaDatasetLUFSNormalizer:
|
|||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
return (out,)
|
return (out,)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
|
||||||
|
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
|
||||||
|
|
||||||
|
Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT.
|
||||||
|
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
|
||||||
|
"""
|
||||||
|
if sr < 32000:
|
||||||
|
return False # can't assess HF at low sample rates
|
||||||
|
|
||||||
|
n_fft = 2048
|
||||||
|
hop = 512
|
||||||
|
window = torch.hann_window(n_fft)
|
||||||
|
mono = wav[0].mean(0) # [L]
|
||||||
|
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
|
||||||
|
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
|
||||||
|
|
||||||
|
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
|
||||||
|
band_lo = (freqs >= 1000) & (freqs < 5000)
|
||||||
|
band_hi = (freqs >= 15000) & (freqs < 20000)
|
||||||
|
|
||||||
|
if band_hi.sum() == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
|
||||||
|
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
|
||||||
|
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
|
||||||
|
return ratio_db > 40.0
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_snr(wav: torch.Tensor) -> float:
|
||||||
|
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
|
||||||
|
mono = wav[0].mean(0) # [L]
|
||||||
|
frames = mono.unfold(0, 2048, 512) # [N, 2048]
|
||||||
|
rms = frames.pow(2).mean(-1).sqrt() # [N]
|
||||||
|
p95 = torch.quantile(rms, 0.95).item()
|
||||||
|
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
|
||||||
|
return 20.0 * np.log10(p95 / p05 + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaDatasetInspector:
|
||||||
|
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"dataset": (AUDIO_DATASET,),
|
||||||
|
"skip_rejected": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "If True, flagged clips are removed from the output dataset. "
|
||||||
|
"If False, all clips pass through but the report still lists issues.",
|
||||||
|
}),
|
||||||
|
"min_snr_db": ("FLOAT", {
|
||||||
|
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
|
||||||
|
"tooltip": "Clips with estimated SNR below this value are flagged.",
|
||||||
|
}),
|
||||||
|
"check_codec_artifacts": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (AUDIO_DATASET, "STRING")
|
||||||
|
RETURN_NAMES = ("dataset", "report")
|
||||||
|
FUNCTION = "inspect"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Analyze each clip for clipping, low SNR, and codec artifacts. "
|
||||||
|
"Outputs a filtered AUDIO_DATASET and a text report. "
|
||||||
|
"Connect report to a ShowText node to preview in the UI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
|
||||||
|
clean = []
|
||||||
|
flagged = []
|
||||||
|
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
wav = item["waveform"]
|
||||||
|
sr = item["sample_rate"]
|
||||||
|
name = item["name"]
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Clipping
|
||||||
|
peak = wav.abs().max().item()
|
||||||
|
if peak > 0.99:
|
||||||
|
issues.append(f"clipping (peak={peak:.3f})")
|
||||||
|
|
||||||
|
# Low SNR
|
||||||
|
snr = _estimate_snr(wav)
|
||||||
|
if snr < min_snr_db:
|
||||||
|
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
|
||||||
|
|
||||||
|
# Codec artifacts
|
||||||
|
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
||||||
|
issues.append("codec artifact (HF shelf > 15 kHz)")
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
flagged.append(name)
|
||||||
|
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
||||||
|
if not skip_rejected:
|
||||||
|
clean.append(item)
|
||||||
|
else:
|
||||||
|
clean.append(item)
|
||||||
|
lines.append(f" OK {name}")
|
||||||
|
|
||||||
|
lines.append("=" * 40)
|
||||||
|
lines.append(
|
||||||
|
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
|
||||||
|
+ (" (removed)" if skip_rejected else " (kept)")
|
||||||
|
)
|
||||||
|
report = "\n".join(lines)
|
||||||
|
print(f"[DatasetInspector]\n{report}", flush=True)
|
||||||
|
return (clean, report)
|
||||||
|
|||||||
Reference in New Issue
Block a user