feat: auto input_sr — detect bandwidth and pick the best value

New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr
finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective
cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited
(→ matched input_sr), full-band (→ 24000), and genuine low-rate files
(→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding
the model an empty band. Logs its decision. Falls back to 24000 when unsure.

Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo.
Verified end-to-end on the real model (8 kHz clip -> auto picks 16000).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 12:46:02 +02:00
parent 8d4cd71723
commit e5110b88e1
4 changed files with 165 additions and 8 deletions
+68
View File
@@ -414,6 +414,74 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
return out.clamp(-1.0, 1.0), dry48
# --------------------------------------------------------------------------- #
# Auto input_sr — detect the audio's effective bandwidth
# --------------------------------------------------------------------------- #
def _supported_nyquists() -> list:
return [s // 2 for s in SUPPORTED_INPUT_SR] # [4000, 6000, 8000, 12000]
def _map_cutoff_to_input_sr(cutoff_hz: float) -> int:
"""Largest supported Nyquist <= cutoff (+300 Hz snap) -> input_sr. Round DOWN on
purpose: a Nyquist above the real cutoff makes the model treat an empty band as
valid and skip regenerating it."""
nyqs = _supported_nyquists()
below = [n for n in nyqs if n <= cutoff_hz + 300.0]
return (max(below) if below else min(nyqs)) * 2
@torch.no_grad()
def detect_input_sr(waveform: torch.Tensor, sr: int, conf_db: float = 25.0) -> tuple:
"""Estimate the effective bandwidth of `waveform` and choose the best input_sr.
Cliff/edge detector: find the steepest drop in the time-averaged magnitude
spectrum; its size (dB) is the confidence. Effective cutoff = that cliff if a
confident one sits below ~0.95*(sr/2), else sr/2 (signal fills its band).
Returns (input_sr:int, info:dict{cutoff_hz, drop_db, confident, reason}).
"""
x = waveform.detach().float().cpu()
if x.dim() == 3:
x = x.mean(dim=(0, 1))
elif x.dim() == 2:
x = x.mean(dim=0)
x = x.reshape(-1)
nyq = sr / 2.0
def _fallback(reason):
isr = _map_cutoff_to_input_sr(nyq)
return isr, {"cutoff_hz": nyq, "drop_db": 0.0, "confident": False, "reason": reason}
if x.numel() < 2048 or float(x.abs().max()) < 1e-6:
return _fallback(f"too short/silent -> sr/2={nyq/1000:.1f} kHz")
n_fft = 4096 if x.numel() >= 4096 else 1 << int(np.floor(np.log2(x.numel())))
spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
window=torch.hann_window(n_fft), return_complex=True).abs()
mag = spec.mean(dim=1)
k = 9
mag = torch.nn.functional.avg_pool1d(mag[None, None], k, 1, k // 2)[0, 0]
db = 20.0 * torch.log10((mag / mag.max().clamp(min=1e-12)).clamp(min=1e-12))
freqs = torch.linspace(0, nyq, mag.shape[0])
grad = db[1:] - db[:-1]
i = int(grad.argmin()) # steepest drop = candidate cliff edge
edge_hz = float(freqs[i])
pre = db[max(0, i - 10):i + 1].median()
post = db[i + 1:i + 40].median()
drop = float(pre - post)
confident = drop >= conf_db and edge_hz < 0.95 * nyq
if confident:
cutoff = edge_hz
reason = f"cutoff {cutoff/1000:.1f} kHz (drop {drop:.0f} dB)"
else:
cutoff = nyq
reason = f"no clear cutoff -> sr/2={nyq/1000:.1f} kHz"
isr = _map_cutoff_to_input_sr(cutoff)
return isr, {"cutoff_hz": cutoff, "drop_db": drop, "confident": confident, "reason": reason}
# --------------------------------------------------------------------------- #
# Spectrogram comparison (optional IMAGE output)
# --------------------------------------------------------------------------- #