From 5f29b225b7315af3840c918d9555264eda0d76fb Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 1 Jun 2026 12:59:42 +0200 Subject: [PATCH] Initial release: ComfyUI-UniverSR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching. - UniverSR Model Loader: presets auto-download to models/universr, plus local dir / raw .pth (from_local) loading, with caching. - UniverSR Super-Resolution: chunked overlap-add for long audio, per-channel stereo, seed control with global-RNG isolation, wet/dry blend, and an optional before/after spectrogram. - Vendors the universr inference package under vendor/ (prefers an installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq. Co-Authored-By: Claude Opus 4.8 --- .gitignore | 10 + LICENSE | 21 + README.md | 108 ++++ __init__.py | 10 + configs/config.yaml | 87 ++++ .../universr_super_resolution.json | 95 ++++ nodes.py | 199 ++++++++ pyproject.toml | 19 + requirements.txt | 8 + universr_wrapper.py | 422 ++++++++++++++++ vendor/universr/__init__.py | 4 + vendor/universr/flow/__init__.py | 0 vendor/universr/flow/loss.py | 9 + vendor/universr/flow/path.py | 54 ++ vendor/universr/flow/solver.py | 127 +++++ vendor/universr/inference.py | 351 +++++++++++++ vendor/universr/models/__init__.py | 0 vendor/universr/models/unet.py | 470 ++++++++++++++++++ vendor/universr/utils/__init__.py | 0 vendor/universr/utils/spectral_ops.py | 135 +++++ 20 files changed, 2129 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 __init__.py create mode 100644 configs/config.yaml create mode 100644 example_workflows/universr_super_resolution.json create mode 100644 nodes.py create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 universr_wrapper.py create mode 100644 vendor/universr/__init__.py create mode 100644 vendor/universr/flow/__init__.py create mode 100644 vendor/universr/flow/loss.py create mode 100644 vendor/universr/flow/path.py create mode 100644 vendor/universr/flow/solver.py create mode 100644 vendor/universr/inference.py create mode 100644 vendor/universr/models/__init__.py create mode 100644 vendor/universr/models/unet.py create mode 100644 vendor/universr/utils/__init__.py create mode 100644 vendor/universr/utils/spectral_ops.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b37ed5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.py[cod] +*.egg-info/ +.pytest_cache/ +.DS_Store +# anchored to repo root so the vendored universr/models/ package is NOT ignored +/models/ +/ckpts/ +*.wav +*.flac diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5c0fdcc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Woongjib Choi, DSPAI Lab, Yonsei University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..fed9e93 --- /dev/null +++ b/README.md @@ -0,0 +1,108 @@ +# ComfyUI-UniverSR + +ComfyUI nodes for **[UniverSR](https://github.com/woongzip1/UniverSR)** — *Unified and Versatile +Audio Super-Resolution via Vocoder-Free Flow Matching* (ICASSP 2026, +[arXiv:2510.00771](https://arxiv.org/abs/2510.00771)). + +A single model upscales **8 / 12 / 16 / 24 kHz** effective bandwidth → **48 kHz** across speech, +music and sound effects. It works directly in the complex‑STFT domain with flow matching — no neural +vocoder — and regenerates the missing high‑frequency band rather than just interpolating. + +![overview](https://raw.githubusercontent.com/woongzip1/UniverSR/master/assets/overview.png) + +--- + +## Nodes + +| Node | Output | Purpose | +|---|---|---| +| **UniverSR Model Loader** | `UNIVERSR_MODEL` | Loads + caches a checkpoint. Auto-downloads the presets to `models/universr/`. | +| **UniverSR Super-Resolution** | `AUDIO`, `IMAGE` | Runs the SR. Chunks long audio (click-free overlap-add). Optional before/after spectrogram. | + +Wire it up: + +``` +LoadAudio ─────────────┐ + ▼ +UniverSR Model Loader ─► UniverSR Super-Resolution ─► SaveAudio + └─ spectrogram ─► PreviewImage +``` + +### Model Loader +- **model** — `universr-audio` (general; music/SFX/mixed, recommended) or `universr-speech` (voice). + Each downloads ~230 MB to `models/universr/` on first use. Local checkpoint folders placed + in `models/universr/` also appear in this list. +- **device** — `auto` / `cuda` / `cpu`. +- **local_path** *(optional)* — override with a folder (`config.yaml` + `pytorch_model.bin`) or a raw + `.pth`/`.ckpt` training checkpoint. +- **config_path** *(optional)* — `config.yaml` for a raw checkpoint. Empty → the bundled default config. + +### Super-Resolution +- **input_sr** — the *effective bandwidth* of your content in Hz. The model treats everything up to + `input_sr/2` as valid and **regenerates above it**. + - `8000` → genuine low-rate audio (8 kHz → 48 kHz; the strongest, best-trained case). + - `16000` → brighten muffled but full-rate audio by regenerating only above 8 kHz (most natural). +- **ode_method** — `euler` (fastest) → `midpoint` (balanced) → `rk4` (best). +- **ode_steps** — flow-matching steps. `4` is fast and validated; `4–10` is a good range. +- **guidance_scale** — classifier-free guidance. Speech `1.0–1.5`, music `1.5–2.0`, SFX `~1.5`. + Higher = denser highs but less faithful. `0` disables CFG. +- **seed** — noise seed (`0` = random each run). +- **chunk_seconds** / **overlap_seconds** — long-audio handling (see below). `chunk_seconds=0` + processes the whole clip at once. +- **blend** — wet/dry mix. `1.0` = full SR. Lower keeps more of the original (handy for *bandwidth + extension* of already-48 kHz audio). +- **unload_model** — free VRAM after the run. +- **show_spectrogram** — also output a before/after spectrogram comparison `IMAGE`. + +--- + +## Long audio & chunking + +UniverSR runs the whole clip through a flow-matching ODE in one shot, which OOMs on long files +(the upstream notebook added chunking specifically to survive clips > 2 min). This node chunks in the +time domain and stitches the results with **overlap-add + linear crossfade** (weight-normalised), so +seams are click-free — an improvement over the upstream GUI's naive concatenation. Drop +`chunk_seconds` if you hit VRAM limits; raise `overlap_seconds` if you ever hear a seam. Stereo is +processed per-channel and preserved. + +> Compared to the `FoleyTune BWE` node (which brightens short foley clips and processes the whole clip +> at once), this node adds the chunking needed for arbitrarily long sequences. + +--- + +## Installation + +```bash +cd ComfyUI/custom_nodes +git clone ComfyUI-UniverSR +pip install -r ComfyUI-UniverSR/requirements.txt +``` + +The `universr` model code is **vendored** under `vendor/` (an installed `pip` copy is preferred if +present), so the only dependency beyond ComfyUI's stack is **`torchdiffeq`** (plus `einops`, `timm`, +`huggingface_hub`, `pyyaml`, which ComfyUI usually already has). Weights download automatically on +first use. + +--- + +## How it works (implementation note) + +ComfyUI audio arrives at an arbitrary real sample rate. UniverSR's *file* path relies on +`torchaudio.load` (fragile torchcodec backend), and its *tensor* path assumes the tensor is already at +`input_sr`. So this node does the band-limit itself: resample to 48 kHz → downsample each chunk to +`input_sr` (pure DSP, no codec) → hand UniverSR a genuine low-rate tensor to super-resolve. This +exactly reproduces the model's training-time degradation. + +## Credits & license + +UniverSR © Woongjib Choi et al., DSPAI Lab, Yonsei University — released under the MIT License +(see `LICENSE`). This node wrapper vendors the UniverSR inference code unmodified under `vendor/`. + +```bibtex +@inproceedings{choi2026universr, + title = {{UniverSR}: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching}, + author = {Choi, Woongjib and Lee, Sangmin and Lim, Hyungseob and Kang, Hong-Goo}, + booktitle = {IEEE ICASSP}, + year = {2026} +} +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..067ddb3 --- /dev/null +++ b/__init__.py @@ -0,0 +1,10 @@ +"""ComfyUI-UniverSR — vocoder-free audio super-resolution (8/12/16/24 kHz -> 48 kHz).""" + +try: + from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +except Exception as e: # surface import errors in the ComfyUI log without crashing startup + print(f"[ComfyUI-UniverSR] Failed to load nodes: {e}") + NODE_CLASS_MAPPINGS = {} + NODE_DISPLAY_NAME_MAPPINGS = {} + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..b25ca33 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,87 @@ +seed: 42 + +wandb: + project_name: "UniverSR" + entity: null # set to your wandb username or team + run_name: "audio" + notes: "" + +dataloader: + batch_size: 4 + num_workers: 4 + prefetch_factor: 2 + persistent_workers: True + pin_memory: True + +collator: + sampling_rates_probs: + 8: 0.7 + 12: 0.1 + 16: 0.1 + 24: 0.1 + validation_probs: + 8: 1.0 + +dataset: + common: + num_samples: 32767 + sr: 48000 + train: + file_list: "./data/train.txt" + val: + file_list: "./data/val.txt" + +path: + class_path: universr.flow.path.OriginalCFMPath + init_args: + sigma_min: 1.0e-4 + +transform: + window_fn: 'hann' + n_fft: 1024 + sampling_rate: 48000 + hop_length: 512 + alpha: 0.2 + beta: 1 + comp_eps: 1.0e-4 + +model: + in_channels: 2 + out_channels: 2 + dims: [96, 192, 384, 768] + depths: [2, 2, 4, 2] + drop_path: 0 + time_dim: 256 + cond_dim: 384 + total_freq_bins: 512 + hr_freq_bins: 432 + feature_enc_layers: 4 + cond_dropout_prob: 0.1 + sr_to_lr_bins: {8: 80, 12: 128, 16: 170, 24: 256} + +scheduler: + type: CosineLR + init_args: + num_warmup_steps: 10000 + num_training_steps: 5000000 + +optimizer: + lr: 2.0e-4 + betas: [0.9, 0.99] + +train: + num_epochs: 200 + max_steps: 5000000 + ckpt_save_dir: ./ckpts/audio/ + ckpt_load_path: null + log_step_interval: 1000 + val_step_interval: 50000 + num_val_log_samples: 5 + val_ode_steps: 4 + val_max_sec: 5 + +eval: + ode_steps: 4 + guidance_scale: 1.5 + max_batches: null + num_log_samples: 6 \ No newline at end of file diff --git a/example_workflows/universr_super_resolution.json b/example_workflows/universr_super_resolution.json new file mode 100644 index 0000000..b23411f --- /dev/null +++ b/example_workflows/universr_super_resolution.json @@ -0,0 +1,95 @@ +{ + "last_node_id": 5, + "last_link_id": 4, + "nodes": [ + { + "id": 1, + "type": "LoadAudio", + "pos": [120, 200], + "size": [320, 124], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + {"name": "AUDIO", "type": "AUDIO", "links": [1], "slot_index": 0} + ], + "properties": {"Node name for S&R": "LoadAudio"}, + "widgets_values": ["example.wav", null, ""] + }, + { + "id": 2, + "type": "UniverSRModelLoader", + "pos": [120, 380], + "size": [340, 130], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + {"name": "model", "type": "UNIVERSR_MODEL", "links": [2], "slot_index": 0} + ], + "properties": {"Node name for S&R": "UniverSRModelLoader"}, + "widgets_values": ["universr-audio", "auto", "", ""] + }, + { + "id": 3, + "type": "UniverSRSampler", + "pos": [540, 200], + "size": [340, 320], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + {"name": "audio", "type": "AUDIO", "link": 1}, + {"name": "model", "type": "UNIVERSR_MODEL", "link": 2} + ], + "outputs": [ + {"name": "audio", "type": "AUDIO", "links": [3], "slot_index": 0}, + {"name": "spectrogram", "type": "IMAGE", "links": [4], "slot_index": 1} + ], + "properties": {"Node name for S&R": "UniverSRSampler"}, + "widgets_values": [8000, "midpoint", 4, 1.5, 0, "randomize", 10.0, 0.5, 1.0, false, true] + }, + { + "id": 4, + "type": "PreviewAudio", + "pos": [940, 200], + "size": [320, 100], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + {"name": "audio", "type": "AUDIO", "link": 3} + ], + "outputs": [], + "properties": {"Node name for S&R": "PreviewAudio"}, + "widgets_values": [] + }, + { + "id": 5, + "type": "PreviewImage", + "pos": [940, 340], + "size": [320, 280], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + {"name": "images", "type": "IMAGE", "link": 4} + ], + "outputs": [], + "properties": {"Node name for S&R": "PreviewImage"}, + "widgets_values": [] + } + ], + "links": [ + [1, 1, 0, 3, 0, "AUDIO"], + [2, 2, 0, 3, 1, "UNIVERSR_MODEL"], + [3, 3, 0, 4, 0, "AUDIO"], + [4, 3, 1, 5, 0, "IMAGE"] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..ca9c5b3 --- /dev/null +++ b/nodes.py @@ -0,0 +1,199 @@ +"""ComfyUI-UniverSR nodes. + +Two-node design (mirrors the ComfyUI-Flash-AudioSR pattern): + UniverSRModelLoader -> UNIVERSR_MODEL (loads + caches weights, auto-downloads) + UniverSRSampler -> AUDIO, IMAGE (runs the super-resolution) +""" + +import torch + +from . import universr_wrapper as usr + +try: + import comfy.model_management as mm + HAS_COMFY = True +except Exception: # pragma: no cover + HAS_COMFY = False + + +def _default_device() -> str: + if HAS_COMFY: + try: + return "cuda" if mm.get_torch_device().type == "cuda" else "cpu" + except Exception: + pass + return "cuda" if torch.cuda.is_available() else "cpu" + + +# --------------------------------------------------------------------------- # +# Model loader +# --------------------------------------------------------------------------- # +class UniverSRModelLoader: + """Load a UniverSR checkpoint. Auto-downloads the presets on first use. + + Output: UNIVERSR_MODEL -> connect to UniverSR Super-Resolution. + """ + + DESCRIPTION = ("Load UniverSR (vocoder-free audio super-resolution, ICASSP 2026). " + "Presets auto-download to models/universr on first use.") + CATEGORY = "audio/UniverSR" + + @classmethod + def INPUT_TYPES(cls): + choices = list(usr.HF_REPOS.keys()) + usr.list_local_models() + return { + "required": { + "model": (choices, { + "default": choices[0], + "tooltip": "universr-audio = general (music/SFX/mixed, recommended); " + "universr-speech = voice only. Both download (~230 MB) on first use. " + "Local checkpoint folders in models/universr also appear here.", + }), + "device": (["auto", "cuda", "cpu"], { + "default": "auto", + "tooltip": "Device to load the model onto.", + }), + }, + "optional": { + "local_path": ("STRING", { + "default": "", + "tooltip": "Override: a folder with config.yaml + pytorch_model.bin, " + "or a raw .pth/.ckpt file (uses config_path or the bundled config).", + }), + "config_path": ("STRING", { + "default": "", + "tooltip": "config.yaml for a raw checkpoint given in local_path. " + "Leave empty to use the bundled default config.", + }), + }, + } + + RETURN_TYPES = ("UNIVERSR_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load" + + def load(self, model, device, local_path="", config_path=""): + dev = _default_device() if device == "auto" else device + if dev == "cuda" and not torch.cuda.is_available(): + print("[UniverSR] CUDA unavailable, falling back to CPU") + dev = "cpu" + model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path) + return ({"model": model_obj, "device": dev, "cache_key": cache_key},) + + @classmethod + def IS_CHANGED(cls, model, device, local_path="", config_path=""): + return f"{model}:{device}:{local_path}:{config_path}" + + +# --------------------------------------------------------------------------- # +# Sampler +# --------------------------------------------------------------------------- # +class UniverSRSampler: + """Super-resolve audio to 48 kHz with UniverSR. Long clips are processed in + overlapping chunks (click-free overlap-add) to stay within VRAM.""" + + DESCRIPTION = ("Upscale low-bandwidth audio to 48 kHz with UniverSR. Pick input_sr to " + "match the effective bandwidth of your content (the model regenerates " + "everything above input_sr/2).") + CATEGORY = "audio/UniverSR" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio": ("AUDIO", {}), + "model": ("UNIVERSR_MODEL", {}), + "input_sr": ([8000, 12000, 16000, 24000], { + "default": 8000, + "tooltip": "Effective input bandwidth (Hz). Content is treated as valid up to " + "input_sr/2 and regenerated above it. 8000 = genuine low-rate audio " + "(strongest, 8 kHz->48 kHz). 16000 = brighten muffled audio above 8 kHz.", + }), + }, + "optional": { + "ode_method": (["midpoint", "euler", "rk4"], { + "default": "midpoint", + "tooltip": "ODE solver. euler (fastest) -> midpoint (balanced) -> rk4 (best).", + }), + "ode_steps": ("INT", { + "default": 4, "min": 1, "max": 64, "step": 1, + "tooltip": "Flow-matching integration steps. 4 is fast and validated; 4-10 is a good range.", + }), + "guidance_scale": ("FLOAT", { + "default": 1.5, "min": 0.0, "max": 6.0, "step": 0.25, + "tooltip": "Classifier-free guidance. Speech 1.0-1.5, music 1.5-2.0, SFX ~1.5. " + "Higher = denser highs but less faithful. 0 disables CFG.", + }), + "seed": ("INT", { + "default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "Noise seed for the flow-matching source. 0 = random each run.", + }), + "chunk_seconds": ("FLOAT", { + "default": 10.0, "min": 0.0, "max": 120.0, "step": 0.5, + "tooltip": "Process long audio in chunks of this length (seconds) to avoid OOM. " + "0 = process the whole clip at once.", + }), + "overlap_seconds": ("FLOAT", { + "default": 0.5, "min": 0.0, "max": 5.0, "step": 0.1, + "tooltip": "Crossfade overlap between chunks (seconds). Prevents seam clicks.", + }), + "blend": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "Wet/dry mix. 1.0 = full super-resolution. Lower to keep more of the " + "original (useful when brightening already-48 kHz audio).", + }), + "unload_model": ("BOOLEAN", { + "default": False, + "tooltip": "Free the model from VRAM after this run.", + }), + "show_spectrogram": ("BOOLEAN", { + "default": True, + "tooltip": "Also output a before/after spectrogram comparison image.", + }), + }, + } + + RETURN_TYPES = ("AUDIO", "IMAGE") + RETURN_NAMES = ("audio", "spectrogram") + FUNCTION = "run" + + def run(self, audio, model, input_sr, ode_method="midpoint", ode_steps=4, + guidance_scale=1.5, seed=0, chunk_seconds=10.0, overlap_seconds=0.5, + blend=1.0, unload_model=False, show_spectrogram=True): + + model_obj = model["model"] + waveform, sr = usr.comfy_audio_to_tensor(audio) + dur = waveform.shape[-1] / max(sr, 1) + print(f"[UniverSR] {tuple(waveform.shape)} @ {sr} Hz ({dur:.2f}s) -> 48 kHz | " + f"input_sr={input_sr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}") + + out, dry48 = usr.super_resolve( + model_obj, waveform, sr, int(input_sr), + ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=guidance_scale, + seed=int(seed), chunk_seconds=float(chunk_seconds), + overlap_seconds=float(overlap_seconds), blend=float(blend), + ) + + audio_out = usr.tensor_to_comfy_audio(out, usr.TARGET_SR) + + spec = torch.zeros(1, 64, 64, 3) + if show_spectrogram: + in_mono = dry48[0].mean(0).numpy() + out_mono = out[0].mean(0).numpy() + spec = usr.make_spectrogram_image(in_mono, out_mono, int(input_sr)) + + if unload_model: + usr.evict_model(model["cache_key"]) + + print(f"[UniverSR] Done -> {out.shape[-1] / usr.TARGET_SR:.2f}s at 48 kHz") + return (audio_out, spec) + + +NODE_CLASS_MAPPINGS = { + "UniverSRModelLoader": UniverSRModelLoader, + "UniverSRSampler": UniverSRSampler, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "UniverSRModelLoader": "UniverSR Model Loader", + "UniverSRSampler": "UniverSR Super-Resolution", +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fd6fef7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "comfyui-universr" +description = "ComfyUI nodes for UniverSR — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching." +version = "1.0.0" +license = {file = "LICENSE"} +dependencies = [ + "torchdiffeq>=0.2.3", + "einops>=0.7", + "timm>=0.9", + "huggingface_hub>=0.20", + "pyyaml>=6.0", +] + +[project.urls] +Repository = "https://github.com/woongzip1/UniverSR" +Paper = "https://arxiv.org/abs/2510.00771" + +[tool.comfy] +DisplayName = "UniverSR" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2cf1644 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# ComfyUI-UniverSR runtime deps. +# torch / torchaudio / numpy / matplotlib are already shipped by ComfyUI. +# The vendored `universr` package only needs these extras on top of ComfyUI's stack: +torchdiffeq>=0.2.3 +einops>=0.7 +timm>=0.9 +huggingface_hub>=0.20 +pyyaml>=6.0 diff --git a/universr_wrapper.py b/universr_wrapper.py new file mode 100644 index 0000000..bc7a2b3 --- /dev/null +++ b/universr_wrapper.py @@ -0,0 +1,422 @@ +"""Core wrapper for ComfyUI-UniverSR. + +Bootstraps the `universr` package (prefers a pip-installed copy, falls back to +the vendored one under ./vendor), manages model loading/caching, and runs the +super-resolution itself with optional overlap-add chunking for long audio. + +UniverSR (ICASSP 2026) is a vocoder-free audio super-resolution model that +regenerates high-frequency content in the complex-STFT domain via flow matching. +A single model handles 8 / 12 / 16 / 24 kHz effective bandwidth -> 48 kHz. + +Key design note — why we resample ourselves instead of handing UniverSR a file: + UniverSR's `enhance()` file path calls `torchaudio.load`, whose torchcodec + backend is fragile across environments; its *tensor* path assumes the tensor + is already at `input_sr`. ComfyUI audio arrives at an arbitrary real sample + rate, so we do the band-limit ourselves: resample to 48 kHz, downsample each + chunk to `input_sr` (pure DSP, no codec), and hand UniverSR a genuine + low-rate tensor to super-resolve. This reproduces the exact training-time + degradation and was validated in the FoleyTune BWE node. +""" + +import os +import threading + +import numpy as np +import torch +import torchaudio + +# --------------------------------------------------------------------------- # +# Optional ComfyUI integration (degrade gracefully outside ComfyUI / in tests) +# --------------------------------------------------------------------------- # +try: + import comfy.model_management as mm + import comfy.utils + HAS_COMFY = True +except Exception: # pragma: no cover - allows standalone import / pytest + HAS_COMFY = False + +try: + import folder_paths + HAS_FOLDER_PATHS = True +except Exception: # pragma: no cover + HAS_FOLDER_PATHS = False + + +TARGET_SR = 48_000 +SUPPORTED_INPUT_SR = (8000, 12000, 16000, 24000) +# UniverSR.enhance() zero-pads anything shorter than this (≈0.68 s @ 48 kHz) before +# running the ODE, so chunks below it just waste compute — clamp to it. +MODEL_MIN_SAMPLES = 32_768 +_NODE_DIR = os.path.dirname(os.path.abspath(__file__)) +_VENDOR_DIR = os.path.join(_NODE_DIR, "vendor") +_BUNDLED_CONFIG = os.path.join(_NODE_DIR, "configs", "config.yaml") + +# HuggingFace repos for the two released checkpoints. +HF_REPOS = { + "universr-audio": "woongzip1/universr-audio", + "universr-speech": "woongzip1/universr-speech", +} + + +# --------------------------------------------------------------------------- # +# Package bootstrap +# --------------------------------------------------------------------------- # +def get_universr_cls(): + """Return the `UniverSR` class, preferring an installed copy over the vendored one.""" + try: + from universr import UniverSR # installed (e.g. via the FoleyTune node) + return UniverSR + except Exception: + pass + import sys + if _VENDOR_DIR not in sys.path: + sys.path.insert(0, _VENDOR_DIR) + try: + from universr import UniverSR # vendored fallback + return UniverSR + except Exception as e: # pragma: no cover + raise RuntimeError( + "Could not import the 'universr' package (neither installed nor vendored). " + "Try: pip install torchdiffeq (the only dependency ComfyUI does not already ship).\n" + f"Underlying error: {e}" + ) + + +# --------------------------------------------------------------------------- # +# Model directory + cache +# --------------------------------------------------------------------------- # +def get_models_dir() -> str: + if HAS_FOLDER_PATHS: + base = folder_paths.models_dir + else: + base = os.path.join(_NODE_DIR, "..", "..", "models") + return os.path.abspath(os.path.join(base, "universr")) + + +def list_local_models() -> list: + """Subdirectories of models/universr that look like a UniverSR checkpoint dir.""" + root = get_models_dir() + found = [] + if os.path.isdir(root): + for name in sorted(os.listdir(root)): + d = os.path.join(root, name) + if os.path.isdir(d) and os.path.exists(os.path.join(d, "config.yaml")) \ + and os.path.exists(os.path.join(d, "pytorch_model.bin")): + if name not in HF_REPOS: + found.append(name) + return found + + +_MODEL_CACHE: dict = {} +_CACHE_LOCK = threading.Lock() + + +def _download_preset(name: str) -> str: + """Download a preset checkpoint into models/universr/ and return that dir.""" + from huggingface_hub import snapshot_download + repo_id = HF_REPOS[name] + target = os.path.join(get_models_dir(), name) + have = os.path.exists(os.path.join(target, "config.yaml")) and \ + os.path.exists(os.path.join(target, "pytorch_model.bin")) + if not have: + os.makedirs(target, exist_ok=True) + print(f"[UniverSR] Downloading {repo_id} -> {target} (~230 MB)...") + snapshot_download( + repo_id=repo_id, + local_dir=target, + allow_patterns=["config.yaml", "pytorch_model.bin"], + ) + print(f"[UniverSR] Downloaded {name}.") + return target + + +def resolve_model_ref(model: str, local_path: str = "") -> tuple: + """Resolve the loader inputs to (kind, path). kind in {'dir', 'ckpt'}. + + - local_path wins if set: a directory (config.yaml + pytorch_model.bin) -> 'dir'; + a .pth/.pt/.ckpt file -> 'ckpt' (loaded via from_local with a config). + - a preset name ('universr-audio' / 'universr-speech') -> download -> 'dir'. + - a local subdir name discovered under models/universr -> 'dir'. + """ + local_path = (local_path or "").strip() + if local_path: + if os.path.isdir(local_path): + return ("dir", local_path) + if os.path.isfile(local_path): + return ("ckpt", local_path) + raise FileNotFoundError(f"local_path does not exist: {local_path}") + + if model in HF_REPOS: + return ("dir", _download_preset(model)) + + cand = os.path.join(get_models_dir(), model) + if os.path.isdir(cand): + return ("dir", cand) + raise FileNotFoundError( + f"Unknown model '{model}'. Use a preset {list(HF_REPOS)}, a local subdir of " + f"{get_models_dir()}, or set local_path." + ) + + +def load_model(model: str, device: str, local_path: str = "", config_path: str = ""): + """Load (and cache) a UniverSR model. Returns (model_obj, cache_key).""" + kind, path = resolve_model_ref(model, local_path) + cache_key = f"{kind}:{os.path.abspath(path)}:{device}" + + with _CACHE_LOCK: + if cache_key in _MODEL_CACHE: + print(f"[UniverSR] Using cached model ({cache_key})") + return _MODEL_CACHE[cache_key], cache_key + + UniverSR = get_universr_cls() + if kind == "dir": + print(f"[UniverSR] Loading from_pretrained({path}) on {device}") + model_obj = UniverSR.from_pretrained(path, device=device) + else: + cfg = (config_path or "").strip() or _BUNDLED_CONFIG + if not os.path.exists(cfg): + raise FileNotFoundError( + f"config_path required for a raw checkpoint and not found: {cfg}" + ) + print(f"[UniverSR] Loading from_local(ckpt={path}, config={cfg}) on {device}") + model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device) + + model_obj.eval() + n = sum(p.numel() for p in model_obj.parameters()) / 1e6 + print(f"[UniverSR] Ready - {n:.1f}M params on {device}") + _MODEL_CACHE[cache_key] = model_obj + return model_obj, cache_key + + +def evict_model(cache_key: str): + import gc + with _CACHE_LOCK: + _MODEL_CACHE.pop(cache_key, None) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print(f"[UniverSR] Model unloaded ({cache_key})") + + +# --------------------------------------------------------------------------- # +# Audio helpers +# --------------------------------------------------------------------------- # +def comfy_audio_to_tensor(audio) -> tuple: + """ComfyUI AUDIO (dict or legacy tuple) -> (waveform [B, C, T] float32 cpu, sr).""" + if isinstance(audio, dict): + waveform, sr = audio["waveform"], audio["sample_rate"] + else: + waveform, sr = audio + if not isinstance(waveform, torch.Tensor): + waveform = torch.as_tensor(waveform) + waveform = waveform.detach().float().cpu() + if waveform.dim() == 1: # (T,) + waveform = waveform[None, None, :] + elif waveform.dim() == 2: # (C, T) + waveform = waveform[None, :, :] + return waveform, int(sr) + + +def tensor_to_comfy_audio(waveform: torch.Tensor, sr: int) -> dict: + if waveform.dim() == 1: + waveform = waveform[None, None, :] + elif waveform.dim() == 2: + waveform = waveform[None, :, :] + return {"waveform": waveform.detach().cpu().contiguous(), "sample_rate": int(sr)} + + +def _resample(x: torch.Tensor, orig: int, target: int) -> torch.Tensor: + if orig == target: + return x + return torchaudio.functional.resample(x, orig, target) + + +def _fit(x: torch.Tensor, n: int) -> torch.Tensor: + """Crop or zero-pad a 1-D tensor to exactly n samples.""" + if x.shape[-1] == n: + return x + if x.shape[-1] > n: + return x[:n] + return torch.nn.functional.pad(x, (0, n - x.shape[-1])) + + +def _crossfade_window(length: int, ov: int, first: bool, last: bool) -> torch.Tensor: + """Linear fade-in/out over the overlap regions; flat 1.0 elsewhere. + + Combined with weight-sum normalisation this gives click-free overlap-add. + """ + w = torch.ones(length) + if ov > 0: + f = min(ov, length) + if not first: + w[:f] = torch.minimum(w[:f], torch.linspace(0.0, 1.0, f)) + if not last: + w[-f:] = torch.minimum(w[-f:], torch.linspace(1.0, 0.0, f)) + return w + + +# --------------------------------------------------------------------------- # +# Inference +# --------------------------------------------------------------------------- # +@torch.no_grad() +def _enhance_segment(model, seg48: torch.Tensor, input_sr: int, + ode_method: str, ode_steps: int, guidance_scale) -> torch.Tensor: + """Super-resolve one 48 kHz mono segment. Returns 1-D tensor @48 kHz on CPU.""" + low = _resample(seg48.unsqueeze(0), TARGET_SR, input_sr).squeeze(0) # genuine LR-rate signal + cfg = float(guidance_scale) if (guidance_scale and guidance_scale > 0) else None + out = model.enhance( + low, input_sr=int(input_sr), + ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=cfg, + ) + return out.reshape(-1).float().cpu() + + +def _chunk_starts(total: int, chunk: int, hop: int) -> list: + if chunk <= 0 or total <= chunk: + return [0] + starts = list(range(0, max(1, total - chunk) + 1, hop)) + if starts[-1] + chunk < total: + starts.append(total - chunk) + return starts + + +@torch.no_grad() +def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps, + guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor: + T = ch48.shape[-1] + if chunk <= 0 or T <= chunk: + if pbar is not None: + pbar.update(1) + return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T) + + hop = max(1, chunk - ov) + starts = _chunk_starts(T, chunk, hop) + out = torch.zeros(T) + wsum = torch.zeros(T) + for i, s in enumerate(starts): + if HAS_COMFY: + mm.throw_exception_if_processing_interrupted() + e = min(s + chunk, T) + enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s) + w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T)) + out[s:e] += enh * w + wsum[s:e] += w + if pbar is not None: + pbar.update(1) + return out / torch.clamp(wsum, min=1e-8) + + +@torch.no_grad() +def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int, + ode_method: str = "midpoint", ode_steps: int = 4, + guidance_scale=1.5, seed: int = 0, + chunk_seconds: float = 10.0, overlap_seconds: float = 0.5, + blend: float = 1.0): + """Run UniverSR over a [B, C, T] waveform. Returns (out [B, C, T48], dry48 [B, C, T48]).""" + if int(input_sr) not in SUPPORTED_INPUT_SR: + raise ValueError(f"input_sr must be one of {SUPPORTED_INPUT_SR}, got {input_sr}") + + waveform = waveform.float().cpu() + if waveform.dim() != 3: + raise ValueError(f"Expected a [B, C, T] waveform, got shape {tuple(waveform.shape)}") + B, C, _ = waveform.shape + dry48 = _resample(waveform, sr, TARGET_SR) # [B, C, T48] + T48 = dry48.shape[-1] + if T48 == 0: # empty input — nothing to do + empty = torch.zeros(B, C, 0) + return empty, empty + + chunk = int(round(chunk_seconds * TARGET_SR)) if (chunk_seconds and chunk_seconds > 0) else 0 + if 0 < chunk < MODEL_MIN_SAMPLES: + print(f"[UniverSR] chunk_seconds too small; raising to the model floor " + f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).") + chunk = MODEL_MIN_SAMPLES + ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0 + n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1 + + pbar = comfy.utils.ProgressBar(B * C * n_per_ch) if HAS_COMFY else None + + # Isolate the global RNG: snapshot, seed, run, restore. Without this the model's + # torch.randn_like noise would advance (and a fixed seed would freeze) the global + # generator that downstream ComfyUI nodes rely on. seed=0 → fresh OS entropy. + cpu_rng = torch.get_rng_state() + cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + actual_seed = int(seed) if (seed and int(seed) != 0) else int.from_bytes(os.urandom(8), "little") + try: + torch.manual_seed(actual_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(actual_seed) + wet = torch.zeros(B, C, T48) + for b in range(B): + for c in range(C): + wet[b, c] = _fit( + _enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps, + guidance_scale, chunk, ov, pbar), + T48, + ) + finally: + torch.set_rng_state(cpu_rng) + if cuda_rng is not None: + torch.cuda.set_rng_state_all(cuda_rng) + + blend = float(blend) + out = wet if blend >= 1.0 else (1.0 - blend) * dry48 + blend * wet + return out.clamp(-1.0, 1.0), dry48 + + +# --------------------------------------------------------------------------- # +# Spectrogram comparison (optional IMAGE output) +# --------------------------------------------------------------------------- # +def _stft_db(x: np.ndarray) -> np.ndarray: + t = torch.from_numpy(np.ascontiguousarray(x)).float() + win = torch.hann_window(1024) + spec = torch.stft(t, n_fft=1024, hop_length=512, window=win, return_complex=True) + db = 20.0 * torch.log10(spec.abs().clamp(min=1e-5)) + db = db - db.max() + return db.numpy() + + +def make_spectrogram_image(input48_mono: np.ndarray, output48_mono: np.ndarray, + input_sr: int) -> torch.Tensor: + """Before/after spectrogram comparison -> IMAGE tensor [1, H, W, 3] in [0, 1]. + + Left panel is the band-limited input (content valid up to input_sr/2); right + panel is the 48 kHz output. The dashed line marks the LR Nyquist, so the + regenerated high-frequency band is the energy above it on the right. + """ + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + # Visualise the band-limit the model actually saw, not the raw container. + lr = torch.from_numpy(np.ascontiguousarray(input48_mono)).float()[None] + lr = _resample(_resample(lr, TARGET_SR, int(input_sr)), int(input_sr), TARGET_SR).squeeze(0).numpy() + n = min(len(lr), len(output48_mono), int(8.0 * TARGET_SR)) + lr, hr = lr[:n], output48_mono[:n] + nyq = int(input_sr) / 2.0 + + fig, axes = plt.subplots(1, 2, figsize=(12, 4.0), facecolor="#0d0f16") + for ax, sig, title, cmap in ( + (axes[0], lr, f"Input (<= {int(input_sr)//1000} kHz)", "magma"), + (axes[1], hr, "UniverSR output (48 kHz)", "viridis"), + ): + db = _stft_db(sig) + ax.imshow(db, origin="lower", aspect="auto", cmap=cmap, + extent=[0, n / TARGET_SR, 0, TARGET_SR / 2], vmin=-80, vmax=0) + ax.axhline(nyq, color="w", ls="--", lw=0.8, alpha=0.6) + ax.set_title(title, color="#cfe0ff", fontsize=10) + ax.set_xlabel("Time (s)", color="#7a93bd", fontsize=8) + ax.set_ylabel("Hz", color="#7a93bd", fontsize=8) + ax.tick_params(colors="#5a6e90", labelsize=7) + ax.set_facecolor("#0d0f16") + fig.tight_layout() + + fig.canvas.draw() + # np.asarray(buffer_rgba()) yields (H, W, 4) at the real pixel size — robust to HiDPI. + img = np.asarray(fig.canvas.buffer_rgba())[..., :3].astype(np.float32) / 255.0 + plt.close(fig) + return torch.from_numpy(np.ascontiguousarray(img))[None] + except Exception as e: # matplotlib missing / headless edge cases + print(f"[UniverSR] Spectrogram render skipped: {e}") + return torch.zeros(1, 64, 64, 3) diff --git a/vendor/universr/__init__.py b/vendor/universr/__init__.py new file mode 100644 index 0000000..3905b64 --- /dev/null +++ b/vendor/universr/__init__.py @@ -0,0 +1,4 @@ +from universr.inference import UniverSR + +__version__ = "0.1.0" +__all__ = ["UniverSR"] diff --git a/vendor/universr/flow/__init__.py b/vendor/universr/flow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/universr/flow/loss.py b/vendor/universr/flow/loss.py new file mode 100644 index 0000000..3ba3cfe --- /dev/null +++ b/vendor/universr/flow/loss.py @@ -0,0 +1,9 @@ +import torch +import torch.nn.functional as F + + +def flow_matching_loss(predicted_vf: torch.Tensor, target_vf: torch.Tensor) -> torch.Tensor: + """ + Flow matching loss; L2 loss between estimated and target vector field. + """ + return F.mse_loss(predicted_vf, target_vf) diff --git a/vendor/universr/flow/path.py b/vendor/universr/flow/path.py new file mode 100644 index 0000000..c2a71fa --- /dev/null +++ b/vendor/universr/flow/path.py @@ -0,0 +1,54 @@ +import importlib +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn + +class ConditionalProbabilityPath(nn.Module, ABC): + """Abstract base class for conditional probability paths in flow matching.""" + + @abstractmethod + def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor: + """Sample from the source distribution. shape_ref is used only for shape/device.""" + + @abstractmethod + def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Interpolate between source x0 and target x1 at time t.""" + + @abstractmethod + def get_target_vector_field( + self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor + ) -> torch.Tensor: + """Compute the target vector field u_t(xt | x1).""" + +class OriginalCFMPath(ConditionalProbabilityPath): + def __init__(self, sigma_min: float = 1e-4): + super().__init__() + self.sigma_min = sigma_min + + def sample_source(self, shape_ref): + return torch.randn_like(shape_ref) + + def sample_xt(self, x0, x1, t): + return t * x1 + (1 - t + self.sigma_min * t) * x0 + + def get_target_vector_field(self, xt, x0, x1, t): + return x1 - (1 - self.sigma_min) * x0 + +def get_path(config): + class_path = config.get("class_path") + + if not class_path: + raise ValueError("Configuration must contain a 'class_path' key") + try: + module_path, class_name = class_path.rsplit(".", 1) + except ValueError: + raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one") + + module = importlib.import_module(module_path) + Class = getattr(module, class_name) + init_args = config.get("init_args", {}) + return Class(**init_args) + + \ No newline at end of file diff --git a/vendor/universr/flow/solver.py b/vendor/universr/flow/solver.py new file mode 100644 index 0000000..203b832 --- /dev/null +++ b/vendor/universr/flow/solver.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod + +import torch +from torchdiffeq import odeint +from tqdm import tqdm + +from universr.models.unet import ConditionalVectorFieldModel + +class ODE(ABC): + @abstractmethod + def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Returns the drift coefficient of the ODE. + Args: + - xt: state at time t, shape (bs, c, h, w) + - t: time, shape (bs, 1) + Returns: + - drift_coefficient: shape (bs, c, h, w) + """ + pass + +class Solver(ABC): + # @abstractmethod + def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs): + """ + Takes one simulation step + Args: + - xt: state at time t, shape (bs, c, h, w) + - t: time, shape (bs, 1, 1, 1) + - dt: time, shape (bs, 1, 1, 1) + Returns: + - nxt: state at time t + dt (bs, c, h, w) + """ + pass + + @torch.no_grad() + def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs): + """ + Simulates using the discretization gives by ts + Args: + - x_init: initial state, shape (bs, c, h, w) + - ts: timesteps, shape (bs, nts, 1, 1, 1) + Returns: + - x_final: final state at time ts[-1], shape (bs, c, h, w) + """ + nts = ts.shape[1] + for t_idx in tqdm(range(nts - 1)): + t = ts[:, t_idx] + h = ts[:, t_idx + 1] - ts[:, t_idx] + x = self.step(x, t, h, **kwargs) + return x + + @torch.no_grad() + def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs): + """ + Simulates using the discretization gives by ts + Args: + - x: initial state, shape (bs, c, h, w) + - ts: timesteps, shape (bs, nts, 1, 1, 1) + Returns: + - xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w) + """ + xs = [x.clone()] + nts = ts.shape[1] + for t_idx in tqdm(range(nts - 1)): + t = ts[:,t_idx] + h = ts[:, t_idx + 1] - ts[:, t_idx] + x = self.step(x, t, h, **kwargs) + xs.append(x.clone()) + return torch.stack(xs, dim=1) + +class VectorFieldODE(ODE): + def __init__(self, net:ConditionalVectorFieldModel) -> None: + super().__init__() + self.net = net + + def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: + return self.net(xt, t, y, **kwargs) + +class CFGVectorFieldODE(ODE): + """ For Classifier Free Guidance """ + def __init__(self, net:ConditionalVectorFieldModel, guidance_scale: float = 1.0) -> None: + super().__init__() + self.net = net + self.guidance_scale = guidance_scale + + def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: + guided_vector_field = self.net(xt, t, y, **kwargs) + unguided_vector_field = self.net(xt, t, None, **kwargs) + + return (1-self.guidance_scale) * unguided_vector_field + self.guidance_scale * guided_vector_field + +class EulerSolver(Solver): + def __init__(self, ode: ODE): + self.ode = ode + + def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs): + return xt + self.ode.drift_coefficient(xt,t, **kwargs) * h + +class TorchDiffeqSolver(Solver): + def __init__(self, + ode: ODE, + method: str = 'euler', + atol: float = 1e-5, + rtol: float = 1e-5, + ): + super().__init__() + self.ode = ode + self.method = method + self.atol = atol + self.rtol = rtol + + @torch.no_grad() + def simulate(self, x_init: torch.Tensor, ts: torch.Tensor, **kwargs): + """ + x_init: [B,C,H,W] + ts: [N] + return: final state [B,C,H,W] + """ + func = lambda t, x: self.ode.drift_coefficient(xt=x, t=t, **kwargs) + + xs = odeint( + func=func, + y0=x_init, t=ts, + method=self.method, + atol=self.atol, rtol=self.rtol) # [N,B,C,H,W] + return xs[-1] \ No newline at end of file diff --git a/vendor/universr/inference.py b/vendor/universr/inference.py new file mode 100644 index 0000000..cbff71c --- /dev/null +++ b/vendor/universr/inference.py @@ -0,0 +1,351 @@ +""" +UniverSR: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching +Inference wrapper module. +""" + +import os +from typing import Optional, Union + +import numpy as np +import torch +import torchaudio +import yaml +from huggingface_hub import hf_hub_download + +from universr.models.unet import ConvNeXtUNetCond +from universr.flow.path import OriginalCFMPath +from universr.flow.solver import CFGVectorFieldODE, VectorFieldODE, TorchDiffeqSolver +from universr.utils.spectral_ops import AmplitudeCompressedComplexSTFT + + +# Supported input sample rates (kHz) and their corresponding LR frequency bins +SUPPORTED_INPUT_SR = {8000, 12000, 16000, 24000} +TARGET_SR = 48000 + + +class UniverSR(torch.nn.Module): + """ + UniverSR inference wrapper. + + Performs audio super-resolution from low sample rates (8/12/16/24 kHz) + to 48 kHz using vocoder-free flow matching in the complex STFT domain. + + Example: + >>> model = UniverSR.from_pretrained("woongzip1/universr-speech") + >>> output = model.enhance("input.wav", input_sr=16000) + >>> torchaudio.save("output.wav", output.cpu(), 48000) + """ + + def __init__( + self, + model: ConvNeXtUNetCond, + transform: AmplitudeCompressedComplexSTFT, + path: OriginalCFMPath, + device: str = "cuda", + ): + super().__init__() + self.model = model + self.transform = transform + self.path = path + self._device = device + + @classmethod + def from_pretrained( + cls, + repo_id_or_path: str, + device: str = "cuda", + revision: Optional[str] = None, + ) -> "UniverSR": + """ + Load a pretrained UniverSR model. + + Args: + repo_id_or_path: HuggingFace repo ID (e.g. "woongzip1/universr-speech") + or local directory path containing config.yaml and pytorch_model.bin. + device: Device to load the model on. + revision: Optional HuggingFace revision (branch, tag, or commit hash). + + Returns: + UniverSR instance ready for inference. + """ + if os.path.isdir(repo_id_or_path): + config_path = os.path.join(repo_id_or_path, "config.yaml") + model_path = os.path.join(repo_id_or_path, "pytorch_model.bin") + else: + config_path = hf_hub_download( + repo_id=repo_id_or_path, filename="config.yaml", revision=revision + ) + model_path = hf_hub_download( + repo_id=repo_id_or_path, filename="pytorch_model.bin", revision=revision + ) + + # Load config + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + # Build model + model = ConvNeXtUNetCond(**config["model"]) + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict) + model.to(device).eval() + + # Build transform + transform = AmplitudeCompressedComplexSTFT(**config["transform"]) + transform.to(device) + + # Build probability path + path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4}) + path = OriginalCFMPath(**path_args) + + return cls(model=model, transform=transform, path=path, device=device) + + @classmethod + def from_local( + cls, + ckpt_path: str, + config_path: str, + device: str = "cuda", + ) -> "UniverSR": + """ + Load UniverSR from a local checkpoint (e.g. training checkpoint with optimizer state). + + This handles the standard training checkpoint format where weights are stored + under the 'model_state_dict' key, as opposed to from_pretrained() which expects + a clean state_dict saved as pytorch_model.bin. + + Args: + ckpt_path: Path to checkpoint file (.pth). + config_path: Path to YAML config file. + device: Device to load the model on. + + Returns: + UniverSR instance ready for inference. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + model = ConvNeXtUNetCond(**config["model"]) + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + # Handle both formats: raw state_dict or training checkpoint + if "model_state_dict" in ckpt: + model.load_state_dict(ckpt["model_state_dict"]) + else: + model.load_state_dict(ckpt) + model.to(device).eval() + + transform = AmplitudeCompressedComplexSTFT(**config["transform"]) + transform.to(device) + + path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4}) + path = OriginalCFMPath(**path_args) + + return cls(model=model, transform=transform, path=path, device=device) + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + @torch.no_grad() + def enhance( + self, + audio: Union[str, torch.Tensor, np.ndarray], + input_sr: Optional[int] = None, + target_sr: int = TARGET_SR, + ode_method: str = "midpoint", + ode_steps: int = 4, + guidance_scale: Optional[float] = 1.5, + ) -> torch.Tensor: + """ + Enhance a low-resolution audio signal to high-resolution. + + Args: + audio: Input audio. Can be: + - str: path to a .wav file + - torch.Tensor: waveform tensor of shape (T,), (1, T), or (1, 1, T) + - np.ndarray: waveform array + input_sr: Effective bandwidth of the input in Hz (e.g. 8000, 16000). + For file input: auto-detected from the file's native sample rate + if it matches a supported rate (8/12/16/24 kHz). Required if the + file is already at 48 kHz but has limited bandwidth. + For tensor/array input: always required. + target_sr: Target sample rate in Hz. Default: 48000. + ode_method: ODE solver method. One of 'euler', 'midpoint', 'rk4'. + ode_steps: Number of ODE integration steps. + guidance_scale: Classifier-free guidance scale. None or 0 disables CFG. + + Returns: + Enhanced waveform tensor of shape (1,T) at target_sr. + """ + # Load audio + wav, file_sr = self._load_audio(audio, input_sr=input_sr) + wav = wav.to(self._device) + + # Determine the effective bandwidth SR + effective_sr = input_sr if input_sr is not None else file_sr + + if effective_sr not in SUPPORTED_INPUT_SR: + if effective_sr == target_sr and input_sr is None: + raise ValueError( + f"Input audio is already at {target_sr} Hz. " + f"Please specify input_sr to indicate the effective bandwidth " + f"(e.g., input_sr=16000). Supported: {sorted(SUPPORTED_INPUT_SR)}" + ) + raise ValueError( + f"Effective input sample rate {effective_sr} Hz is not supported. " + f"Supported rates: {sorted(SUPPORTED_INPUT_SR)}" + ) + + # Prepare the 48 kHz LR input for the model + if file_sr == target_sr: + # Simulate the training degradation: downsample → upsample to match + wav = self._apply_bandwidth_limit(wav, effective_sr, target_sr) + elif file_sr != target_sr: + # File is truly low-resolution; resample up to 48 kHz + wav = torchaudio.functional.resample(wav, orig_freq=file_sr, new_freq=target_sr) + + # Minimum length guard + MIN_SAMPLES = 32_768 + original_len = wav.shape[-1] + wav = torch.nn.functional.pad(wav, (0, max(0, MIN_SAMPLES - wav.shape[-1]))) + + # Ensure shape is [B, C, T] = [1, 1, T] + if wav.dim() == 1: + wav = wav.unsqueeze(0).unsqueeze(0) + elif wav.dim() == 2: + wav = wav.unsqueeze(0) + + sr_khz = effective_sr // 1000 + + # Run flow matching SR + output = self._inference(wav, sr_khz, ode_method, ode_steps, guidance_scale) + + # (1,T) + return output[..., :original_len] + + # ------------------------------------------------------------------ # + # Internal methods # + # ------------------------------------------------------------------ # + + def _load_audio( + self, audio: Union[str, torch.Tensor, np.ndarray], input_sr: Optional[int] = None, + ) -> tuple: + """ + Load and validate audio input. + + Returns: + (waveform, file_sr): The waveform tensor and its *actual* sample rate. + """ + if isinstance(audio, str): + wav, file_sr = torchaudio.load(audio) + # Mix to mono if stereo + if wav.shape[0] > 1: + wav = wav.mean(dim=0, keepdim=True) + return wav, file_sr + + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio).float() + + if isinstance(audio, torch.Tensor): + if input_sr is None: + raise ValueError("input_sr is required when passing a tensor or array.") + return audio.float(), input_sr + + raise TypeError(f"Unsupported audio type: {type(audio)}") + + def _apply_bandwidth_limit( + self, wav: torch.Tensor, effective_sr: int, target_sr: int, + ) -> torch.Tensor: + """ + Simulate low-resolution input from a high-sample-rate waveform. + + Applies the same downsample-then-upsample pipeline used during training + (see WaveformCollator._apply_lpf) so that the spectral cutoff pattern + matches what the model expects. + + Args: + wav: Waveform at target_sr. Shape: (1, T) or (T,). + effective_sr: The effective bandwidth in Hz (e.g. 8000). + target_sr: The native sample rate of wav (e.g. 48000). + + Returns: + Bandwidth-limited waveform at target_sr, same length as input. + """ + original_len = wav.shape[-1] + lr = torchaudio.functional.resample(wav, orig_freq=target_sr, new_freq=effective_sr) + lr = torchaudio.functional.resample(lr, orig_freq=effective_sr, new_freq=target_sr) + return lr[..., :original_len] + + def _preprocess(self, waveform: torch.Tensor) -> torch.Tensor: + """ + Convert waveform to amplitude-compressed complex STFT representation. + [B, C, T] -> [B, 2, F-1, T_frames] (real/imag channels, drop Nyquist bin) + """ + spec = self.transform(waveform) # [B, C, F, T_frames] complex + real = torch.view_as_real(spec.squeeze(1)) # [B, F, T_frames, 2] + real = real.permute(0, 3, 1, 2) # [B, 2, F, T_frames] + return real[:, :, :-1, :] # drop Nyquist bin + + def _postprocess(self, spec: torch.Tensor) -> torch.Tensor: + """ + Convert STFT representation back to waveform. + [B, 2, F-1, T_frames] -> [B, T] + """ + spec = torch.nn.functional.pad(spec, [0, 0, 0, 1], value=0) # restore Nyquist + spec = spec.permute(0, 2, 3, 1).contiguous() # [B, F, T, 2] + spec = torch.view_as_complex(spec) # [B, F, T] complex + waveform = self.transform.invert(spec) # [B, T] + return waveform + + def _inference( + self, + lr_audio: torch.Tensor, + sr_khz: int, + ode_method: str, + ode_steps: int, + guidance_scale: Optional[float], + ) -> torch.Tensor: + """ + Core inference pipeline: + 1. STFT the (resampled) LR audio + 2. Extract LR condition bins + 3. Sample noise for HF region + 4. Solve ODE (flow matching) + 5. Concatenate LR + generated HF + 6. iSTFT to waveform + """ + # Frequency bin bookkeeping + lr_bin_count = self.model.sr_to_lr_bins[sr_khz] + hf_start_bin = self.model.total_freq_bins - self.model.hr_freq_bins + + # STFT + Y = self._preprocess(lr_audio) # [B, 2, F-1, T] + Y_lr = Y[:, :, :lr_bin_count, :] # LR condition + Y_hr = Y[:, :, hf_start_bin:, :] # HR target region (for shape reference) + + # Initial noise + x0 = self.path.sample_source(Y_hr).to(self._device) + + # Build ODE solver + if guidance_scale is not None and guidance_scale > 0: + ode = CFGVectorFieldODE(net=self.model, guidance_scale=guidance_scale) + else: + ode = VectorFieldODE(net=self.model) + solver = TorchDiffeqSolver(ode, method=ode_method) + + # Time discretization + ts = torch.linspace(0, 1, ode_steps + 1, device=self._device) + + # Solve ODE + x1_spec = solver.simulate( + x0, ts=ts, y=Y_lr, sr_values=torch.tensor([sr_khz], device=self._device) + ) + + # Concatenate LR bins + generated HF bins (handle overlapping region) + slice_start = max(0, lr_bin_count - hf_start_bin) + x1_spec = x1_spec[:, :, slice_start:, :] + full_spec = torch.cat([Y_lr, x1_spec], dim=2) + + # iSTFT + output = self._postprocess(full_spec) + return output \ No newline at end of file diff --git a/vendor/universr/models/__init__.py b/vendor/universr/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/universr/models/unet.py b/vendor/universr/models/unet.py new file mode 100644 index 0000000..ac9d667 --- /dev/null +++ b/vendor/universr/models/unet.py @@ -0,0 +1,470 @@ +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ + +class ConditionalVectorFieldModel(nn.Module, ABC): + """ + Base class for DNN-based VF model + MLP-parameterization of the learned vector field u_t^theta(x) + """ + + @abstractmethod + def forward(self, x:torch.Tensor, t:torch.Tensor, y:torch.Tensor): + """ + Args: + - x: (bs, c, h, w) + - t: (bs, 1, 1, 1) + - y: (bs,) + Returns: + - u_t^theta(x|y): (bs, c, h, w) + """ + pass + +class SinusoidalTimeEmbedding(nn.Module): + """ + Based on https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/karras_unet.py#L183 + & DiffWave / WaveFM + """ + def __init__(self, dim: int=128, mode: str='learnable', time_scale=1): + super().__init__() + assert dim % 2 == 0, "Dimension must be an even number" + assert mode in ['fixed', 'learnable'], "Mode must be 'fixed' or 'learnable'" + + self.dim = dim # D + self.half_dim = dim // 2 + self.mode = mode + self.time_scale = time_scale # 1(diffusion) or 100(flow) + + if self.mode == 'learnable': + self.weights = nn.Parameter(torch.randn(1, self.half_dim)) # [1,D/2] + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """ + Args: + - t: Time tensor. Shape can be [B] or [B, 1]. + Returns: + - embeddings: Time embeddings of shape [B, D] + """ + # Ensure t has shape [B, 1] for broadcasting + t = t.view(-1, 1) + device = t.device + + if self.mode == 'fixed': + # Create a sequence from 0 to D/2 - 1 + pos = torch.arange(self.half_dim, device=device).unsqueeze(0) # [1,D/2] + freqs = self.time_scale * t * 10.0 ** (pos * 4.0 / (self.half_dim - 1)) # 100 is a magnitude hyperparameter + + sin_embed = torch.sin(freqs) + cos_embed = torch.cos(freqs) + + return torch.cat([sin_embed, cos_embed], dim=-1) + + elif self.mode == 'learnable': + freqs = t * self.weights * 2 * math.pi + + sin_embed = torch.sin(freqs) + cos_embed = torch.cos(freqs) + + return torch.cat([sin_embed, cos_embed], dim=-1) * math.sqrt(2) + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + +class LayerNorm(nn.Module): + """ LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class Block(nn.Module): + """ ConvNeXt V2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, padding_mode="reflect") + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.grn = GRN(4 * dim) # GRN for V2 + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + # This Block preserves the input shape (C, H, W) -> (C, H, W) + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # [N,C,H,W] -> [N,H,W,C] + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # [N,H,W,C] -> [N,C,H,W] + + x = input + self.drop_path(x) # Residual connection + return x + +class BlockWithEmbedding(nn.Module): + """ ConvNeXt block with time embedding injection + """ + def __init__(self, dim, drop_path=0., time_embed_dim=128): + super().__init__() + self.block = Block(dim, drop_path) + self.time_adapter = nn.Sequential( + nn.Linear(time_embed_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, dim), + ) + def forward(self, x, t_embed): + t_embed = self.time_adapter(t_embed).unsqueeze(-1).unsqueeze(-1) # [B,C,1,1] + x = x + t_embed + x = self.block(x) + + return x + +class EncoderBlock(nn.Module): + def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim): + super().__init__() + self.blocks= nn.ModuleList( + [BlockWithEmbedding(dim_in, drop_path, time_embed_dim) + for _ in range(num_blocks)] + ) + self.downsampler = nn.Sequential( + LayerNorm(dim_in, eps=1e-6, data_format="channels_first"), + nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2), + ) + + def forward(self, x, t_emb): + for block in self.blocks: + x = block(x, t_emb) + x = self.downsampler(x) + return x + +class Midcoder(nn.Module): + def __init__(self, dim, num_blocks, drop_path, time_embed_dim): + super().__init__() + self.blocks = nn.ModuleList( + [BlockWithEmbedding(dim, drop_path, time_embed_dim) + for _ in range(num_blocks)] + ) + + def forward(self, x, t_emb): + for block in self.blocks: + x = block(x, t_emb) + return x + +class DecoderBlock(nn.Module): + def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim): + super().__init__() + self.upsampler = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=2, stride=2) + self.blocks = nn.ModuleList( + [BlockWithEmbedding(dim_out, drop_path, time_embed_dim) + for _ in range(num_blocks)] + ) + def forward(self, x, t_emb): + x = self.upsampler(x) + for block in self.blocks: + x = block(x, t_emb) + return x + +class ConditioningEncoder2D(nn.Module): + def __init__(self, cond_dim, num_blocks=3): + """ + Args: + cond_dim (int): The main conditioning dimension (D). + num_blocks (int): The number of shared 2D ConvNeXt blocks. + """ + super().__init__() + self.cond_dim = cond_dim + self.film_generator = nn.Linear(cond_dim, 4) + self.head = nn.Conv2d(2, cond_dim, kernel_size=1) + self.sr_adapter = nn.Sequential( + nn.Linear(cond_dim, cond_dim), + nn.GELU(), + nn.Linear(cond_dim, cond_dim * 2) + ) + self.blocks = nn.Sequential(*[ + Block(dim=cond_dim) for _ in range(num_blocks) + ]) + self.freq_pool = nn.AdaptiveAvgPool2d((1,None)) + + def forward(self, y_lr, f_emb_lr, sr_emb): + """ + Args: + y_lr (Tensor): LR Spec [B, 2, F1, T] + f_emb : Freq positional embedding for lr spec [F1,D] + sr_emb: Sampling rate embedding [B,D] + Returns: + z (Tensor): Conditioning Emb [B, D, T] + """ + film_params = self.film_generator(f_emb_lr) # [F1, 4] + gamma, beta = torch.chunk(film_params, chunks=2, dim=-1) # [F1,2] + gamma = rearrange(gamma, 'f c -> 1 c f 1') # [1,2,F1,1] + beta = rearrange(beta, 'f c -> 1 c f 1') # [1,2,F1,1] + z = y_lr * gamma + beta # [B, 2, F1, T] + z = self.head(z) # [B,D,F1,T] + + sr_film_params = self.sr_adapter(sr_emb) # [B, 2*D] + sr_gamma, sr_beta = torch.chunk(sr_film_params, 2, dim=-1) # [B,D] + sr_gamma = sr_gamma.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1] + sr_beta = sr_beta.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1] + z = z * sr_gamma + sr_beta # [B,D,F1,T] + z = self.blocks(z) # [B,D,F1,T] + z = self.freq_pool(z).squeeze(2) # [B,D,T] + return z + +class FrequencyPositionalEmbedding(nn.Module): + def __init__(self, num_bins: int, emb_dim: int): + super().__init__() + # (F, D) + pe = torch.zeros(num_bins, emb_dim) + position = torch.arange(num_bins, dtype=torch.float32).unsqueeze(1) # (F,1) + div_term = torch.exp( + torch.arange(0, emb_dim, 2, dtype=torch.float32) * + -(math.log(10000.0) / emb_dim) + ) # (D/2,) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self): + # returns (F, D) + return self.pe + +class ConvNeXtUNetCond(ConditionalVectorFieldModel): + def __init__(self, in_channels=2, out_channels=2, + dims=[64,128,256,512], depths=[2,2,2,4], + drop_path=0., time_dim=128, + cond_dim=256, # D1 + total_freq_bins=512, + hr_freq_bins=432, + feature_enc_layers=10, + cond_dropout_prob=0.1, + sr_to_lr_bins={8: 80, 12: 128, 16: 170, 24: 256}, + ): + super().__init__() + self.strides = 2**len(dims) + self.time_embedder = SinusoidalTimeEmbedding(dim=time_dim) + self.total_freq_bins = total_freq_bins + self.hr_freq_bins = hr_freq_bins + self.sr_to_lr_bins = sr_to_lr_bins + self.sr_values_list = sorted(list(sr_to_lr_bins.keys())) # (8,12,16,24) kHz + self.sr_to_idx = {sr: i for i, sr in enumerate(self.sr_values_list)} + self.sr_embedder = nn.Embedding(len(self.sr_values_list), cond_dim) # [4,D] + self.cond_dropout_prob = cond_dropout_prob + self.cond_dim = cond_dim + self.uncond_emb = nn.Parameter(torch.randn(cond_dim)) + self.sr_projector = nn.Linear(cond_dim, time_dim) # projector to t_emb + + self.freq_pos_enc = FrequencyPositionalEmbedding(num_bins=total_freq_bins, emb_dim=cond_dim) + self.film_generator = nn.Linear(cond_dim, cond_dim * 2) + + self.conditioning_encoder = ConditioningEncoder2D( + cond_dim=cond_dim, + num_blocks=feature_enc_layers, + ) + + self.init_conv = nn.Sequential( + nn.Conv2d(in_channels+cond_dim, dims[0], kernel_size=1), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + + # Encoder + for i in range(len(depths)): + dim_in = dims[i] + dim_out = dims[i+1] if i+1 < len(dims) else dims[i] + self.encoders.append(EncoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim)) + + # Midcoder + self.midcoder = Midcoder(dims[-1], depths[-1], drop_path, time_dim) + + # Decoder + for i in reversed(range(len(depths))): + dim_in = dims[i+1] if i+1 < len(dims) else dims[i] + dim_out = dims[i] + self.decoders.append(DecoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim)) + + self.final_conv = nn.Conv2d(dims[0], out_channels, kernel_size=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def _pad_frames(self, x): + num_frames = x.shape[-1] + pad_len = (self.strides - num_frames % self.strides) % self.strides + if pad_len: + x = torch.nn.functional.pad(x, [0,pad_len,0,0], mode='reflect') + assert x.shape[-1] % self.strides == 0, \ + f"After padding, time dim:{x.shape(-1)} must be multiples of {self.strides}" + return x, pad_len + + def forward(self, x, t, y, sr_values): + """ + x : x_t noisy spec [B,2,F,T] + t : time embedding [B,1] or [B] + y : condition lr spectrum [B,2,F,T] + sr_values: input sampling_rate [B] or [1] + """ + # Pad logic + x, pad_len = self._pad_frames(x) + if pad_len > 0 and y is not None: + y = torch.nn.functional.pad(y, [0, pad_len, 0, 0], mode='reflect') + B, _, F, T = x.shape + + # get number of lr bins for input sr + if isinstance(sr_values, int): + current_sr = sr_values + else: + current_sr = sr_values[0].item() if hasattr(sr_values[0], 'item') else sr_values[0] + + lr_bin_count = self.sr_to_lr_bins[current_sr] + + # freq pe + pe_full = self.freq_pos_enc() # [F,D] + pe_low = pe_full[:lr_bin_count,:] # [F1,D] + hf_start_bin = self.total_freq_bins - self.hr_freq_bins # 512 - 432 + pe_high = pe_full[hf_start_bin:, :] # [F2=432,D] + + # time / sr embedding + t_embed = self.time_embedder(t) # [B,timedim] + sr_idx = self.sr_to_idx[current_sr] + sr_emb = self.sr_embedder(torch.tensor([sr_idx], device=x.device)).expand(B,-1) # [B, D] + t_embed = t_embed + self.sr_projector(sr_emb) # [B, timedim] + + if y is not None: # (Training) + y_cond_real = self.conditioning_encoder(y, pe_low, sr_emb) # [B,D,T] + # Uncond token masking + if self.training and self.cond_dropout_prob > 0: + # random mask for uncond + mask = (torch.rand(B, device=x.device) < self.cond_dropout_prob) # [B] + uncond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T) # [B,D,T] + y_cond = torch.where(mask.reshape(B,1,1), uncond, y_cond_real) + else: + y_cond = y_cond_real + else: # Unconditional (inference) + y_cond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T) + + y_cond = y_cond.unsqueeze(2) # [B,D,1,T] + + # FiLM Conditioning of freq-bins + film_params = self.film_generator(pe_high) # [F2,D] -> [F2,2D] + gamma_high, beta_high = torch.chunk(film_params, chunks=2, dim=-1) # [F2, D] + gamma_high = rearrange(gamma_high, 'f d -> 1 d f 1') # [1,D,F2,1] + beta_high = rearrange(beta_high, 'f d -> 1 d f 1') # [1,D,F2,1] + spatial_cond = y_cond * gamma_high + beta_high # [B,D,F2,T] + + x = torch.cat([x, spatial_cond], dim=1) # [B,2+D,F2,T] + + x = self.init_conv(x) + skip_connections = [x] + + for encoder in self.encoders: + x = encoder(x, t_embed) + skip_connections.append(x) + + x = self.midcoder(x, t_embed) + + for decoder in self.decoders: + skip = skip_connections.pop() + if x.shape != skip.shape: + x = nn.functional.interpolate(x, size=skip.shape[2:]) + x = x + skip + x = decoder(x, t_embed) + + skip = skip_connections.pop() + x = x + skip + x = self.final_conv(x) + + # Crop out + if pad_len: + x = x[...,:-pad_len] + return x + +def main(): + """ + Dummy forward pass test for ConvNeXtUNetCond. + """ + from torchinfo import summary + + batch_size = 2 + hr_freq_bins = 432 # High-res bins to be generated (fixed) + lr_freq_bins = 128 # Low-res bins for this specific test case (e.g., for 8kHz) + T = 256 # Number of time frames + + sr_config = {8: 80, 12: 128, 16: 170, 24: 256} + + model = ConvNeXtUNetCond( + in_channels=2, + out_channels=2, + dims=[96, 192, 384, 768], + depths=[2, 2, 4, 2], + time_dim=256, + cond_dim=384, + total_freq_bins=512, + hr_freq_bins=hr_freq_bins, + feature_enc_layers=4, + cond_dropout_prob=0.1, + sr_to_lr_bins=sr_config, # Pass the dictionary + ) + + x = torch.randn(batch_size, 2, hr_freq_bins, T) + y = torch.randn(batch_size, 2, lr_freq_bins, T) + t = torch.randint(0, 1000, (batch_size,)) + sr_values = [12] * batch_size + + print("\n--- Model Summary ---") + summary( + model, + input_data=[x, t, y, sr_values], + depth=4, + col_names=("input_size", "output_size", "num_params", + "kernel_size", "mult_adds", "trainable"), + verbose=1 + ) +if __name__ == "__main__": + main() diff --git a/vendor/universr/utils/__init__.py b/vendor/universr/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/universr/utils/spectral_ops.py b/vendor/universr/utils/spectral_ops.py new file mode 100644 index 0000000..c131d7a --- /dev/null +++ b/vendor/universr/utils/spectral_ops.py @@ -0,0 +1,135 @@ +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor + +class InvertibleFeatureExtractor(nn.Module, ABC): + """ + An invertible feature extractor, i.e. a one-to-one mapping that has a forward and a true inverse. + It should hold up to numerical error that `extractor.invert(extractor(x)) == x`. + """ + @abstractmethod + def forward(self, x, **kwargs): + pass + + @abstractmethod + def invert(self, x, **kwargs): + pass + + def analysis_synthesis(self, x, **kwargs): + return self.invert(self.forward(x, **kwargs), **kwargs) + +class AmplitudeCompressedComplexSTFT(InvertibleFeatureExtractor): + """ + A convenient composition of ComplexSTFT() and CompressAmplitudesAndScale(). + """ + def __init__( + self, + window_fn, n_fft, sampling_rate, + alpha, beta, comp_eps, + hop_length=None, n_hops=None, + learnable_window=False, + *args, **kwargs, + ): + super().__init__(*args, **kwargs) + self.complex_stft = ComplexSTFT( + window_fn, n_fft, sampling_rate, hop_length=hop_length, n_hops=n_hops, + learnable_window=learnable_window, + ) + self.compress = CompressAmplitudesAndScale( + compression_exponent=alpha, + scale_factor=beta, + comp_eps=comp_eps, + ) + + def forward(self, x: Tensor, **kwargs): + X = self.complex_stft(x, **kwargs) + out = self.compress(X, **kwargs) + return out + + def invert(self, X: Tensor, **kwargs): + X = self.compress.invert(X, **kwargs) + x = self.complex_stft.invert(X, **kwargs) + return x + + +class ComplexSTFT(InvertibleFeatureExtractor): + def __init__( + self, window_fn, n_fft, sampling_rate, hop_length=None, n_hops=None, learnable_window=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + assert (hop_length is not None) ^ (n_hops is not None),\ + "Exactly one of {hop_length, n_hops} must be specified!" + if hop_length is None: + hop_length = int(math.ceil(n_fft / n_hops)) + + window_fn = getattr(torch.signal.windows, window_fn) + self.learnable_window = learnable_window + self.window = nn.Parameter(window_fn(n_fft), requires_grad=learnable_window) + self.n_fft = n_fft + self.hop_length = hop_length + self.sampling_rate = sampling_rate + self.center = True + + def forward(self, x: Tensor, **kwargs): + """Assumes x is an audio tensor of shape [B, C, T] or [B, T] + + [B,C,T] -> [B,C,F,T] + [B,C,T] -> [B,F,T] + + """ + bc = "b c" if x.ndim == 3 else "b" + X = torch.stft( + rearrange(x, f"{bc} t -> ({bc}) t"), n_fft=self.n_fft, hop_length=self.hop_length, + window=self.window.to(x.device), center=self.center, + onesided=True, return_complex=True, + ) + X = rearrange(X, f"({bc}) f t -> {bc} f t", b=x.shape[0]) + return X + + def invert(self, X: Tensor, orig_length: Optional[int] = None, **kwargs): + """Assumes X is a (complex) spectrogram tensor of shape [B, C, F, T] or [B, F, T]""" + bc = "b c" if X.ndim == 4 else "b" + x = torch.istft( + rearrange(X, f"{bc} f t -> ({bc}) f t"), n_fft=self.n_fft, hop_length=self.hop_length, + window=self.window.to(X.device), center=self.center, + onesided=True, return_complex=False, + length=orig_length, + ) + x = rearrange(x, f"({bc}) t -> {bc} t", b=X.shape[0]) + return x + +class CompressAmplitudesAndScale(InvertibleFeatureExtractor): + def __init__(self, compression_exponent: float, scale_factor: float, comp_eps: float, *args, **kwargs): + super().__init__() + self.compression_exponent = compression_exponent + self.scale_factor = scale_factor + self.comp_eps = comp_eps + + def forward(self, X: Tensor, **kwargs): + """ + Assumes X is a complex STFT (complex spectrogram). + """ + alpha = self.compression_exponent + beta = self.scale_factor + if alpha != 1: + X = X + self.comp_eps + X = X.abs()**alpha * torch.exp(1j * X.angle()) + return X * beta + + def invert(self, X: Tensor, **kwargs): + """ + Assumes X is an amplitude-compressed and scaled complex STFT. + """ + alpha = self.compression_exponent + beta = self.scale_factor + X = X / beta + if alpha != 1: + X = X.abs()**(1/alpha) * torch.exp(1j * X.angle()) + return X + +