From e5110b88e1f12f5aca90766f33a689a406514ee5 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 17 Jun 2026 12:46:02 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20auto=20input=5Fsr=20=E2=80=94=20detect?= =?UTF-8?q?=20bandwidth=20and=20pick=20the=20best=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited (→ matched input_sr), full-band (→ 24000), and genuine low-rate files (→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding the model an empty band. Logs its decision. Falls back to 24000 when unsure. Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo. Verified end-to-end on the real model (8 kHz clip -> auto picks 16000). Co-Authored-By: Claude Opus 4.8 --- README.md | 9 ++++- nodes.py | 23 ++++++++---- tests/test_auto_input_sr.py | 73 +++++++++++++++++++++++++++++++++++++ universr_wrapper.py | 68 ++++++++++++++++++++++++++++++++++ 4 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 tests/test_auto_input_sr.py diff --git a/README.md b/README.md index c8db433..1a3472a 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ Runs the super-resolution. Outputs: **`AUDIO`** (48 kHz) and **`IMAGE`** (spectr |---|---|---|---|---| | `audio` | AUDIO | — | — | Input audio (any sample rate / mono or stereo). | | `model` | UNIVERSR_MODEL | — | — | From the Model Loader. | -| `input_sr` | choice | `8000` | 8000 / 12000 / 16000 / 24000 | **Effective input bandwidth (Hz).** Content is treated as valid up to `input_sr/2` and **regenerated above it**. See below. | +| `input_sr` | choice | `auto` | auto / 8000 / 12000 / 16000 / 24000 | **Effective input bandwidth (Hz).** Content is valid up to `input_sr/2` and **regenerated above it**. `auto` detects the cutoff for you (see below). | | `ode_method` | choice | `midpoint` | euler / midpoint / rk4 | ODE solver. `euler` fastest → `midpoint` balanced → `rk4` best. | | `ode_steps` | int | `4` | 1–64 | Flow-matching integration steps. `4` is fast & validated; `4–10` is a good range. | | `guidance_scale` | float | `1.5` | 0–6 | Classifier-free guidance. Higher = denser highs but less faithful. `0` disables CFG. | @@ -210,6 +210,13 @@ audio **and** the `video` reference into the combiner. Ready-made graph: | `16000` | 8 kHz | 8 – 24 kHz | | `24000` | 12 kHz | 12 – 24 kHz | +**`auto` (default)** analyses the input's spectrum, finds the **cutoff cliff**, and picks the largest +supported bandwidth at or below it (rounding *down*, to avoid feeding the model an empty band). It +prints its decision, e.g. `auto: cutoff 8.0 kHz (drop 53 dB) -> input_sr=16000`. When there's **no clear +cutoff** (full-band or gently rolled-off audio) it falls back to `24000` (least aggressive). Auto is +most reliable on genuinely band-limited material (codecs, downsamples, telephone); for fine control or +deliberate over-brightening, pick a value manually. + Two ways to use it: 1. **Genuine low-rate audio (classic super-resolution).** You have an 8 kHz (or 16/24 kHz) recording diff --git a/nodes.py b/nodes.py index 45849aa..30fb12a 100644 --- a/nodes.py +++ b/nodes.py @@ -118,11 +118,12 @@ class UniverSRSampler: "required": { "audio": ("AUDIO", {}), "model": ("UNIVERSR_MODEL", {}), - "input_sr": (["8000", "12000", "16000", "24000"], { - "default": "8000", + "input_sr": (["auto", "8000", "12000", "16000", "24000"], { + "default": "auto", "tooltip": "Effective input bandwidth (Hz). Content is treated as valid up to " - "input_sr/2 and regenerated above it. 8000 = genuine low-rate audio " - "(strongest, 8 kHz->48 kHz). 16000 = brighten muffled audio above 8 kHz.", + "input_sr/2 and regenerated above it. 'auto' detects the audio's cutoff " + "and picks for you (falls back to 24000 if no clear cutoff). " + "8000 = genuine low-rate audio (strongest). 16000 = brighten muffled audio.", }), }, "optional": { @@ -179,11 +180,19 @@ class UniverSRSampler: model_obj = model["model"] waveform, sr = usr.comfy_audio_to_tensor(audio) dur = waveform.shape[-1] / max(sr, 1) + + # Resolve auto bandwidth detection to a concrete input_sr. + if str(input_sr) == "auto": + isr, info = usr.detect_input_sr(waveform, sr) + print(f"[UniverSR] auto: {info['reason']} -> input_sr={isr}") + else: + isr = int(input_sr) + print(f"[UniverSR] {tuple(waveform.shape)} @ {sr} Hz ({dur:.2f}s) -> 48 kHz | " - f"input_sr={input_sr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}") + f"input_sr={isr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}") out, dry48 = usr.super_resolve( - model_obj, waveform, sr, int(input_sr), + model_obj, waveform, sr, isr, ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=guidance_scale, seed=int(seed), chunk_seconds=float(chunk_seconds), overlap_seconds=float(overlap_seconds), blend=float(blend), @@ -195,7 +204,7 @@ class UniverSRSampler: if show_spectrogram: in_mono = dry48[0].mean(0).numpy() out_mono = out[0].mean(0).numpy() - spec = usr.make_spectrogram_image(in_mono, out_mono, int(input_sr)) + spec = usr.make_spectrogram_image(in_mono, out_mono, isr) if unload_model: usr.evict_model(model["cache_key"]) diff --git a/tests/test_auto_input_sr.py b/tests/test_auto_input_sr.py new file mode 100644 index 0000000..a9b4d2d --- /dev/null +++ b/tests/test_auto_input_sr.py @@ -0,0 +1,73 @@ +"""Tests for auto input_sr bandwidth detection (detect_input_sr). + +Runnable with pytest, or standalone: python tests/test_auto_input_sr.py +Only needs torch (no ComfyUI). Loads universr_wrapper by path. +""" +import importlib.util +import os + +import torch + +_ND = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_spec = importlib.util.spec_from_file_location("_usr", os.path.join(_ND, "universr_wrapper.py")) +usr = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(usr) + +SR = 48000 + + +def _broadband(seconds=3.0, sr=SR): + torch.manual_seed(0) + return (torch.randn(int(sr * seconds)) * 0.3) + + +def _brickwall(x, sr, cut): + X = torch.fft.rfft(x) + f = torch.fft.rfftfreq(x.shape[-1], 1 / sr) + X[f > cut] = 0 + return torch.fft.irfft(X, n=x.shape[-1]) + + +def test_sharp_cutoffs_map_to_expected_input_sr(): + base = _broadband() + for cut, expected in [(4000, 8000), (6000, 12000), (8000, 16000), (12000, 24000)]: + x = _brickwall(base, SR, cut).reshape(1, 1, -1) + isr, info = usr.detect_input_sr(x, SR) + assert info["confident"], f"cut={cut}: expected a confident cliff, got {info}" + assert isr == expected, f"cut={cut}: got input_sr={isr} ({info['cutoff_hz']:.0f} Hz)" + + +def test_full_band_falls_back_to_24000(): + x = _broadband().reshape(1, 1, -1) + isr, info = usr.detect_input_sr(x, SR) + assert not info["confident"] + assert isr == 24000, info + + +def test_genuine_low_rate_file_maps_to_its_rate(): + # A real 8 kHz file: sr=8000, content fills its 4 kHz band -> input_sr 8000. + torch.manual_seed(1) + x = (torch.randn(8000 * 3) * 0.3).reshape(1, 1, -1) + isr, info = usr.detect_input_sr(x, 8000) + assert isr == 8000, info + + +def test_silent_or_tiny_falls_back(): + isr, info = usr.detect_input_sr(torch.zeros(1, 1, 1000), SR) + assert isr == 24000 and not info["confident"] + + +def test_stereo_is_mono_mixed(): + base = _broadband() + x = _brickwall(base, SR, 8000) + stereo = torch.stack([x, x], 0).reshape(1, 2, -1) + isr, _ = usr.detect_input_sr(stereo, SR) + assert isr == 16000 + + +if __name__ == "__main__": + fns = [v for k, v in sorted(globals().items()) if k.startswith("test_")] + for fn in fns: + fn() + print(f"PASS {fn.__name__}") + print(f"\nAll {len(fns)} tests passed.") diff --git a/universr_wrapper.py b/universr_wrapper.py index e3592b6..3655585 100644 --- a/universr_wrapper.py +++ b/universr_wrapper.py @@ -414,6 +414,74 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int, return out.clamp(-1.0, 1.0), dry48 +# --------------------------------------------------------------------------- # +# Auto input_sr — detect the audio's effective bandwidth +# --------------------------------------------------------------------------- # +def _supported_nyquists() -> list: + return [s // 2 for s in SUPPORTED_INPUT_SR] # [4000, 6000, 8000, 12000] + + +def _map_cutoff_to_input_sr(cutoff_hz: float) -> int: + """Largest supported Nyquist <= cutoff (+300 Hz snap) -> input_sr. Round DOWN on + purpose: a Nyquist above the real cutoff makes the model treat an empty band as + valid and skip regenerating it.""" + nyqs = _supported_nyquists() + below = [n for n in nyqs if n <= cutoff_hz + 300.0] + return (max(below) if below else min(nyqs)) * 2 + + +@torch.no_grad() +def detect_input_sr(waveform: torch.Tensor, sr: int, conf_db: float = 25.0) -> tuple: + """Estimate the effective bandwidth of `waveform` and choose the best input_sr. + + Cliff/edge detector: find the steepest drop in the time-averaged magnitude + spectrum; its size (dB) is the confidence. Effective cutoff = that cliff if a + confident one sits below ~0.95*(sr/2), else sr/2 (signal fills its band). + + Returns (input_sr:int, info:dict{cutoff_hz, drop_db, confident, reason}). + """ + x = waveform.detach().float().cpu() + if x.dim() == 3: + x = x.mean(dim=(0, 1)) + elif x.dim() == 2: + x = x.mean(dim=0) + x = x.reshape(-1) + nyq = sr / 2.0 + + def _fallback(reason): + isr = _map_cutoff_to_input_sr(nyq) + return isr, {"cutoff_hz": nyq, "drop_db": 0.0, "confident": False, "reason": reason} + + if x.numel() < 2048 or float(x.abs().max()) < 1e-6: + return _fallback(f"too short/silent -> sr/2={nyq/1000:.1f} kHz") + + n_fft = 4096 if x.numel() >= 4096 else 1 << int(np.floor(np.log2(x.numel()))) + spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4, + window=torch.hann_window(n_fft), return_complex=True).abs() + mag = spec.mean(dim=1) + k = 9 + mag = torch.nn.functional.avg_pool1d(mag[None, None], k, 1, k // 2)[0, 0] + db = 20.0 * torch.log10((mag / mag.max().clamp(min=1e-12)).clamp(min=1e-12)) + freqs = torch.linspace(0, nyq, mag.shape[0]) + + grad = db[1:] - db[:-1] + i = int(grad.argmin()) # steepest drop = candidate cliff edge + edge_hz = float(freqs[i]) + pre = db[max(0, i - 10):i + 1].median() + post = db[i + 1:i + 40].median() + drop = float(pre - post) + + confident = drop >= conf_db and edge_hz < 0.95 * nyq + if confident: + cutoff = edge_hz + reason = f"cutoff {cutoff/1000:.1f} kHz (drop {drop:.0f} dB)" + else: + cutoff = nyq + reason = f"no clear cutoff -> sr/2={nyq/1000:.1f} kHz" + isr = _map_cutoff_to_input_sr(cutoff) + return isr, {"cutoff_hz": cutoff, "drop_db": drop, "confident": confident, "reason": reason} + + # --------------------------------------------------------------------------- # # Spectrogram comparison (optional IMAGE output) # --------------------------------------------------------------------------- #