feat: auto input_sr — detect bandwidth and pick the best value
New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited (→ matched input_sr), full-band (→ 24000), and genuine low-rate files (→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding the model an empty band. Logs its decision. Falls back to 24000 when unsure. Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo. Verified end-to-end on the real model (8 kHz clip -> auto picks 16000). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
"""Tests for auto input_sr bandwidth detection (detect_input_sr).
|
||||
|
||||
Runnable with pytest, or standalone: python tests/test_auto_input_sr.py
|
||||
Only needs torch (no ComfyUI). Loads universr_wrapper by path.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
_ND = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_spec = importlib.util.spec_from_file_location("_usr", os.path.join(_ND, "universr_wrapper.py"))
|
||||
usr = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(usr)
|
||||
|
||||
SR = 48000
|
||||
|
||||
|
||||
def _broadband(seconds=3.0, sr=SR):
|
||||
torch.manual_seed(0)
|
||||
return (torch.randn(int(sr * seconds)) * 0.3)
|
||||
|
||||
|
||||
def _brickwall(x, sr, cut):
|
||||
X = torch.fft.rfft(x)
|
||||
f = torch.fft.rfftfreq(x.shape[-1], 1 / sr)
|
||||
X[f > cut] = 0
|
||||
return torch.fft.irfft(X, n=x.shape[-1])
|
||||
|
||||
|
||||
def test_sharp_cutoffs_map_to_expected_input_sr():
|
||||
base = _broadband()
|
||||
for cut, expected in [(4000, 8000), (6000, 12000), (8000, 16000), (12000, 24000)]:
|
||||
x = _brickwall(base, SR, cut).reshape(1, 1, -1)
|
||||
isr, info = usr.detect_input_sr(x, SR)
|
||||
assert info["confident"], f"cut={cut}: expected a confident cliff, got {info}"
|
||||
assert isr == expected, f"cut={cut}: got input_sr={isr} ({info['cutoff_hz']:.0f} Hz)"
|
||||
|
||||
|
||||
def test_full_band_falls_back_to_24000():
|
||||
x = _broadband().reshape(1, 1, -1)
|
||||
isr, info = usr.detect_input_sr(x, SR)
|
||||
assert not info["confident"]
|
||||
assert isr == 24000, info
|
||||
|
||||
|
||||
def test_genuine_low_rate_file_maps_to_its_rate():
|
||||
# A real 8 kHz file: sr=8000, content fills its 4 kHz band -> input_sr 8000.
|
||||
torch.manual_seed(1)
|
||||
x = (torch.randn(8000 * 3) * 0.3).reshape(1, 1, -1)
|
||||
isr, info = usr.detect_input_sr(x, 8000)
|
||||
assert isr == 8000, info
|
||||
|
||||
|
||||
def test_silent_or_tiny_falls_back():
|
||||
isr, info = usr.detect_input_sr(torch.zeros(1, 1, 1000), SR)
|
||||
assert isr == 24000 and not info["confident"]
|
||||
|
||||
|
||||
def test_stereo_is_mono_mixed():
|
||||
base = _broadband()
|
||||
x = _brickwall(base, SR, 8000)
|
||||
stereo = torch.stack([x, x], 0).reshape(1, 2, -1)
|
||||
isr, _ = usr.detect_input_sr(stereo, SR)
|
||||
assert isr == 16000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fns = [v for k, v in sorted(globals().items()) if k.startswith("test_")]
|
||||
for fn in fns:
|
||||
fn()
|
||||
print(f"PASS {fn.__name__}")
|
||||
print(f"\nAll {len(fns)} tests passed.")
|
||||
Reference in New Issue
Block a user