e5110b88e1
New "auto" option (now the default) on the Sampler's input_sr. detect_input_sr finds the spectral cutoff cliff (steepest drop) and its dB confidence: effective cutoff = that cliff if confident, else sr/2 — one rule that covers band-limited (→ matched input_sr), full-band (→ 24000), and genuine low-rate files (→ their rate). Rounds DOWN to the nearest supported Nyquist to avoid feeding the model an empty band. Logs its decision. Falls back to 24000 when unsure. Tests cover sharp 4/6/8/12 kHz cutoffs, full-band, genuine-8kHz, silence, stereo. Verified end-to-end on the real model (8 kHz clip -> auto picks 16000). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
541 lines
22 KiB
Python
541 lines
22 KiB
Python
"""Core wrapper for ComfyUI-UniverSR.
|
|
|
|
Bootstraps the `universr` package (prefers a pip-installed copy, falls back to
|
|
the vendored one under ./vendor), manages model loading/caching, and runs the
|
|
super-resolution itself with optional overlap-add chunking for long audio.
|
|
|
|
UniverSR (ICASSP 2026) is a vocoder-free audio super-resolution model that
|
|
regenerates high-frequency content in the complex-STFT domain via flow matching.
|
|
A single model handles 8 / 12 / 16 / 24 kHz effective bandwidth -> 48 kHz.
|
|
|
|
Key design note — why we resample ourselves instead of handing UniverSR a file:
|
|
UniverSR's `enhance()` file path calls `torchaudio.load`, whose torchcodec
|
|
backend is fragile across environments; its *tensor* path assumes the tensor
|
|
is already at `input_sr`. ComfyUI audio arrives at an arbitrary real sample
|
|
rate, so we do the band-limit ourselves: resample to 48 kHz, downsample each
|
|
chunk to `input_sr` (pure DSP, no codec), and hand UniverSR a genuine
|
|
low-rate tensor to super-resolve. This reproduces the exact training-time
|
|
degradation and was validated in the FoleyTune BWE node.
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchaudio
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Optional ComfyUI integration (degrade gracefully outside ComfyUI / in tests)
|
|
# --------------------------------------------------------------------------- #
|
|
try:
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
HAS_COMFY = True
|
|
except Exception: # pragma: no cover - allows standalone import / pytest
|
|
HAS_COMFY = False
|
|
|
|
try:
|
|
import folder_paths
|
|
HAS_FOLDER_PATHS = True
|
|
except Exception: # pragma: no cover
|
|
HAS_FOLDER_PATHS = False
|
|
|
|
|
|
TARGET_SR = 48_000
|
|
SUPPORTED_INPUT_SR = (8000, 12000, 16000, 24000)
|
|
# UniverSR.enhance() zero-pads anything shorter than this (≈0.68 s @ 48 kHz) before
|
|
# running the ODE, so chunks below it just waste compute — clamp to it.
|
|
MODEL_MIN_SAMPLES = 32_768
|
|
_NODE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
_VENDOR_DIR = os.path.join(_NODE_DIR, "vendor")
|
|
_BUNDLED_CONFIG = os.path.join(_NODE_DIR, "configs", "config.yaml")
|
|
|
|
# HuggingFace repos for the two released checkpoints.
|
|
HF_REPOS = {
|
|
"universr-audio": "woongzip1/universr-audio",
|
|
"universr-speech": "woongzip1/universr-speech",
|
|
}
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Package bootstrap
|
|
# --------------------------------------------------------------------------- #
|
|
def get_universr_cls():
|
|
"""Return the `UniverSR` class, preferring an installed copy over the vendored one."""
|
|
try:
|
|
from universr import UniverSR # installed (e.g. via the FoleyTune node)
|
|
return UniverSR
|
|
except Exception:
|
|
pass
|
|
import sys
|
|
if _VENDOR_DIR not in sys.path:
|
|
sys.path.insert(0, _VENDOR_DIR)
|
|
try:
|
|
from universr import UniverSR # vendored fallback
|
|
return UniverSR
|
|
except Exception as e: # pragma: no cover
|
|
raise RuntimeError(
|
|
"Could not import the 'universr' package (neither installed nor vendored). "
|
|
"Try: pip install torchdiffeq (the only dependency ComfyUI does not already ship).\n"
|
|
f"Underlying error: {e}"
|
|
)
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Model directory + cache
|
|
# --------------------------------------------------------------------------- #
|
|
def get_models_dir() -> str:
|
|
if HAS_FOLDER_PATHS:
|
|
base = folder_paths.models_dir
|
|
else:
|
|
base = os.path.join(_NODE_DIR, "..", "..", "models")
|
|
return os.path.abspath(os.path.join(base, "universr"))
|
|
|
|
|
|
def list_local_models() -> list:
|
|
"""Subdirectories of models/universr that look like a UniverSR checkpoint dir."""
|
|
root = get_models_dir()
|
|
found = []
|
|
if os.path.isdir(root):
|
|
for name in sorted(os.listdir(root)):
|
|
d = os.path.join(root, name)
|
|
if os.path.isdir(d) and os.path.exists(os.path.join(d, "config.yaml")) \
|
|
and os.path.exists(os.path.join(d, "pytorch_model.bin")):
|
|
if name not in HF_REPOS:
|
|
found.append(name)
|
|
return found
|
|
|
|
|
|
_MODEL_CACHE: dict = {}
|
|
_CACHE_LOCK = threading.Lock()
|
|
|
|
|
|
def _download_preset(name: str) -> str:
|
|
"""Download a preset checkpoint into models/universr/<name> and return that dir."""
|
|
from huggingface_hub import snapshot_download
|
|
repo_id = HF_REPOS[name]
|
|
target = os.path.join(get_models_dir(), name)
|
|
have = os.path.exists(os.path.join(target, "config.yaml")) and \
|
|
os.path.exists(os.path.join(target, "pytorch_model.bin"))
|
|
if not have:
|
|
os.makedirs(target, exist_ok=True)
|
|
print(f"[UniverSR] Downloading {repo_id} -> {target} (~230 MB)...")
|
|
snapshot_download(
|
|
repo_id=repo_id,
|
|
local_dir=target,
|
|
allow_patterns=["config.yaml", "pytorch_model.bin"],
|
|
)
|
|
print(f"[UniverSR] Downloaded {name}.")
|
|
return target
|
|
|
|
|
|
def resolve_model_ref(model: str, local_path: str = "") -> tuple:
|
|
"""Resolve the loader inputs to (kind, path). kind in {'dir', 'ckpt'}.
|
|
|
|
- local_path wins if set: a directory (config.yaml + pytorch_model.bin) -> 'dir';
|
|
a .pth/.pt/.ckpt file -> 'ckpt' (loaded via from_local with a config).
|
|
- a preset name ('universr-audio' / 'universr-speech') -> download -> 'dir'.
|
|
- a local subdir name discovered under models/universr -> 'dir'.
|
|
"""
|
|
local_path = (local_path or "").strip()
|
|
if local_path:
|
|
if os.path.isdir(local_path):
|
|
return ("dir", local_path)
|
|
if os.path.isfile(local_path):
|
|
return ("ckpt", local_path)
|
|
raise FileNotFoundError(f"local_path does not exist: {local_path}")
|
|
|
|
if model in HF_REPOS:
|
|
return ("dir", _download_preset(model))
|
|
|
|
cand = os.path.join(get_models_dir(), model)
|
|
if os.path.isdir(cand):
|
|
return ("dir", cand)
|
|
raise FileNotFoundError(
|
|
f"Unknown model '{model}'. Use a preset {list(HF_REPOS)}, a local subdir of "
|
|
f"{get_models_dir()}, or set local_path."
|
|
)
|
|
|
|
|
|
def apply_tf32(enabled: bool):
|
|
"""Enable/disable TF32 for BOTH matmul and cuDNN convolutions on Ampere+ GPUs.
|
|
|
|
~1.15x when on. In our spectral A/B (centroid + HF energy) TF32 was tonally
|
|
neutral, but it is NOT bit-exact (10 mantissa bits vs 23), so it's off by
|
|
default. Off sets true fp32 — note PyTorch otherwise leaves cuDNN conv-TF32 ON
|
|
by default, so we explicitly disable it here too. Global process setting."""
|
|
try:
|
|
torch.set_float32_matmul_precision("high" if enabled else "highest") # matmul TF32
|
|
torch.backends.cudnn.allow_tf32 = enabled # conv TF32
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def load_model(model: str, device: str, local_path: str = "", config_path: str = "",
|
|
tf32: bool = False, compile_model: bool = False):
|
|
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
|
|
apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect
|
|
kind, path = resolve_model_ref(model, local_path)
|
|
cache_key = f"{kind}:{os.path.abspath(path)}:{device}:compile={bool(compile_model)}"
|
|
|
|
with _CACHE_LOCK:
|
|
if cache_key in _MODEL_CACHE:
|
|
print(f"[UniverSR] Using cached model ({cache_key})")
|
|
return _MODEL_CACHE[cache_key], cache_key
|
|
|
|
UniverSR = get_universr_cls()
|
|
if kind == "dir":
|
|
print(f"[UniverSR] Loading from_pretrained({path}) on {device}")
|
|
model_obj = UniverSR.from_pretrained(path, device=device)
|
|
else:
|
|
cfg = (config_path or "").strip() or _BUNDLED_CONFIG
|
|
if not os.path.exists(cfg):
|
|
raise FileNotFoundError(
|
|
f"config_path required for a raw checkpoint and not found: {cfg}"
|
|
)
|
|
print(f"[UniverSR] Loading from_local(ckpt={path}, config={cfg}) on {device}")
|
|
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
|
|
|
|
model_obj.eval()
|
|
|
|
# torch.compile the UNet (~2.1x measured). Static shapes only — the model's
|
|
# adaptive-avg-pool can't trace dynamic shapes — so the sampler pads every
|
|
# chunk to a fixed length (see _universr_compiled flag) to compile exactly once.
|
|
compiled = False
|
|
if compile_model and device == "cuda":
|
|
try:
|
|
import torch._dynamo as _dynamo
|
|
_dynamo.config.cache_size_limit = max(getattr(_dynamo.config, "cache_size_limit", 8), 32)
|
|
model_obj.model = torch.compile(model_obj.model, mode="default", dynamic=False)
|
|
compiled = True
|
|
print("[UniverSR] torch.compile enabled (first run compiles ~10-35s, then ~2x).")
|
|
except Exception as e:
|
|
print(f"[UniverSR] torch.compile unavailable, continuing eager: {e}")
|
|
model_obj._universr_compiled = compiled
|
|
|
|
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
|
|
print(f"[UniverSR] Ready - {n:.1f}M params on {device} (tf32={tf32}, compile={compiled})")
|
|
_MODEL_CACHE[cache_key] = model_obj
|
|
return model_obj, cache_key
|
|
|
|
|
|
def evict_model(cache_key: str):
|
|
import gc
|
|
with _CACHE_LOCK:
|
|
_MODEL_CACHE.pop(cache_key, None)
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
print(f"[UniverSR] Model unloaded ({cache_key})")
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Audio helpers
|
|
# --------------------------------------------------------------------------- #
|
|
def comfy_audio_to_tensor(audio) -> tuple:
|
|
"""ComfyUI AUDIO (dict or legacy tuple) -> (waveform [B, C, T] float32 cpu, sr)."""
|
|
if isinstance(audio, dict):
|
|
waveform, sr = audio["waveform"], audio["sample_rate"]
|
|
else:
|
|
waveform, sr = audio
|
|
if not isinstance(waveform, torch.Tensor):
|
|
waveform = torch.as_tensor(waveform)
|
|
waveform = waveform.detach().float().cpu()
|
|
if waveform.dim() == 1: # (T,)
|
|
waveform = waveform[None, None, :]
|
|
elif waveform.dim() == 2: # (C, T)
|
|
waveform = waveform[None, :, :]
|
|
return waveform, int(sr)
|
|
|
|
|
|
def tensor_to_comfy_audio(waveform: torch.Tensor, sr: int) -> dict:
|
|
if waveform.dim() == 1:
|
|
waveform = waveform[None, None, :]
|
|
elif waveform.dim() == 2:
|
|
waveform = waveform[None, :, :]
|
|
return {"waveform": waveform.detach().cpu().contiguous(), "sample_rate": int(sr)}
|
|
|
|
|
|
def _resample(x: torch.Tensor, orig: int, target: int) -> torch.Tensor:
|
|
if orig == target:
|
|
return x
|
|
return torchaudio.functional.resample(x, orig, target)
|
|
|
|
|
|
def _fit(x: torch.Tensor, n: int) -> torch.Tensor:
|
|
"""Crop or zero-pad a 1-D tensor to exactly n samples."""
|
|
if x.shape[-1] == n:
|
|
return x
|
|
if x.shape[-1] > n:
|
|
return x[:n]
|
|
return torch.nn.functional.pad(x, (0, n - x.shape[-1]))
|
|
|
|
|
|
def _crossfade_window(length: int, ov: int, first: bool, last: bool) -> torch.Tensor:
|
|
"""Linear fade-in/out over the overlap regions; flat 1.0 elsewhere.
|
|
|
|
Combined with weight-sum normalisation this gives click-free overlap-add.
|
|
"""
|
|
w = torch.ones(length)
|
|
if ov > 0:
|
|
f = min(ov, length)
|
|
if not first:
|
|
w[:f] = torch.minimum(w[:f], torch.linspace(0.0, 1.0, f))
|
|
if not last:
|
|
w[-f:] = torch.minimum(w[-f:], torch.linspace(1.0, 0.0, f))
|
|
return w
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Inference
|
|
# --------------------------------------------------------------------------- #
|
|
@torch.no_grad()
|
|
def _enhance_segment(model, seg48: torch.Tensor, input_sr: int,
|
|
ode_method: str, ode_steps: int, guidance_scale) -> torch.Tensor:
|
|
"""Super-resolve one 48 kHz mono segment. Returns 1-D tensor @48 kHz on CPU."""
|
|
low = _resample(seg48.unsqueeze(0), TARGET_SR, input_sr).squeeze(0) # genuine LR-rate signal
|
|
cfg = float(guidance_scale) if (guidance_scale and guidance_scale > 0) else None
|
|
out = model.enhance(
|
|
low, input_sr=int(input_sr),
|
|
ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=cfg,
|
|
)
|
|
return out.reshape(-1).float().cpu()
|
|
|
|
|
|
def _chunk_starts(total: int, chunk: int, hop: int) -> list:
|
|
if chunk <= 0 or total <= chunk:
|
|
return [0]
|
|
starts = list(range(0, max(1, total - chunk) + 1, hop))
|
|
if starts[-1] + chunk < total:
|
|
starts.append(total - chunk)
|
|
return starts
|
|
|
|
|
|
@torch.no_grad()
|
|
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
|
guidance_scale, chunk: int, ov: int, pbar, pad_to: int = 0) -> torch.Tensor:
|
|
T = ch48.shape[-1]
|
|
|
|
def seg_enhance(seg: torch.Tensor) -> torch.Tensor:
|
|
# pad_to>0 (compiled model) → zero-pad to a fixed length so the UNet always
|
|
# sees one input shape (compile once), then crop the result back.
|
|
L = seg.shape[-1]
|
|
if pad_to and pad_to > L:
|
|
seg = _fit(seg, pad_to)
|
|
return _fit(_enhance_segment(model, seg, input_sr, ode_method, ode_steps, guidance_scale), L)
|
|
|
|
if chunk <= 0 or T <= chunk:
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
return _fit(seg_enhance(ch48), T)
|
|
|
|
hop = max(1, chunk - ov)
|
|
starts = _chunk_starts(T, chunk, hop)
|
|
out = torch.zeros(T)
|
|
wsum = torch.zeros(T)
|
|
for i, s in enumerate(starts):
|
|
if HAS_COMFY:
|
|
mm.throw_exception_if_processing_interrupted()
|
|
e = min(s + chunk, T)
|
|
enh = seg_enhance(ch48[s:e])
|
|
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
|
|
out[s:e] += enh * w
|
|
wsum[s:e] += w
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
return out / torch.clamp(wsum, min=1e-8)
|
|
|
|
|
|
@torch.no_grad()
|
|
def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
|
ode_method: str = "midpoint", ode_steps: int = 4,
|
|
guidance_scale=1.5, seed: int = 0,
|
|
chunk_seconds: float = 10.0, overlap_seconds: float = 0.5,
|
|
blend: float = 1.0):
|
|
"""Run UniverSR over a [B, C, T] waveform. Returns (out [B, C, T48], dry48 [B, C, T48])."""
|
|
if int(input_sr) not in SUPPORTED_INPUT_SR:
|
|
raise ValueError(f"input_sr must be one of {SUPPORTED_INPUT_SR}, got {input_sr}")
|
|
|
|
waveform = waveform.float().cpu()
|
|
if waveform.dim() != 3:
|
|
raise ValueError(f"Expected a [B, C, T] waveform, got shape {tuple(waveform.shape)}")
|
|
B, C, _ = waveform.shape
|
|
dry48 = _resample(waveform, sr, TARGET_SR) # [B, C, T48]
|
|
T48 = dry48.shape[-1]
|
|
if T48 == 0: # empty input — nothing to do
|
|
empty = torch.zeros(B, C, 0)
|
|
return empty, empty
|
|
|
|
chunk = int(round(chunk_seconds * TARGET_SR)) if (chunk_seconds and chunk_seconds > 0) else 0
|
|
if 0 < chunk < MODEL_MIN_SAMPLES:
|
|
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
|
|
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
|
|
chunk = MODEL_MIN_SAMPLES
|
|
|
|
# A torch.compile'd model needs a fixed input shape, so force chunking and pad
|
|
# every chunk to `chunk` samples (compile once, reuse). Without compile, pad_to=0.
|
|
compiled = getattr(model, "_universr_compiled", False)
|
|
if compiled and chunk <= 0:
|
|
chunk = int(round(10.0 * TARGET_SR))
|
|
print("[UniverSR] compile: forcing 10.0s chunks for fixed input shapes.")
|
|
pad_to = chunk if compiled else 0
|
|
|
|
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
|
|
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
|
|
|
|
pbar = comfy.utils.ProgressBar(B * C * n_per_ch) if HAS_COMFY else None
|
|
|
|
# Isolate the global RNG: snapshot, seed, run, restore. Without this the model's
|
|
# torch.randn_like noise would advance (and a fixed seed would freeze) the global
|
|
# generator that downstream ComfyUI nodes rely on. seed=0 → fresh OS entropy.
|
|
cpu_rng = torch.get_rng_state()
|
|
cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
|
actual_seed = int(seed) if (seed and int(seed) != 0) else int.from_bytes(os.urandom(8), "little")
|
|
try:
|
|
torch.manual_seed(actual_seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(actual_seed)
|
|
wet = torch.zeros(B, C, T48)
|
|
for b in range(B):
|
|
for c in range(C):
|
|
wet[b, c] = _fit(
|
|
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
|
|
guidance_scale, chunk, ov, pbar, pad_to=pad_to),
|
|
T48,
|
|
)
|
|
finally:
|
|
torch.set_rng_state(cpu_rng)
|
|
if cuda_rng is not None:
|
|
torch.cuda.set_rng_state_all(cuda_rng)
|
|
|
|
blend = float(blend)
|
|
out = wet if blend >= 1.0 else (1.0 - blend) * dry48 + blend * wet
|
|
return out.clamp(-1.0, 1.0), dry48
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Auto input_sr — detect the audio's effective bandwidth
|
|
# --------------------------------------------------------------------------- #
|
|
def _supported_nyquists() -> list:
|
|
return [s // 2 for s in SUPPORTED_INPUT_SR] # [4000, 6000, 8000, 12000]
|
|
|
|
|
|
def _map_cutoff_to_input_sr(cutoff_hz: float) -> int:
|
|
"""Largest supported Nyquist <= cutoff (+300 Hz snap) -> input_sr. Round DOWN on
|
|
purpose: a Nyquist above the real cutoff makes the model treat an empty band as
|
|
valid and skip regenerating it."""
|
|
nyqs = _supported_nyquists()
|
|
below = [n for n in nyqs if n <= cutoff_hz + 300.0]
|
|
return (max(below) if below else min(nyqs)) * 2
|
|
|
|
|
|
@torch.no_grad()
|
|
def detect_input_sr(waveform: torch.Tensor, sr: int, conf_db: float = 25.0) -> tuple:
|
|
"""Estimate the effective bandwidth of `waveform` and choose the best input_sr.
|
|
|
|
Cliff/edge detector: find the steepest drop in the time-averaged magnitude
|
|
spectrum; its size (dB) is the confidence. Effective cutoff = that cliff if a
|
|
confident one sits below ~0.95*(sr/2), else sr/2 (signal fills its band).
|
|
|
|
Returns (input_sr:int, info:dict{cutoff_hz, drop_db, confident, reason}).
|
|
"""
|
|
x = waveform.detach().float().cpu()
|
|
if x.dim() == 3:
|
|
x = x.mean(dim=(0, 1))
|
|
elif x.dim() == 2:
|
|
x = x.mean(dim=0)
|
|
x = x.reshape(-1)
|
|
nyq = sr / 2.0
|
|
|
|
def _fallback(reason):
|
|
isr = _map_cutoff_to_input_sr(nyq)
|
|
return isr, {"cutoff_hz": nyq, "drop_db": 0.0, "confident": False, "reason": reason}
|
|
|
|
if x.numel() < 2048 or float(x.abs().max()) < 1e-6:
|
|
return _fallback(f"too short/silent -> sr/2={nyq/1000:.1f} kHz")
|
|
|
|
n_fft = 4096 if x.numel() >= 4096 else 1 << int(np.floor(np.log2(x.numel())))
|
|
spec = torch.stft(x, n_fft=n_fft, hop_length=n_fft // 4,
|
|
window=torch.hann_window(n_fft), return_complex=True).abs()
|
|
mag = spec.mean(dim=1)
|
|
k = 9
|
|
mag = torch.nn.functional.avg_pool1d(mag[None, None], k, 1, k // 2)[0, 0]
|
|
db = 20.0 * torch.log10((mag / mag.max().clamp(min=1e-12)).clamp(min=1e-12))
|
|
freqs = torch.linspace(0, nyq, mag.shape[0])
|
|
|
|
grad = db[1:] - db[:-1]
|
|
i = int(grad.argmin()) # steepest drop = candidate cliff edge
|
|
edge_hz = float(freqs[i])
|
|
pre = db[max(0, i - 10):i + 1].median()
|
|
post = db[i + 1:i + 40].median()
|
|
drop = float(pre - post)
|
|
|
|
confident = drop >= conf_db and edge_hz < 0.95 * nyq
|
|
if confident:
|
|
cutoff = edge_hz
|
|
reason = f"cutoff {cutoff/1000:.1f} kHz (drop {drop:.0f} dB)"
|
|
else:
|
|
cutoff = nyq
|
|
reason = f"no clear cutoff -> sr/2={nyq/1000:.1f} kHz"
|
|
isr = _map_cutoff_to_input_sr(cutoff)
|
|
return isr, {"cutoff_hz": cutoff, "drop_db": drop, "confident": confident, "reason": reason}
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Spectrogram comparison (optional IMAGE output)
|
|
# --------------------------------------------------------------------------- #
|
|
def _stft_db(x: np.ndarray) -> np.ndarray:
|
|
t = torch.from_numpy(np.ascontiguousarray(x)).float()
|
|
win = torch.hann_window(1024)
|
|
spec = torch.stft(t, n_fft=1024, hop_length=512, window=win, return_complex=True)
|
|
db = 20.0 * torch.log10(spec.abs().clamp(min=1e-5))
|
|
db = db - db.max()
|
|
return db.numpy()
|
|
|
|
|
|
def make_spectrogram_image(input48_mono: np.ndarray, output48_mono: np.ndarray,
|
|
input_sr: int) -> torch.Tensor:
|
|
"""Before/after spectrogram comparison -> IMAGE tensor [1, H, W, 3] in [0, 1].
|
|
|
|
Left panel is the band-limited input (content valid up to input_sr/2); right
|
|
panel is the 48 kHz output. The dashed line marks the LR Nyquist, so the
|
|
regenerated high-frequency band is the energy above it on the right.
|
|
"""
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Visualise the band-limit the model actually saw, not the raw container.
|
|
lr = torch.from_numpy(np.ascontiguousarray(input48_mono)).float()[None]
|
|
lr = _resample(_resample(lr, TARGET_SR, int(input_sr)), int(input_sr), TARGET_SR).squeeze(0).numpy()
|
|
n = min(len(lr), len(output48_mono), int(8.0 * TARGET_SR))
|
|
lr, hr = lr[:n], output48_mono[:n]
|
|
nyq = int(input_sr) / 2.0
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 4.0), facecolor="#0d0f16")
|
|
for ax, sig, title, cmap in (
|
|
(axes[0], lr, f"Input (<= {int(input_sr)//1000} kHz)", "magma"),
|
|
(axes[1], hr, "UniverSR output (48 kHz)", "viridis"),
|
|
):
|
|
db = _stft_db(sig)
|
|
ax.imshow(db, origin="lower", aspect="auto", cmap=cmap,
|
|
extent=[0, n / TARGET_SR, 0, TARGET_SR / 2], vmin=-80, vmax=0)
|
|
ax.axhline(nyq, color="w", ls="--", lw=0.8, alpha=0.6)
|
|
ax.set_title(title, color="#cfe0ff", fontsize=10)
|
|
ax.set_xlabel("Time (s)", color="#7a93bd", fontsize=8)
|
|
ax.set_ylabel("Hz", color="#7a93bd", fontsize=8)
|
|
ax.tick_params(colors="#5a6e90", labelsize=7)
|
|
ax.set_facecolor("#0d0f16")
|
|
fig.tight_layout()
|
|
|
|
fig.canvas.draw()
|
|
# np.asarray(buffer_rgba()) yields (H, W, 4) at the real pixel size — robust to HiDPI.
|
|
img = np.asarray(fig.canvas.buffer_rgba())[..., :3].astype(np.float32) / 255.0
|
|
plt.close(fig)
|
|
return torch.from_numpy(np.ascontiguousarray(img))[None]
|
|
except Exception as e: # matplotlib missing / headless edge cases
|
|
print(f"[UniverSR] Spectrogram render skipped: {e}")
|
|
return torch.zeros(1, 64, 64, 3)
|