feat: auto input_sr — detect bandwidth and pick the best value

New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr
finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective
cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited
(→ matched input_sr), full-band (→ 24000), and genuine low-rate files
(→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding
the model an empty band. Logs its decision. Falls back to 24000 when unsure.

Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo.
Verified end-to-end on the real model (8 kHz clip -> auto picks 16000).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 12:46:02 +02:00
parent 8d4cd71723
commit e5110b88e1
4 changed files with 165 additions and 8 deletions
+8 -1
View File
@@ -142,7 +142,7 @@ Runs the super-resolution. Outputs: **`AUDIO`** (48 kHz) and **`IMAGE`** (spectr
|---|---|---|---|---|
| `audio` | AUDIO | — | — | Input audio (any sample rate / mono or stereo). |
| `model` | UNIVERSR_MODEL | — | — | From the Model Loader. |
| `input_sr` | choice | `8000` | 8000 / 12000 / 16000 / 24000 | **Effective input bandwidth (Hz).** Content is treated as valid up to `input_sr/2` and **regenerated above it**. See below. |
| `input_sr` | choice | `auto` | auto / 8000 / 12000 / 16000 / 24000 | **Effective input bandwidth (Hz).** Content is valid up to `input_sr/2` and **regenerated above it**. `auto` detects the cutoff for you (see below). |
| `ode_method` | choice | `midpoint` | euler / midpoint / rk4 | ODE solver. `euler` fastest → `midpoint` balanced → `rk4` best. |
| `ode_steps` | int | `4` | 164 | Flow-matching integration steps. `4` is fast & validated; `410` is a good range. |
| `guidance_scale` | float | `1.5` | 06 | Classifier-free guidance. Higher = denser highs but less faithful. `0` disables CFG. |
@@ -210,6 +210,13 @@ audio **and** the `video` reference into the combiner. Ready-made graph:
| `16000` | 8 kHz | 8 24 kHz |
| `24000` | 12 kHz | 12 24 kHz |
**`auto` (default)** analyses the input's spectrum, finds the **cutoff cliff**, and picks the largest
supported bandwidth at or below it (rounding *down*, to avoid feeding the model an empty band). It
prints its decision, e.g. `auto: cutoff 8.0 kHz (drop 53 dB) -> input_sr=16000`. When there's **no clear
cutoff** (full-band or gently rolled-off audio) it falls back to `24000` (least aggressive). Auto is
most reliable on genuinely band-limited material (codecs, downsamples, telephone); for fine control or
deliberate over-brightening, pick a value manually.
Two ways to use it:
1. **Genuine low-rate audio (classic super-resolution).** You have an 8 kHz (or 16/24 kHz) recording
+16 -7
View File
@@ -118,11 +118,12 @@ class UniverSRSampler:
"required": {
"audio": ("AUDIO", {}),
"model": ("UNIVERSR_MODEL", {}),
"input_sr": (["8000", "12000", "16000", "24000"], {
"default": "8000",
"input_sr": (["auto", "8000", "12000", "16000", "24000"], {
"default": "auto",
"tooltip": "Effective input bandwidth (Hz). Content is treated as valid up to "
"input_sr/2 and regenerated above it. 8000 = genuine low-rate audio "
"(strongest, 8 kHz->48 kHz). 16000 = brighten muffled audio above 8 kHz.",
"input_sr/2 and regenerated above it. 'auto' detects the audio's cutoff "
"and picks for you (falls back to 24000 if no clear cutoff). "
"8000 = genuine low-rate audio (strongest). 16000 = brighten muffled audio.",
}),
},
"optional": {
@@ -179,11 +180,19 @@ class UniverSRSampler:
model_obj = model["model"]
waveform, sr = usr.comfy_audio_to_tensor(audio)
dur = waveform.shape[-1] / max(sr, 1)
# Resolve auto bandwidth detection to a concrete input_sr.
if str(input_sr) == "auto":
isr, info = usr.detect_input_sr(waveform, sr)
print(f"[UniverSR] auto: {info['reason']} -> input_sr={isr}")
else:
isr = int(input_sr)
print(f"[UniverSR] {tuple(waveform.shape)} @ {sr} Hz ({dur:.2f}s) -> 48 kHz | "
f"input_sr={input_sr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}")
f"input_sr={isr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}")
out, dry48 = usr.super_resolve(
model_obj, waveform, sr, int(input_sr),
model_obj, waveform, sr, isr,
ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=guidance_scale,
seed=int(seed), chunk_seconds=float(chunk_seconds),
overlap_seconds=float(overlap_seconds), blend=float(blend),
@@ -195,7 +204,7 @@ class UniverSRSampler:
if show_spectrogram:
in_mono = dry48[0].mean(0).numpy()
out_mono = out[0].mean(0).numpy()
spec = usr.make_spectrogram_image(in_mono, out_mono, int(input_sr))
spec = usr.make_spectrogram_image(in_mono, out_mono, isr)
if unload_model:
usr.evict_model(model["cache_key"])
+73
View File
@@ -0,0 +1,73 @@
"""Tests for auto input_sr bandwidth detection (detect_input_sr).
Runnable with pytest, or standalone: python tests/test_auto_input_sr.py
Only needs torch (no ComfyUI). Loads universr_wrapper by path.
"""
import importlib.util
import os
import torch
_ND = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_spec = importlib.util.spec_from_file_location("_usr", os.path.join(_ND, "universr_wrapper.py"))
usr = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(usr)
SR = 48000
def _broadband(seconds=3.0, sr=SR):
torch.manual_seed(0)
return (torch.randn(int(sr * seconds)) * 0.3)
def _brickwall(x, sr, cut):
X = torch.fft.rfft(x)
f = torch.fft.rfftfreq(x.shape[-1], 1 / sr)
X[f > cut] = 0
return torch.fft.irfft(X, n=x.shape[-1])
def test_sharp_cutoffs_map_to_expected_input_sr():
base = _broadband()
for cut, expected in [(4000, 8000), (6000, 12000), (8000, 16000), (12000, 24000)]:
x = _brickwall(base, SR, cut).reshape(1, 1, -1)
isr, info = usr.detect_input_sr(x, SR)
assert info["confident"], f"cut={cut}: expected a confident cliff, got {info}"
assert isr == expected, f"cut={cut}: got input_sr={isr} ({info['cutoff_hz']:.0f} Hz)"
def test_full_band_falls_back_to_24000():
x = _broadband().reshape(1, 1, -1)
isr, info = usr.detect_input_sr(x, SR)
assert not info["confident"]
assert isr == 24000, info
def test_genuine_low_rate_file_maps_to_its_rate():
# A real 8 kHz file: sr=8000, content fills its 4 kHz band -> input_sr 8000.
torch.manual_seed(1)
x = (torch.randn(8000 * 3) * 0.3).reshape(1, 1, -1)
isr, info = usr.detect_input_sr(x, 8000)
assert isr == 8000, info
def test_silent_or_tiny_falls_back():
isr, info = usr.detect_input_sr(torch.zeros(1, 1, 1000), SR)
assert isr == 24000 and not info["confident"]
def test_stereo_is_mono_mixed():
base = _broadband()
x = _brickwall(base, SR, 8000)
stereo = torch.stack([x, x], 0).reshape(1, 2, -1)
isr, _ = usr.detect_input_sr(stereo, SR)
assert isr == 16000
if __name__ == "__main__":
fns = [v for k, v in sorted(globals().items()) if k.startswith("test_")]
for fn in fns:
fn()
print(f"PASS {fn.__name__}")
print(f"\nAll {len(fns)} tests passed.")
+68
View File
@@ -414,6 +414,74 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
return out.clamp(-1.0, 1.0), dry48
# --------------------------------------------------------------------------- #
# Auto input_sr — detect the audio's effective bandwidth
# --------------------------------------------------------------------------- #
def _supported_nyquists() -> list:
return [s // 2 for s in SUPPORTED_INPUT_SR] # [4000, 6000, 8000, 12000]
def _map_cutoff_to_input_sr(cutoff_hz: float) -> int:
"""Largest supported Nyquist <= cutoff (+300 Hz snap) -> input_sr. Round DOWN on
purpose: a Nyquist above the real cutoff makes the model treat an empty band as
valid and skip regenerating it."""
nyqs = _supported_nyquists()
below = [n for n in nyqs if n <= cutoff_hz + 300.0]
return (max(below) if below else min(nyqs)) * 2
@torch.no_grad()
def detect_input_sr(waveform: torch.Tensor, sr: int, conf_db: float = 25.0) -> tuple:
"""Estimate the effective bandwidth of `waveform` and choose the best input_sr.
Cliff/edge detector: find the steepest drop in the time-averaged magnitude
spectrum; its size (dB) is the confidence. Effective cutoff = that cliff if a
confident one sits below ~0.95*(sr/2), else sr/2 (signal fills its band).
Returns (input_sr:int, info:dict{cutoff_hz, drop_db, confident, reason}).
"""
x = waveform.detach().float().cpu()
if x.dim() == 3:
x = x.mean(dim=(0, 1))
elif x.dim() == 2:
x = x.mean(dim=0)
x = x.reshape(-1)
nyq = sr / 2.0
def _fallback(reason):
isr = _map_cutoff_to_input_sr(nyq)
return isr, {"cutoff_hz": nyq, "drop_db": 0.0, "confident": False, "reason": reason}
if x.numel() < 2048 or float(x.abs().max()) < 1e-6:
return _fallback(f"too short/silent -> sr/2={nyq/1000:.1f} kHz")
n_fft = 4096 if x.numel() >= 4096 else 1 << int(np.floor(np.log2(x.numel())))
spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
window=torch.hann_window(n_fft), return_complex=True).abs()
mag = spec.mean(dim=1)
k = 9
mag = torch.nn.functional.avg_pool1d(mag[None, None], k, 1, k // 2)[0, 0]
db = 20.0 * torch.log10((mag / mag.max().clamp(min=1e-12)).clamp(min=1e-12))
freqs = torch.linspace(0, nyq, mag.shape[0])
grad = db[1:] - db[:-1]
i = int(grad.argmin()) # steepest drop = candidate cliff edge
edge_hz = float(freqs[i])
pre = db[max(0, i - 10):i + 1].median()
post = db[i + 1:i + 40].median()
drop = float(pre - post)
confident = drop >= conf_db and edge_hz < 0.95 * nyq
if confident:
cutoff = edge_hz
reason = f"cutoff {cutoff/1000:.1f} kHz (drop {drop:.0f} dB)"
else:
cutoff = nyq
reason = f"no clear cutoff -> sr/2={nyq/1000:.1f} kHz"
isr = _map_cutoff_to_input_sr(cutoff)
return isr, {"cutoff_hz": cutoff, "drop_db": drop, "confident": confident, "reason": reason}
# --------------------------------------------------------------------------- #
# Spectrogram comparison (optional IMAGE output)
# --------------------------------------------------------------------------- #