feat: auto input_sr — detect bandwidth and pick the best value
New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited (→ matched input_sr), full-band (→ 24000), and genuine low-rate files (→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding the model an empty band. Logs its decision. Falls back to 24000 when unsure. Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo. Verified end-to-end on the real model (8 kHz clip -> auto picks 16000). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -414,6 +414,74 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
||||
return out.clamp(-1.0, 1.0), dry48
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Auto input_sr — detect the audio's effective bandwidth
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _supported_nyquists() -> list:
|
||||
return [s // 2 for s in SUPPORTED_INPUT_SR] # [4000, 6000, 8000, 12000]
|
||||
|
||||
|
||||
def _map_cutoff_to_input_sr(cutoff_hz: float) -> int:
|
||||
"""Largest supported Nyquist <= cutoff (+300 Hz snap) -> input_sr. Round DOWN on
|
||||
purpose: a Nyquist above the real cutoff makes the model treat an empty band as
|
||||
valid and skip regenerating it."""
|
||||
nyqs = _supported_nyquists()
|
||||
below = [n for n in nyqs if n <= cutoff_hz + 300.0]
|
||||
return (max(below) if below else min(nyqs)) * 2
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_input_sr(waveform: torch.Tensor, sr: int, conf_db: float = 25.0) -> tuple:
|
||||
"""Estimate the effective bandwidth of `waveform` and choose the best input_sr.
|
||||
|
||||
Cliff/edge detector: find the steepest drop in the time-averaged magnitude
|
||||
spectrum; its size (dB) is the confidence. Effective cutoff = that cliff if a
|
||||
confident one sits below ~0.95*(sr/2), else sr/2 (signal fills its band).
|
||||
|
||||
Returns (input_sr:int, info:dict{cutoff_hz, drop_db, confident, reason}).
|
||||
"""
|
||||
x = waveform.detach().float().cpu()
|
||||
if x.dim() == 3:
|
||||
x = x.mean(dim=(0, 1))
|
||||
elif x.dim() == 2:
|
||||
x = x.mean(dim=0)
|
||||
x = x.reshape(-1)
|
||||
nyq = sr / 2.0
|
||||
|
||||
def _fallback(reason):
|
||||
isr = _map_cutoff_to_input_sr(nyq)
|
||||
return isr, {"cutoff_hz": nyq, "drop_db": 0.0, "confident": False, "reason": reason}
|
||||
|
||||
if x.numel() < 2048 or float(x.abs().max()) < 1e-6:
|
||||
return _fallback(f"too short/silent -> sr/2={nyq/1000:.1f} kHz")
|
||||
|
||||
n_fft = 4096 if x.numel() >= 4096 else 1 << int(np.floor(np.log2(x.numel())))
|
||||
spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
|
||||
window=torch.hann_window(n_fft), return_complex=True).abs()
|
||||
mag = spec.mean(dim=1)
|
||||
k = 9
|
||||
mag = torch.nn.functional.avg_pool1d(mag[None, None], k, 1, k // 2)[0, 0]
|
||||
db = 20.0 * torch.log10((mag / mag.max().clamp(min=1e-12)).clamp(min=1e-12))
|
||||
freqs = torch.linspace(0, nyq, mag.shape[0])
|
||||
|
||||
grad = db[1:] - db[:-1]
|
||||
i = int(grad.argmin()) # steepest drop = candidate cliff edge
|
||||
edge_hz = float(freqs[i])
|
||||
pre = db[max(0, i - 10):i + 1].median()
|
||||
post = db[i + 1:i + 40].median()
|
||||
drop = float(pre - post)
|
||||
|
||||
confident = drop >= conf_db and edge_hz < 0.95 * nyq
|
||||
if confident:
|
||||
cutoff = edge_hz
|
||||
reason = f"cutoff {cutoff/1000:.1f} kHz (drop {drop:.0f} dB)"
|
||||
else:
|
||||
cutoff = nyq
|
||||
reason = f"no clear cutoff -> sr/2={nyq/1000:.1f} kHz"
|
||||
isr = _map_cutoff_to_input_sr(cutoff)
|
||||
return isr, {"cutoff_hz": cutoff, "drop_db": drop, "confident": confident, "reason": reason}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Spectrogram comparison (optional IMAGE output)
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
Reference in New Issue
Block a user