Initial release: ComfyUI-UniverSR

ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 12:59:42 +02:00
commit 5f29b225b7
20 changed files with 2129 additions and 0 deletions
+10
View File
@@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*.egg-info/
.pytest_cache/
.DS_Store
# anchored to repo root so the vendored universr/models/ package is NOT ignored
/models/
/ckpts/
*.wav
*.flac
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Woongjib Choi, DSPAI Lab, Yonsei University
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+108
View File
@@ -0,0 +1,108 @@
# ComfyUI-UniverSR
ComfyUI nodes for **[UniverSR](https://github.com/woongzip1/UniverSR)** — *Unified and Versatile
Audio Super-Resolution via Vocoder-Free Flow Matching* (ICASSP 2026,
[arXiv:2510.00771](https://arxiv.org/abs/2510.00771)).
A single model upscales **8 / 12 / 16 / 24 kHz** effective bandwidth → **48 kHz** across speech,
music and sound effects. It works directly in the complexSTFT domain with flow matching — no neural
vocoder — and regenerates the missing highfrequency band rather than just interpolating.
![overview](https://raw.githubusercontent.com/woongzip1/UniverSR/master/assets/overview.png)
---
## Nodes
| Node | Output | Purpose |
|---|---|---|
| **UniverSR Model Loader** | `UNIVERSR_MODEL` | Loads + caches a checkpoint. Auto-downloads the presets to `models/universr/`. |
| **UniverSR Super-Resolution** | `AUDIO`, `IMAGE` | Runs the SR. Chunks long audio (click-free overlap-add). Optional before/after spectrogram. |
Wire it up:
```
LoadAudio ─────────────┐
UniverSR Model Loader ─► UniverSR Super-Resolution ─► SaveAudio
└─ spectrogram ─► PreviewImage
```
### Model Loader
- **model** — `universr-audio` (general; music/SFX/mixed, recommended) or `universr-speech` (voice).
Each downloads ~230 MB to `models/universr/<name>` on first use. Local checkpoint folders placed
in `models/universr/` also appear in this list.
- **device** — `auto` / `cuda` / `cpu`.
- **local_path** *(optional)* — override with a folder (`config.yaml` + `pytorch_model.bin`) or a raw
`.pth`/`.ckpt` training checkpoint.
- **config_path** *(optional)*`config.yaml` for a raw checkpoint. Empty → the bundled default config.
### Super-Resolution
- **input_sr** — the *effective bandwidth* of your content in Hz. The model treats everything up to
`input_sr/2` as valid and **regenerates above it**.
- `8000` → genuine low-rate audio (8 kHz → 48 kHz; the strongest, best-trained case).
- `16000` → brighten muffled but full-rate audio by regenerating only above 8 kHz (most natural).
- **ode_method** — `euler` (fastest) → `midpoint` (balanced) → `rk4` (best).
- **ode_steps** — flow-matching steps. `4` is fast and validated; `410` is a good range.
- **guidance_scale** — classifier-free guidance. Speech `1.01.5`, music `1.52.0`, SFX `~1.5`.
Higher = denser highs but less faithful. `0` disables CFG.
- **seed** — noise seed (`0` = random each run).
- **chunk_seconds** / **overlap_seconds** — long-audio handling (see below). `chunk_seconds=0`
processes the whole clip at once.
- **blend** — wet/dry mix. `1.0` = full SR. Lower keeps more of the original (handy for *bandwidth
extension* of already-48 kHz audio).
- **unload_model** — free VRAM after the run.
- **show_spectrogram** — also output a before/after spectrogram comparison `IMAGE`.
---
## Long audio & chunking
UniverSR runs the whole clip through a flow-matching ODE in one shot, which OOMs on long files
(the upstream notebook added chunking specifically to survive clips > 2 min). This node chunks in the
time domain and stitches the results with **overlap-add + linear crossfade** (weight-normalised), so
seams are click-free — an improvement over the upstream GUI's naive concatenation. Drop
`chunk_seconds` if you hit VRAM limits; raise `overlap_seconds` if you ever hear a seam. Stereo is
processed per-channel and preserved.
> Compared to the `FoleyTune BWE` node (which brightens short foley clips and processes the whole clip
> at once), this node adds the chunking needed for arbitrarily long sequences.
---
## Installation
```bash
cd ComfyUI/custom_nodes
git clone <this repo> ComfyUI-UniverSR
pip install -r ComfyUI-UniverSR/requirements.txt
```
The `universr` model code is **vendored** under `vendor/` (an installed `pip` copy is preferred if
present), so the only dependency beyond ComfyUI's stack is **`torchdiffeq`** (plus `einops`, `timm`,
`huggingface_hub`, `pyyaml`, which ComfyUI usually already has). Weights download automatically on
first use.
---
## How it works (implementation note)
ComfyUI audio arrives at an arbitrary real sample rate. UniverSR's *file* path relies on
`torchaudio.load` (fragile torchcodec backend), and its *tensor* path assumes the tensor is already at
`input_sr`. So this node does the band-limit itself: resample to 48 kHz → downsample each chunk to
`input_sr` (pure DSP, no codec) → hand UniverSR a genuine low-rate tensor to super-resolve. This
exactly reproduces the model's training-time degradation.
## Credits & license
UniverSR © Woongjib Choi et al., DSPAI Lab, Yonsei University — released under the MIT License
(see `LICENSE`). This node wrapper vendors the UniverSR inference code unmodified under `vendor/`.
```bibtex
@inproceedings{choi2026universr,
title = {{UniverSR}: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching},
author = {Choi, Woongjib and Lee, Sangmin and Lim, Hyungseob and Kang, Hong-Goo},
booktitle = {IEEE ICASSP},
year = {2026}
}
```
+10
View File
@@ -0,0 +1,10 @@
"""ComfyUI-UniverSR — vocoder-free audio super-resolution (8/12/16/24 kHz -> 48 kHz)."""
try:
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
except Exception as e: # surface import errors in the ComfyUI log without crashing startup
print(f"[ComfyUI-UniverSR] Failed to load nodes: {e}")
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
+87
View File
@@ -0,0 +1,87 @@
seed: 42
wandb:
project_name: "UniverSR"
entity: null # set to your wandb username or team
run_name: "audio"
notes: ""
dataloader:
batch_size: 4
num_workers: 4
prefetch_factor: 2
persistent_workers: True
pin_memory: True
collator:
sampling_rates_probs:
8: 0.7
12: 0.1
16: 0.1
24: 0.1
validation_probs:
8: 1.0
dataset:
common:
num_samples: 32767
sr: 48000
train:
file_list: "./data/train.txt"
val:
file_list: "./data/val.txt"
path:
class_path: universr.flow.path.OriginalCFMPath
init_args:
sigma_min: 1.0e-4
transform:
window_fn: 'hann'
n_fft: 1024
sampling_rate: 48000
hop_length: 512
alpha: 0.2
beta: 1
comp_eps: 1.0e-4
model:
in_channels: 2
out_channels: 2
dims: [96, 192, 384, 768]
depths: [2, 2, 4, 2]
drop_path: 0
time_dim: 256
cond_dim: 384
total_freq_bins: 512
hr_freq_bins: 432
feature_enc_layers: 4
cond_dropout_prob: 0.1
sr_to_lr_bins: {8: 80, 12: 128, 16: 170, 24: 256}
scheduler:
type: CosineLR
init_args:
num_warmup_steps: 10000
num_training_steps: 5000000
optimizer:
lr: 2.0e-4
betas: [0.9, 0.99]
train:
num_epochs: 200
max_steps: 5000000
ckpt_save_dir: ./ckpts/audio/
ckpt_load_path: null
log_step_interval: 1000
val_step_interval: 50000
num_val_log_samples: 5
val_ode_steps: 4
val_max_sec: 5
eval:
ode_steps: 4
guidance_scale: 1.5
max_batches: null
num_log_samples: 6
@@ -0,0 +1,95 @@
{
"last_node_id": 5,
"last_link_id": 4,
"nodes": [
{
"id": 1,
"type": "LoadAudio",
"pos": [120, 200],
"size": [320, 124],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{"name": "AUDIO", "type": "AUDIO", "links": [1], "slot_index": 0}
],
"properties": {"Node name for S&R": "LoadAudio"},
"widgets_values": ["example.wav", null, ""]
},
{
"id": 2,
"type": "UniverSRModelLoader",
"pos": [120, 380],
"size": [340, 130],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{"name": "model", "type": "UNIVERSR_MODEL", "links": [2], "slot_index": 0}
],
"properties": {"Node name for S&R": "UniverSRModelLoader"},
"widgets_values": ["universr-audio", "auto", "", ""]
},
{
"id": 3,
"type": "UniverSRSampler",
"pos": [540, 200],
"size": [340, 320],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{"name": "audio", "type": "AUDIO", "link": 1},
{"name": "model", "type": "UNIVERSR_MODEL", "link": 2}
],
"outputs": [
{"name": "audio", "type": "AUDIO", "links": [3], "slot_index": 0},
{"name": "spectrogram", "type": "IMAGE", "links": [4], "slot_index": 1}
],
"properties": {"Node name for S&R": "UniverSRSampler"},
"widgets_values": [8000, "midpoint", 4, 1.5, 0, "randomize", 10.0, 0.5, 1.0, false, true]
},
{
"id": 4,
"type": "PreviewAudio",
"pos": [940, 200],
"size": [320, 100],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{"name": "audio", "type": "AUDIO", "link": 3}
],
"outputs": [],
"properties": {"Node name for S&R": "PreviewAudio"},
"widgets_values": []
},
{
"id": 5,
"type": "PreviewImage",
"pos": [940, 340],
"size": [320, 280],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{"name": "images", "type": "IMAGE", "link": 4}
],
"outputs": [],
"properties": {"Node name for S&R": "PreviewImage"},
"widgets_values": []
}
],
"links": [
[1, 1, 0, 3, 0, "AUDIO"],
[2, 2, 0, 3, 1, "UNIVERSR_MODEL"],
[3, 3, 0, 4, 0, "AUDIO"],
[4, 3, 1, 5, 0, "IMAGE"]
],
"groups": [],
"config": {},
"extra": {},
"version": 0.4
}
+199
View File
@@ -0,0 +1,199 @@
"""ComfyUI-UniverSR nodes.
Two-node design (mirrors the ComfyUI-Flash-AudioSR pattern):
UniverSRModelLoader -> UNIVERSR_MODEL (loads + caches weights, auto-downloads)
UniverSRSampler -> AUDIO, IMAGE (runs the super-resolution)
"""
import torch
from . import universr_wrapper as usr
try:
import comfy.model_management as mm
HAS_COMFY = True
except Exception: # pragma: no cover
HAS_COMFY = False
def _default_device() -> str:
if HAS_COMFY:
try:
return "cuda" if mm.get_torch_device().type == "cuda" else "cpu"
except Exception:
pass
return "cuda" if torch.cuda.is_available() else "cpu"
# --------------------------------------------------------------------------- #
# Model loader
# --------------------------------------------------------------------------- #
class UniverSRModelLoader:
"""Load a UniverSR checkpoint. Auto-downloads the presets on first use.
Output: UNIVERSR_MODEL -> connect to UniverSR Super-Resolution.
"""
DESCRIPTION = ("Load UniverSR (vocoder-free audio super-resolution, ICASSP 2026). "
"Presets auto-download to models/universr on first use.")
CATEGORY = "audio/UniverSR"
@classmethod
def INPUT_TYPES(cls):
choices = list(usr.HF_REPOS.keys()) + usr.list_local_models()
return {
"required": {
"model": (choices, {
"default": choices[0],
"tooltip": "universr-audio = general (music/SFX/mixed, recommended); "
"universr-speech = voice only. Both download (~230 MB) on first use. "
"Local checkpoint folders in models/universr also appear here.",
}),
"device": (["auto", "cuda", "cpu"], {
"default": "auto",
"tooltip": "Device to load the model onto.",
}),
},
"optional": {
"local_path": ("STRING", {
"default": "",
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
"or a raw .pth/.ckpt file (uses config_path or the bundled config).",
}),
"config_path": ("STRING", {
"default": "",
"tooltip": "config.yaml for a raw checkpoint given in local_path. "
"Leave empty to use the bundled default config.",
}),
},
}
RETURN_TYPES = ("UNIVERSR_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load"
def load(self, model, device, local_path="", config_path=""):
dev = _default_device() if device == "auto" else device
if dev == "cuda" and not torch.cuda.is_available():
print("[UniverSR] CUDA unavailable, falling back to CPU")
dev = "cpu"
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
@classmethod
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
return f"{model}:{device}:{local_path}:{config_path}"
# --------------------------------------------------------------------------- #
# Sampler
# --------------------------------------------------------------------------- #
class UniverSRSampler:
"""Super-resolve audio to 48 kHz with UniverSR. Long clips are processed in
overlapping chunks (click-free overlap-add) to stay within VRAM."""
DESCRIPTION = ("Upscale low-bandwidth audio to 48 kHz with UniverSR. Pick input_sr to "
"match the effective bandwidth of your content (the model regenerates "
"everything above input_sr/2).")
CATEGORY = "audio/UniverSR"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO", {}),
"model": ("UNIVERSR_MODEL", {}),
"input_sr": ([8000, 12000, 16000, 24000], {
"default": 8000,
"tooltip": "Effective input bandwidth (Hz). Content is treated as valid up to "
"input_sr/2 and regenerated above it. 8000 = genuine low-rate audio "
"(strongest, 8 kHz->48 kHz). 16000 = brighten muffled audio above 8 kHz.",
}),
},
"optional": {
"ode_method": (["midpoint", "euler", "rk4"], {
"default": "midpoint",
"tooltip": "ODE solver. euler (fastest) -> midpoint (balanced) -> rk4 (best).",
}),
"ode_steps": ("INT", {
"default": 4, "min": 1, "max": 64, "step": 1,
"tooltip": "Flow-matching integration steps. 4 is fast and validated; 4-10 is a good range.",
}),
"guidance_scale": ("FLOAT", {
"default": 1.5, "min": 0.0, "max": 6.0, "step": 0.25,
"tooltip": "Classifier-free guidance. Speech 1.0-1.5, music 1.5-2.0, SFX ~1.5. "
"Higher = denser highs but less faithful. 0 disables CFG.",
}),
"seed": ("INT", {
"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Noise seed for the flow-matching source. 0 = random each run.",
}),
"chunk_seconds": ("FLOAT", {
"default": 10.0, "min": 0.0, "max": 120.0, "step": 0.5,
"tooltip": "Process long audio in chunks of this length (seconds) to avoid OOM. "
"0 = process the whole clip at once.",
}),
"overlap_seconds": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 5.0, "step": 0.1,
"tooltip": "Crossfade overlap between chunks (seconds). Prevents seam clicks.",
}),
"blend": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Wet/dry mix. 1.0 = full super-resolution. Lower to keep more of the "
"original (useful when brightening already-48 kHz audio).",
}),
"unload_model": ("BOOLEAN", {
"default": False,
"tooltip": "Free the model from VRAM after this run.",
}),
"show_spectrogram": ("BOOLEAN", {
"default": True,
"tooltip": "Also output a before/after spectrogram comparison image.",
}),
},
}
RETURN_TYPES = ("AUDIO", "IMAGE")
RETURN_NAMES = ("audio", "spectrogram")
FUNCTION = "run"
def run(self, audio, model, input_sr, ode_method="midpoint", ode_steps=4,
guidance_scale=1.5, seed=0, chunk_seconds=10.0, overlap_seconds=0.5,
blend=1.0, unload_model=False, show_spectrogram=True):
model_obj = model["model"]
waveform, sr = usr.comfy_audio_to_tensor(audio)
dur = waveform.shape[-1] / max(sr, 1)
print(f"[UniverSR] {tuple(waveform.shape)} @ {sr} Hz ({dur:.2f}s) -> 48 kHz | "
f"input_sr={input_sr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}")
out, dry48 = usr.super_resolve(
model_obj, waveform, sr, int(input_sr),
ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=guidance_scale,
seed=int(seed), chunk_seconds=float(chunk_seconds),
overlap_seconds=float(overlap_seconds), blend=float(blend),
)
audio_out = usr.tensor_to_comfy_audio(out, usr.TARGET_SR)
spec = torch.zeros(1, 64, 64, 3)
if show_spectrogram:
in_mono = dry48[0].mean(0).numpy()
out_mono = out[0].mean(0).numpy()
spec = usr.make_spectrogram_image(in_mono, out_mono, int(input_sr))
if unload_model:
usr.evict_model(model["cache_key"])
print(f"[UniverSR] Done -> {out.shape[-1] / usr.TARGET_SR:.2f}s at 48 kHz")
return (audio_out, spec)
NODE_CLASS_MAPPINGS = {
"UniverSRModelLoader": UniverSRModelLoader,
"UniverSRSampler": UniverSRSampler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UniverSRModelLoader": "UniverSR Model Loader",
"UniverSRSampler": "UniverSR Super-Resolution",
}
+19
View File
@@ -0,0 +1,19 @@
[project]
name = "comfyui-universr"
description = "ComfyUI nodes for UniverSR — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching."
version = "1.0.0"
license = {file = "LICENSE"}
dependencies = [
"torchdiffeq>=0.2.3",
"einops>=0.7",
"timm>=0.9",
"huggingface_hub>=0.20",
"pyyaml>=6.0",
]
[project.urls]
Repository = "https://github.com/woongzip1/UniverSR"
Paper = "https://arxiv.org/abs/2510.00771"
[tool.comfy]
DisplayName = "UniverSR"
+8
View File
@@ -0,0 +1,8 @@
# ComfyUI-UniverSR runtime deps.
# torch / torchaudio / numpy / matplotlib are already shipped by ComfyUI.
# The vendored `universr` package only needs these extras on top of ComfyUI's stack:
torchdiffeq>=0.2.3
einops>=0.7
timm>=0.9
huggingface_hub>=0.20
pyyaml>=6.0
+422
View File
@@ -0,0 +1,422 @@
"""Core wrapper for ComfyUI-UniverSR.
Bootstraps the `universr` package (prefers a pip-installed copy, falls back to
the vendored one under ./vendor), manages model loading/caching, and runs the
super-resolution itself with optional overlap-add chunking for long audio.
UniverSR (ICASSP 2026) is a vocoder-free audio super-resolution model that
regenerates high-frequency content in the complex-STFT domain via flow matching.
A single model handles 8 / 12 / 16 / 24 kHz effective bandwidth -> 48 kHz.
Key design note — why we resample ourselves instead of handing UniverSR a file:
UniverSR's `enhance()` file path calls `torchaudio.load`, whose torchcodec
backend is fragile across environments; its *tensor* path assumes the tensor
is already at `input_sr`. ComfyUI audio arrives at an arbitrary real sample
rate, so we do the band-limit ourselves: resample to 48 kHz, downsample each
chunk to `input_sr` (pure DSP, no codec), and hand UniverSR a genuine
low-rate tensor to super-resolve. This reproduces the exact training-time
degradation and was validated in the FoleyTune BWE node.
"""
import os
import threading
import numpy as np
import torch
import torchaudio
# --------------------------------------------------------------------------- #
# Optional ComfyUI integration (degrade gracefully outside ComfyUI / in tests)
# --------------------------------------------------------------------------- #
try:
import comfy.model_management as mm
import comfy.utils
HAS_COMFY = True
except Exception: # pragma: no cover - allows standalone import / pytest
HAS_COMFY = False
try:
import folder_paths
HAS_FOLDER_PATHS = True
except Exception: # pragma: no cover
HAS_FOLDER_PATHS = False
TARGET_SR = 48_000
SUPPORTED_INPUT_SR = (8000, 12000, 16000, 24000)
# UniverSR.enhance() zero-pads anything shorter than this (≈0.68 s @ 48 kHz) before
# running the ODE, so chunks below it just waste compute — clamp to it.
MODEL_MIN_SAMPLES = 32_768
_NODE_DIR = os.path.dirname(os.path.abspath(__file__))
_VENDOR_DIR = os.path.join(_NODE_DIR, "vendor")
_BUNDLED_CONFIG = os.path.join(_NODE_DIR, "configs", "config.yaml")
# HuggingFace repos for the two released checkpoints.
HF_REPOS = {
"universr-audio": "woongzip1/universr-audio",
"universr-speech": "woongzip1/universr-speech",
}
# --------------------------------------------------------------------------- #
# Package bootstrap
# --------------------------------------------------------------------------- #
def get_universr_cls():
"""Return the `UniverSR` class, preferring an installed copy over the vendored one."""
try:
from universr import UniverSR # installed (e.g. via the FoleyTune node)
return UniverSR
except Exception:
pass
import sys
if _VENDOR_DIR not in sys.path:
sys.path.insert(0, _VENDOR_DIR)
try:
from universr import UniverSR # vendored fallback
return UniverSR
except Exception as e: # pragma: no cover
raise RuntimeError(
"Could not import the 'universr' package (neither installed nor vendored). "
"Try: pip install torchdiffeq (the only dependency ComfyUI does not already ship).\n"
f"Underlying error: {e}"
)
# --------------------------------------------------------------------------- #
# Model directory + cache
# --------------------------------------------------------------------------- #
def get_models_dir() -> str:
if HAS_FOLDER_PATHS:
base = folder_paths.models_dir
else:
base = os.path.join(_NODE_DIR, "..", "..", "models")
return os.path.abspath(os.path.join(base, "universr"))
def list_local_models() -> list:
"""Subdirectories of models/universr that look like a UniverSR checkpoint dir."""
root = get_models_dir()
found = []
if os.path.isdir(root):
for name in sorted(os.listdir(root)):
d = os.path.join(root, name)
if os.path.isdir(d) and os.path.exists(os.path.join(d, "config.yaml")) \
and os.path.exists(os.path.join(d, "pytorch_model.bin")):
if name not in HF_REPOS:
found.append(name)
return found
_MODEL_CACHE: dict = {}
_CACHE_LOCK = threading.Lock()
def _download_preset(name: str) -> str:
"""Download a preset checkpoint into models/universr/<name> and return that dir."""
from huggingface_hub import snapshot_download
repo_id = HF_REPOS[name]
target = os.path.join(get_models_dir(), name)
have = os.path.exists(os.path.join(target, "config.yaml")) and \
os.path.exists(os.path.join(target, "pytorch_model.bin"))
if not have:
os.makedirs(target, exist_ok=True)
print(f"[UniverSR] Downloading {repo_id} -> {target} (~230 MB)...")
snapshot_download(
repo_id=repo_id,
local_dir=target,
allow_patterns=["config.yaml", "pytorch_model.bin"],
)
print(f"[UniverSR] Downloaded {name}.")
return target
def resolve_model_ref(model: str, local_path: str = "") -> tuple:
"""Resolve the loader inputs to (kind, path). kind in {'dir', 'ckpt'}.
- local_path wins if set: a directory (config.yaml + pytorch_model.bin) -> 'dir';
a .pth/.pt/.ckpt file -> 'ckpt' (loaded via from_local with a config).
- a preset name ('universr-audio' / 'universr-speech') -> download -> 'dir'.
- a local subdir name discovered under models/universr -> 'dir'.
"""
local_path = (local_path or "").strip()
if local_path:
if os.path.isdir(local_path):
return ("dir", local_path)
if os.path.isfile(local_path):
return ("ckpt", local_path)
raise FileNotFoundError(f"local_path does not exist: {local_path}")
if model in HF_REPOS:
return ("dir", _download_preset(model))
cand = os.path.join(get_models_dir(), model)
if os.path.isdir(cand):
return ("dir", cand)
raise FileNotFoundError(
f"Unknown model '{model}'. Use a preset {list(HF_REPOS)}, a local subdir of "
f"{get_models_dir()}, or set local_path."
)
def load_model(model: str, device: str, local_path: str = "", config_path: str = ""):
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
kind, path = resolve_model_ref(model, local_path)
cache_key = f"{kind}:{os.path.abspath(path)}:{device}"
with _CACHE_LOCK:
if cache_key in _MODEL_CACHE:
print(f"[UniverSR] Using cached model ({cache_key})")
return _MODEL_CACHE[cache_key], cache_key
UniverSR = get_universr_cls()
if kind == "dir":
print(f"[UniverSR] Loading from_pretrained({path}) on {device}")
model_obj = UniverSR.from_pretrained(path, device=device)
else:
cfg = (config_path or "").strip() or _BUNDLED_CONFIG
if not os.path.exists(cfg):
raise FileNotFoundError(
f"config_path required for a raw checkpoint and not found: {cfg}"
)
print(f"[UniverSR] Loading from_local(ckpt={path}, config={cfg}) on {device}")
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
model_obj.eval()
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
print(f"[UniverSR] Ready - {n:.1f}M params on {device}")
_MODEL_CACHE[cache_key] = model_obj
return model_obj, cache_key
def evict_model(cache_key: str):
import gc
with _CACHE_LOCK:
_MODEL_CACHE.pop(cache_key, None)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[UniverSR] Model unloaded ({cache_key})")
# --------------------------------------------------------------------------- #
# Audio helpers
# --------------------------------------------------------------------------- #
def comfy_audio_to_tensor(audio) -> tuple:
"""ComfyUI AUDIO (dict or legacy tuple) -> (waveform [B, C, T] float32 cpu, sr)."""
if isinstance(audio, dict):
waveform, sr = audio["waveform"], audio["sample_rate"]
else:
waveform, sr = audio
if not isinstance(waveform, torch.Tensor):
waveform = torch.as_tensor(waveform)
waveform = waveform.detach().float().cpu()
if waveform.dim() == 1: # (T,)
waveform = waveform[None, None, :]
elif waveform.dim() == 2: # (C, T)
waveform = waveform[None, :, :]
return waveform, int(sr)
def tensor_to_comfy_audio(waveform: torch.Tensor, sr: int) -> dict:
if waveform.dim() == 1:
waveform = waveform[None, None, :]
elif waveform.dim() == 2:
waveform = waveform[None, :, :]
return {"waveform": waveform.detach().cpu().contiguous(), "sample_rate": int(sr)}
def _resample(x: torch.Tensor, orig: int, target: int) -> torch.Tensor:
if orig == target:
return x
return torchaudio.functional.resample(x, orig, target)
def _fit(x: torch.Tensor, n: int) -> torch.Tensor:
"""Crop or zero-pad a 1-D tensor to exactly n samples."""
if x.shape[-1] == n:
return x
if x.shape[-1] > n:
return x[:n]
return torch.nn.functional.pad(x, (0, n - x.shape[-1]))
def _crossfade_window(length: int, ov: int, first: bool, last: bool) -> torch.Tensor:
"""Linear fade-in/out over the overlap regions; flat 1.0 elsewhere.
Combined with weight-sum normalisation this gives click-free overlap-add.
"""
w = torch.ones(length)
if ov > 0:
f = min(ov, length)
if not first:
w[:f] = torch.minimum(w[:f], torch.linspace(0.0, 1.0, f))
if not last:
w[-f:] = torch.minimum(w[-f:], torch.linspace(1.0, 0.0, f))
return w
# --------------------------------------------------------------------------- #
# Inference
# --------------------------------------------------------------------------- #
@torch.no_grad()
def _enhance_segment(model, seg48: torch.Tensor, input_sr: int,
ode_method: str, ode_steps: int, guidance_scale) -> torch.Tensor:
"""Super-resolve one 48 kHz mono segment. Returns 1-D tensor @48 kHz on CPU."""
low = _resample(seg48.unsqueeze(0), TARGET_SR, input_sr).squeeze(0) # genuine LR-rate signal
cfg = float(guidance_scale) if (guidance_scale and guidance_scale > 0) else None
out = model.enhance(
low, input_sr=int(input_sr),
ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=cfg,
)
return out.reshape(-1).float().cpu()
def _chunk_starts(total: int, chunk: int, hop: int) -> list:
if chunk <= 0 or total <= chunk:
return [0]
starts = list(range(0, max(1, total - chunk) + 1, hop))
if starts[-1] + chunk < total:
starts.append(total - chunk)
return starts
@torch.no_grad()
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor:
T = ch48.shape[-1]
if chunk <= 0 or T <= chunk:
if pbar is not None:
pbar.update(1)
return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T)
hop = max(1, chunk - ov)
starts = _chunk_starts(T, chunk, hop)
out = torch.zeros(T)
wsum = torch.zeros(T)
for i, s in enumerate(starts):
if HAS_COMFY:
mm.throw_exception_if_processing_interrupted()
e = min(s + chunk, T)
enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s)
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
out[s:e] += enh * w
wsum[s:e] += w
if pbar is not None:
pbar.update(1)
return out / torch.clamp(wsum, min=1e-8)
@torch.no_grad()
def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
ode_method: str = "midpoint", ode_steps: int = 4,
guidance_scale=1.5, seed: int = 0,
chunk_seconds: float = 10.0, overlap_seconds: float = 0.5,
blend: float = 1.0):
"""Run UniverSR over a [B, C, T] waveform. Returns (out [B, C, T48], dry48 [B, C, T48])."""
if int(input_sr) not in SUPPORTED_INPUT_SR:
raise ValueError(f"input_sr must be one of {SUPPORTED_INPUT_SR}, got {input_sr}")
waveform = waveform.float().cpu()
if waveform.dim() != 3:
raise ValueError(f"Expected a [B, C, T] waveform, got shape {tuple(waveform.shape)}")
B, C, _ = waveform.shape
dry48 = _resample(waveform, sr, TARGET_SR) # [B, C, T48]
T48 = dry48.shape[-1]
if T48 == 0: # empty input — nothing to do
empty = torch.zeros(B, C, 0)
return empty, empty
chunk = int(round(chunk_seconds * TARGET_SR)) if (chunk_seconds and chunk_seconds > 0) else 0
if 0 < chunk < MODEL_MIN_SAMPLES:
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
chunk = MODEL_MIN_SAMPLES
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
pbar = comfy.utils.ProgressBar(B * C * n_per_ch) if HAS_COMFY else None
# Isolate the global RNG: snapshot, seed, run, restore. Without this the model's
# torch.randn_like noise would advance (and a fixed seed would freeze) the global
# generator that downstream ComfyUI nodes rely on. seed=0 → fresh OS entropy.
cpu_rng = torch.get_rng_state()
cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
actual_seed = int(seed) if (seed and int(seed) != 0) else int.from_bytes(os.urandom(8), "little")
try:
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
wet = torch.zeros(B, C, T48)
for b in range(B):
for c in range(C):
wet[b, c] = _fit(
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
guidance_scale, chunk, ov, pbar),
T48,
)
finally:
torch.set_rng_state(cpu_rng)
if cuda_rng is not None:
torch.cuda.set_rng_state_all(cuda_rng)
blend = float(blend)
out = wet if blend >= 1.0 else (1.0 - blend) * dry48 + blend * wet
return out.clamp(-1.0, 1.0), dry48
# --------------------------------------------------------------------------- #
# Spectrogram comparison (optional IMAGE output)
# --------------------------------------------------------------------------- #
def _stft_db(x: np.ndarray) -> np.ndarray:
t = torch.from_numpy(np.ascontiguousarray(x)).float()
win = torch.hann_window(1024)
spec = torch.stft(t, n_fft=1024, hop_length=512, window=win, return_complex=True)
db = 20.0 * torch.log10(spec.abs().clamp(min=1e-5))
db = db - db.max()
return db.numpy()
def make_spectrogram_image(input48_mono: np.ndarray, output48_mono: np.ndarray,
input_sr: int) -> torch.Tensor:
"""Before/after spectrogram comparison -> IMAGE tensor [1, H, W, 3] in [0, 1].
Left panel is the band-limited input (content valid up to input_sr/2); right
panel is the 48 kHz output. The dashed line marks the LR Nyquist, so the
regenerated high-frequency band is the energy above it on the right.
"""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# Visualise the band-limit the model actually saw, not the raw container.
lr = torch.from_numpy(np.ascontiguousarray(input48_mono)).float()[None]
lr = _resample(_resample(lr, TARGET_SR, int(input_sr)), int(input_sr), TARGET_SR).squeeze(0).numpy()
n = min(len(lr), len(output48_mono), int(8.0 * TARGET_SR))
lr, hr = lr[:n], output48_mono[:n]
nyq = int(input_sr) / 2.0
fig, axes = plt.subplots(1, 2, figsize=(12, 4.0), facecolor="#0d0f16")
for ax, sig, title, cmap in (
(axes[0], lr, f"Input (<= {int(input_sr)//1000} kHz)", "magma"),
(axes[1], hr, "UniverSR output (48 kHz)", "viridis"),
):
db = _stft_db(sig)
ax.imshow(db, origin="lower", aspect="auto", cmap=cmap,
extent=[0, n / TARGET_SR, 0, TARGET_SR / 2], vmin=-80, vmax=0)
ax.axhline(nyq, color="w", ls="--", lw=0.8, alpha=0.6)
ax.set_title(title, color="#cfe0ff", fontsize=10)
ax.set_xlabel("Time (s)", color="#7a93bd", fontsize=8)
ax.set_ylabel("Hz", color="#7a93bd", fontsize=8)
ax.tick_params(colors="#5a6e90", labelsize=7)
ax.set_facecolor("#0d0f16")
fig.tight_layout()
fig.canvas.draw()
# np.asarray(buffer_rgba()) yields (H, W, 4) at the real pixel size — robust to HiDPI.
img = np.asarray(fig.canvas.buffer_rgba())[..., :3].astype(np.float32) / 255.0
plt.close(fig)
return torch.from_numpy(np.ascontiguousarray(img))[None]
except Exception as e: # matplotlib missing / headless edge cases
print(f"[UniverSR] Spectrogram render skipped: {e}")
return torch.zeros(1, 64, 64, 3)
+4
View File
@@ -0,0 +1,4 @@
from universr.inference import UniverSR
__version__ = "0.1.0"
__all__ = ["UniverSR"]
View File
+9
View File
@@ -0,0 +1,9 @@
import torch
import torch.nn.functional as F
def flow_matching_loss(predicted_vf: torch.Tensor, target_vf: torch.Tensor) -> torch.Tensor:
"""
Flow matching loss; L2 loss between estimated and target vector field.
"""
return F.mse_loss(predicted_vf, target_vf)
+54
View File
@@ -0,0 +1,54 @@
import importlib
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
class ConditionalProbabilityPath(nn.Module, ABC):
"""Abstract base class for conditional probability paths in flow matching."""
@abstractmethod
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
"""Sample from the source distribution. shape_ref is used only for shape/device."""
@abstractmethod
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Interpolate between source x0 and target x1 at time t."""
@abstractmethod
def get_target_vector_field(
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Compute the target vector field u_t(xt | x1)."""
class OriginalCFMPath(ConditionalProbabilityPath):
def __init__(self, sigma_min: float = 1e-4):
super().__init__()
self.sigma_min = sigma_min
def sample_source(self, shape_ref):
return torch.randn_like(shape_ref)
def sample_xt(self, x0, x1, t):
return t * x1 + (1 - t + self.sigma_min * t) * x0
def get_target_vector_field(self, xt, x0, x1, t):
return x1 - (1 - self.sigma_min) * x0
def get_path(config):
class_path = config.get("class_path")
if not class_path:
raise ValueError("Configuration must contain a 'class_path' key")
try:
module_path, class_name = class_path.rsplit(".", 1)
except ValueError:
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
module = importlib.import_module(module_path)
Class = getattr(module, class_name)
init_args = config.get("init_args", {})
return Class(**init_args)
+127
View File
@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
import torch
from torchdiffeq import odeint
from tqdm import tqdm
from universr.models.unet import ConditionalVectorFieldModel
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class Solver(ABC):
# @abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time, shape (bs, 1, 1, 1)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x.clone()]
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:,t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
xs.append(x.clone())
return torch.stack(xs, dim=1)
class VectorFieldODE(ODE):
def __init__(self, net:ConditionalVectorFieldModel) -> None:
super().__init__()
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t, y, **kwargs)
class CFGVectorFieldODE(ODE):
""" For Classifier Free Guidance """
def __init__(self, net:ConditionalVectorFieldModel, guidance_scale: float = 1.0) -> None:
super().__init__()
self.net = net
self.guidance_scale = guidance_scale
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
guided_vector_field = self.net(xt, t, y, **kwargs)
unguided_vector_field = self.net(xt, t, None, **kwargs)
return (1-self.guidance_scale) * unguided_vector_field + self.guidance_scale * guided_vector_field
class EulerSolver(Solver):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs):
return xt + self.ode.drift_coefficient(xt,t, **kwargs) * h
class TorchDiffeqSolver(Solver):
def __init__(self,
ode: ODE,
method: str = 'euler',
atol: float = 1e-5,
rtol: float = 1e-5,
):
super().__init__()
self.ode = ode
self.method = method
self.atol = atol
self.rtol = rtol
@torch.no_grad()
def simulate(self, x_init: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
x_init: [B,C,H,W]
ts: [N]
return: final state [B,C,H,W]
"""
func = lambda t, x: self.ode.drift_coefficient(xt=x, t=t, **kwargs)
xs = odeint(
func=func,
y0=x_init, t=ts,
method=self.method,
atol=self.atol, rtol=self.rtol) # [N,B,C,H,W]
return xs[-1]
+351
View File
@@ -0,0 +1,351 @@
"""
UniverSR: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching
Inference wrapper module.
"""
import os
from typing import Optional, Union
import numpy as np
import torch
import torchaudio
import yaml
from huggingface_hub import hf_hub_download
from universr.models.unet import ConvNeXtUNetCond
from universr.flow.path import OriginalCFMPath
from universr.flow.solver import CFGVectorFieldODE, VectorFieldODE, TorchDiffeqSolver
from universr.utils.spectral_ops import AmplitudeCompressedComplexSTFT
# Supported input sample rates (kHz) and their corresponding LR frequency bins
SUPPORTED_INPUT_SR = {8000, 12000, 16000, 24000}
TARGET_SR = 48000
class UniverSR(torch.nn.Module):
"""
UniverSR inference wrapper.
Performs audio super-resolution from low sample rates (8/12/16/24 kHz)
to 48 kHz using vocoder-free flow matching in the complex STFT domain.
Example:
>>> model = UniverSR.from_pretrained("woongzip1/universr-speech")
>>> output = model.enhance("input.wav", input_sr=16000)
>>> torchaudio.save("output.wav", output.cpu(), 48000)
"""
def __init__(
self,
model: ConvNeXtUNetCond,
transform: AmplitudeCompressedComplexSTFT,
path: OriginalCFMPath,
device: str = "cuda",
):
super().__init__()
self.model = model
self.transform = transform
self.path = path
self._device = device
@classmethod
def from_pretrained(
cls,
repo_id_or_path: str,
device: str = "cuda",
revision: Optional[str] = None,
) -> "UniverSR":
"""
Load a pretrained UniverSR model.
Args:
repo_id_or_path: HuggingFace repo ID (e.g. "woongzip1/universr-speech")
or local directory path containing config.yaml and pytorch_model.bin.
device: Device to load the model on.
revision: Optional HuggingFace revision (branch, tag, or commit hash).
Returns:
UniverSR instance ready for inference.
"""
if os.path.isdir(repo_id_or_path):
config_path = os.path.join(repo_id_or_path, "config.yaml")
model_path = os.path.join(repo_id_or_path, "pytorch_model.bin")
else:
config_path = hf_hub_download(
repo_id=repo_id_or_path, filename="config.yaml", revision=revision
)
model_path = hf_hub_download(
repo_id=repo_id_or_path, filename="pytorch_model.bin", revision=revision
)
# Load config
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Build model
model = ConvNeXtUNetCond(**config["model"])
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.to(device).eval()
# Build transform
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
transform.to(device)
# Build probability path
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
path = OriginalCFMPath(**path_args)
return cls(model=model, transform=transform, path=path, device=device)
@classmethod
def from_local(
cls,
ckpt_path: str,
config_path: str,
device: str = "cuda",
) -> "UniverSR":
"""
Load UniverSR from a local checkpoint (e.g. training checkpoint with optimizer state).
This handles the standard training checkpoint format where weights are stored
under the 'model_state_dict' key, as opposed to from_pretrained() which expects
a clean state_dict saved as pytorch_model.bin.
Args:
ckpt_path: Path to checkpoint file (.pth).
config_path: Path to YAML config file.
device: Device to load the model on.
Returns:
UniverSR instance ready for inference.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model = ConvNeXtUNetCond(**config["model"])
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
# Handle both formats: raw state_dict or training checkpoint
if "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
else:
model.load_state_dict(ckpt)
model.to(device).eval()
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
transform.to(device)
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
path = OriginalCFMPath(**path_args)
return cls(model=model, transform=transform, path=path, device=device)
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
@torch.no_grad()
def enhance(
self,
audio: Union[str, torch.Tensor, np.ndarray],
input_sr: Optional[int] = None,
target_sr: int = TARGET_SR,
ode_method: str = "midpoint",
ode_steps: int = 4,
guidance_scale: Optional[float] = 1.5,
) -> torch.Tensor:
"""
Enhance a low-resolution audio signal to high-resolution.
Args:
audio: Input audio. Can be:
- str: path to a .wav file
- torch.Tensor: waveform tensor of shape (T,), (1, T), or (1, 1, T)
- np.ndarray: waveform array
input_sr: Effective bandwidth of the input in Hz (e.g. 8000, 16000).
For file input: auto-detected from the file's native sample rate
if it matches a supported rate (8/12/16/24 kHz). Required if the
file is already at 48 kHz but has limited bandwidth.
For tensor/array input: always required.
target_sr: Target sample rate in Hz. Default: 48000.
ode_method: ODE solver method. One of 'euler', 'midpoint', 'rk4'.
ode_steps: Number of ODE integration steps.
guidance_scale: Classifier-free guidance scale. None or 0 disables CFG.
Returns:
Enhanced waveform tensor of shape (1,T) at target_sr.
"""
# Load audio
wav, file_sr = self._load_audio(audio, input_sr=input_sr)
wav = wav.to(self._device)
# Determine the effective bandwidth SR
effective_sr = input_sr if input_sr is not None else file_sr
if effective_sr not in SUPPORTED_INPUT_SR:
if effective_sr == target_sr and input_sr is None:
raise ValueError(
f"Input audio is already at {target_sr} Hz. "
f"Please specify input_sr to indicate the effective bandwidth "
f"(e.g., input_sr=16000). Supported: {sorted(SUPPORTED_INPUT_SR)}"
)
raise ValueError(
f"Effective input sample rate {effective_sr} Hz is not supported. "
f"Supported rates: {sorted(SUPPORTED_INPUT_SR)}"
)
# Prepare the 48 kHz LR input for the model
if file_sr == target_sr:
# Simulate the training degradation: downsample → upsample to match
wav = self._apply_bandwidth_limit(wav, effective_sr, target_sr)
elif file_sr != target_sr:
# File is truly low-resolution; resample up to 48 kHz
wav = torchaudio.functional.resample(wav, orig_freq=file_sr, new_freq=target_sr)
# Minimum length guard
MIN_SAMPLES = 32_768
original_len = wav.shape[-1]
wav = torch.nn.functional.pad(wav, (0, max(0, MIN_SAMPLES - wav.shape[-1])))
# Ensure shape is [B, C, T] = [1, 1, T]
if wav.dim() == 1:
wav = wav.unsqueeze(0).unsqueeze(0)
elif wav.dim() == 2:
wav = wav.unsqueeze(0)
sr_khz = effective_sr // 1000
# Run flow matching SR
output = self._inference(wav, sr_khz, ode_method, ode_steps, guidance_scale)
# (1,T)
return output[..., :original_len]
# ------------------------------------------------------------------ #
# Internal methods #
# ------------------------------------------------------------------ #
def _load_audio(
self, audio: Union[str, torch.Tensor, np.ndarray], input_sr: Optional[int] = None,
) -> tuple:
"""
Load and validate audio input.
Returns:
(waveform, file_sr): The waveform tensor and its *actual* sample rate.
"""
if isinstance(audio, str):
wav, file_sr = torchaudio.load(audio)
# Mix to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav, file_sr
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if isinstance(audio, torch.Tensor):
if input_sr is None:
raise ValueError("input_sr is required when passing a tensor or array.")
return audio.float(), input_sr
raise TypeError(f"Unsupported audio type: {type(audio)}")
def _apply_bandwidth_limit(
self, wav: torch.Tensor, effective_sr: int, target_sr: int,
) -> torch.Tensor:
"""
Simulate low-resolution input from a high-sample-rate waveform.
Applies the same downsample-then-upsample pipeline used during training
(see WaveformCollator._apply_lpf) so that the spectral cutoff pattern
matches what the model expects.
Args:
wav: Waveform at target_sr. Shape: (1, T) or (T,).
effective_sr: The effective bandwidth in Hz (e.g. 8000).
target_sr: The native sample rate of wav (e.g. 48000).
Returns:
Bandwidth-limited waveform at target_sr, same length as input.
"""
original_len = wav.shape[-1]
lr = torchaudio.functional.resample(wav, orig_freq=target_sr, new_freq=effective_sr)
lr = torchaudio.functional.resample(lr, orig_freq=effective_sr, new_freq=target_sr)
return lr[..., :original_len]
def _preprocess(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Convert waveform to amplitude-compressed complex STFT representation.
[B, C, T] -> [B, 2, F-1, T_frames] (real/imag channels, drop Nyquist bin)
"""
spec = self.transform(waveform) # [B, C, F, T_frames] complex
real = torch.view_as_real(spec.squeeze(1)) # [B, F, T_frames, 2]
real = real.permute(0, 3, 1, 2) # [B, 2, F, T_frames]
return real[:, :, :-1, :] # drop Nyquist bin
def _postprocess(self, spec: torch.Tensor) -> torch.Tensor:
"""
Convert STFT representation back to waveform.
[B, 2, F-1, T_frames] -> [B, T]
"""
spec = torch.nn.functional.pad(spec, [0, 0, 0, 1], value=0) # restore Nyquist
spec = spec.permute(0, 2, 3, 1).contiguous() # [B, F, T, 2]
spec = torch.view_as_complex(spec) # [B, F, T] complex
waveform = self.transform.invert(spec) # [B, T]
return waveform
def _inference(
self,
lr_audio: torch.Tensor,
sr_khz: int,
ode_method: str,
ode_steps: int,
guidance_scale: Optional[float],
) -> torch.Tensor:
"""
Core inference pipeline:
1. STFT the (resampled) LR audio
2. Extract LR condition bins
3. Sample noise for HF region
4. Solve ODE (flow matching)
5. Concatenate LR + generated HF
6. iSTFT to waveform
"""
# Frequency bin bookkeeping
lr_bin_count = self.model.sr_to_lr_bins[sr_khz]
hf_start_bin = self.model.total_freq_bins - self.model.hr_freq_bins
# STFT
Y = self._preprocess(lr_audio) # [B, 2, F-1, T]
Y_lr = Y[:, :, :lr_bin_count, :] # LR condition
Y_hr = Y[:, :, hf_start_bin:, :] # HR target region (for shape reference)
# Initial noise
x0 = self.path.sample_source(Y_hr).to(self._device)
# Build ODE solver
if guidance_scale is not None and guidance_scale > 0:
ode = CFGVectorFieldODE(net=self.model, guidance_scale=guidance_scale)
else:
ode = VectorFieldODE(net=self.model)
solver = TorchDiffeqSolver(ode, method=ode_method)
# Time discretization
ts = torch.linspace(0, 1, ode_steps + 1, device=self._device)
# Solve ODE
x1_spec = solver.simulate(
x0, ts=ts, y=Y_lr, sr_values=torch.tensor([sr_khz], device=self._device)
)
# Concatenate LR bins + generated HF bins (handle overlapping region)
slice_start = max(0, lr_bin_count - hf_start_bin)
x1_spec = x1_spec[:, :, slice_start:, :]
full_spec = torch.cat([Y_lr, x1_spec], dim=2)
# iSTFT
output = self._postprocess(full_spec)
return output
View File
+470
View File
@@ -0,0 +1,470 @@
import math
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
class ConditionalVectorFieldModel(nn.Module, ABC):
"""
Base class for DNN-based VF model
MLP-parameterization of the learned vector field u_t^theta(x)
"""
@abstractmethod
def forward(self, x:torch.Tensor, t:torch.Tensor, y:torch.Tensor):
"""
Args:
- x: (bs, c, h, w)
- t: (bs, 1, 1, 1)
- y: (bs,)
Returns:
- u_t^theta(x|y): (bs, c, h, w)
"""
pass
class SinusoidalTimeEmbedding(nn.Module):
"""
Based on https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/karras_unet.py#L183
& DiffWave / WaveFM
"""
def __init__(self, dim: int=128, mode: str='learnable', time_scale=1):
super().__init__()
assert dim % 2 == 0, "Dimension must be an even number"
assert mode in ['fixed', 'learnable'], "Mode must be 'fixed' or 'learnable'"
self.dim = dim # D
self.half_dim = dim // 2
self.mode = mode
self.time_scale = time_scale # 1(diffusion) or 100(flow)
if self.mode == 'learnable':
self.weights = nn.Parameter(torch.randn(1, self.half_dim)) # [1,D/2]
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Args:
- t: Time tensor. Shape can be [B] or [B, 1].
Returns:
- embeddings: Time embeddings of shape [B, D]
"""
# Ensure t has shape [B, 1] for broadcasting
t = t.view(-1, 1)
device = t.device
if self.mode == 'fixed':
# Create a sequence from 0 to D/2 - 1
pos = torch.arange(self.half_dim, device=device).unsqueeze(0) # [1,D/2]
freqs = self.time_scale * t * 10.0 ** (pos * 4.0 / (self.half_dim - 1)) # 100 is a magnitude hyperparameter
sin_embed = torch.sin(freqs)
cos_embed = torch.cos(freqs)
return torch.cat([sin_embed, cos_embed], dim=-1)
elif self.mode == 'learnable':
freqs = t * self.weights * 2 * math.pi
sin_embed = torch.sin(freqs)
cos_embed = torch.cos(freqs)
return torch.cat([sin_embed, cos_embed], dim=-1) * math.sqrt(2)
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Block(nn.Module):
""" ConvNeXt V2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, padding_mode="reflect")
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.grn = GRN(4 * dim) # GRN for V2
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
# This Block preserves the input shape (C, H, W) -> (C, H, W)
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # [N,C,H,W] -> [N,H,W,C]
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # [N,H,W,C] -> [N,C,H,W]
x = input + self.drop_path(x) # Residual connection
return x
class BlockWithEmbedding(nn.Module):
""" ConvNeXt block with time embedding injection
"""
def __init__(self, dim, drop_path=0., time_embed_dim=128):
super().__init__()
self.block = Block(dim, drop_path)
self.time_adapter = nn.Sequential(
nn.Linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, dim),
)
def forward(self, x, t_embed):
t_embed = self.time_adapter(t_embed).unsqueeze(-1).unsqueeze(-1) # [B,C,1,1]
x = x + t_embed
x = self.block(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.blocks= nn.ModuleList(
[BlockWithEmbedding(dim_in, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
self.downsampler = nn.Sequential(
LayerNorm(dim_in, eps=1e-6, data_format="channels_first"),
nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2),
)
def forward(self, x, t_emb):
for block in self.blocks:
x = block(x, t_emb)
x = self.downsampler(x)
return x
class Midcoder(nn.Module):
def __init__(self, dim, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.blocks = nn.ModuleList(
[BlockWithEmbedding(dim, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
def forward(self, x, t_emb):
for block in self.blocks:
x = block(x, t_emb)
return x
class DecoderBlock(nn.Module):
def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.upsampler = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=2, stride=2)
self.blocks = nn.ModuleList(
[BlockWithEmbedding(dim_out, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
def forward(self, x, t_emb):
x = self.upsampler(x)
for block in self.blocks:
x = block(x, t_emb)
return x
class ConditioningEncoder2D(nn.Module):
def __init__(self, cond_dim, num_blocks=3):
"""
Args:
cond_dim (int): The main conditioning dimension (D).
num_blocks (int): The number of shared 2D ConvNeXt blocks.
"""
super().__init__()
self.cond_dim = cond_dim
self.film_generator = nn.Linear(cond_dim, 4)
self.head = nn.Conv2d(2, cond_dim, kernel_size=1)
self.sr_adapter = nn.Sequential(
nn.Linear(cond_dim, cond_dim),
nn.GELU(),
nn.Linear(cond_dim, cond_dim * 2)
)
self.blocks = nn.Sequential(*[
Block(dim=cond_dim) for _ in range(num_blocks)
])
self.freq_pool = nn.AdaptiveAvgPool2d((1,None))
def forward(self, y_lr, f_emb_lr, sr_emb):
"""
Args:
y_lr (Tensor): LR Spec [B, 2, F1, T]
f_emb : Freq positional embedding for lr spec [F1,D]
sr_emb: Sampling rate embedding [B,D]
Returns:
z (Tensor): Conditioning Emb [B, D, T]
"""
film_params = self.film_generator(f_emb_lr) # [F1, 4]
gamma, beta = torch.chunk(film_params, chunks=2, dim=-1) # [F1,2]
gamma = rearrange(gamma, 'f c -> 1 c f 1') # [1,2,F1,1]
beta = rearrange(beta, 'f c -> 1 c f 1') # [1,2,F1,1]
z = y_lr * gamma + beta # [B, 2, F1, T]
z = self.head(z) # [B,D,F1,T]
sr_film_params = self.sr_adapter(sr_emb) # [B, 2*D]
sr_gamma, sr_beta = torch.chunk(sr_film_params, 2, dim=-1) # [B,D]
sr_gamma = sr_gamma.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1]
sr_beta = sr_beta.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1]
z = z * sr_gamma + sr_beta # [B,D,F1,T]
z = self.blocks(z) # [B,D,F1,T]
z = self.freq_pool(z).squeeze(2) # [B,D,T]
return z
class FrequencyPositionalEmbedding(nn.Module):
def __init__(self, num_bins: int, emb_dim: int):
super().__init__()
# (F, D)
pe = torch.zeros(num_bins, emb_dim)
position = torch.arange(num_bins, dtype=torch.float32).unsqueeze(1) # (F,1)
div_term = torch.exp(
torch.arange(0, emb_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / emb_dim)
) # (D/2,)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self):
# returns (F, D)
return self.pe
class ConvNeXtUNetCond(ConditionalVectorFieldModel):
def __init__(self, in_channels=2, out_channels=2,
dims=[64,128,256,512], depths=[2,2,2,4],
drop_path=0., time_dim=128,
cond_dim=256, # D1
total_freq_bins=512,
hr_freq_bins=432,
feature_enc_layers=10,
cond_dropout_prob=0.1,
sr_to_lr_bins={8: 80, 12: 128, 16: 170, 24: 256},
):
super().__init__()
self.strides = 2**len(dims)
self.time_embedder = SinusoidalTimeEmbedding(dim=time_dim)
self.total_freq_bins = total_freq_bins
self.hr_freq_bins = hr_freq_bins
self.sr_to_lr_bins = sr_to_lr_bins
self.sr_values_list = sorted(list(sr_to_lr_bins.keys())) # (8,12,16,24) kHz
self.sr_to_idx = {sr: i for i, sr in enumerate(self.sr_values_list)}
self.sr_embedder = nn.Embedding(len(self.sr_values_list), cond_dim) # [4,D]
self.cond_dropout_prob = cond_dropout_prob
self.cond_dim = cond_dim
self.uncond_emb = nn.Parameter(torch.randn(cond_dim))
self.sr_projector = nn.Linear(cond_dim, time_dim) # projector to t_emb
self.freq_pos_enc = FrequencyPositionalEmbedding(num_bins=total_freq_bins, emb_dim=cond_dim)
self.film_generator = nn.Linear(cond_dim, cond_dim * 2)
self.conditioning_encoder = ConditioningEncoder2D(
cond_dim=cond_dim,
num_blocks=feature_enc_layers,
)
self.init_conv = nn.Sequential(
nn.Conv2d(in_channels+cond_dim, dims[0], kernel_size=1),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
# Encoder
for i in range(len(depths)):
dim_in = dims[i]
dim_out = dims[i+1] if i+1 < len(dims) else dims[i]
self.encoders.append(EncoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim))
# Midcoder
self.midcoder = Midcoder(dims[-1], depths[-1], drop_path, time_dim)
# Decoder
for i in reversed(range(len(depths))):
dim_in = dims[i+1] if i+1 < len(dims) else dims[i]
dim_out = dims[i]
self.decoders.append(DecoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim))
self.final_conv = nn.Conv2d(dims[0], out_channels, kernel_size=1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def _pad_frames(self, x):
num_frames = x.shape[-1]
pad_len = (self.strides - num_frames % self.strides) % self.strides
if pad_len:
x = torch.nn.functional.pad(x, [0,pad_len,0,0], mode='reflect')
assert x.shape[-1] % self.strides == 0, \
f"After padding, time dim:{x.shape(-1)} must be multiples of {self.strides}"
return x, pad_len
def forward(self, x, t, y, sr_values):
"""
x : x_t noisy spec [B,2,F,T]
t : time embedding [B,1] or [B]
y : condition lr spectrum [B,2,F,T]
sr_values: input sampling_rate [B] or [1]
"""
# Pad logic
x, pad_len = self._pad_frames(x)
if pad_len > 0 and y is not None:
y = torch.nn.functional.pad(y, [0, pad_len, 0, 0], mode='reflect')
B, _, F, T = x.shape
# get number of lr bins for input sr
if isinstance(sr_values, int):
current_sr = sr_values
else:
current_sr = sr_values[0].item() if hasattr(sr_values[0], 'item') else sr_values[0]
lr_bin_count = self.sr_to_lr_bins[current_sr]
# freq pe
pe_full = self.freq_pos_enc() # [F,D]
pe_low = pe_full[:lr_bin_count,:] # [F1,D]
hf_start_bin = self.total_freq_bins - self.hr_freq_bins # 512 - 432
pe_high = pe_full[hf_start_bin:, :] # [F2=432,D]
# time / sr embedding
t_embed = self.time_embedder(t) # [B,timedim]
sr_idx = self.sr_to_idx[current_sr]
sr_emb = self.sr_embedder(torch.tensor([sr_idx], device=x.device)).expand(B,-1) # [B, D]
t_embed = t_embed + self.sr_projector(sr_emb) # [B, timedim]
if y is not None: # (Training)
y_cond_real = self.conditioning_encoder(y, pe_low, sr_emb) # [B,D,T]
# Uncond token masking
if self.training and self.cond_dropout_prob > 0:
# random mask for uncond
mask = (torch.rand(B, device=x.device) < self.cond_dropout_prob) # [B]
uncond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T) # [B,D,T]
y_cond = torch.where(mask.reshape(B,1,1), uncond, y_cond_real)
else:
y_cond = y_cond_real
else: # Unconditional (inference)
y_cond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T)
y_cond = y_cond.unsqueeze(2) # [B,D,1,T]
# FiLM Conditioning of freq-bins
film_params = self.film_generator(pe_high) # [F2,D] -> [F2,2D]
gamma_high, beta_high = torch.chunk(film_params, chunks=2, dim=-1) # [F2, D]
gamma_high = rearrange(gamma_high, 'f d -> 1 d f 1') # [1,D,F2,1]
beta_high = rearrange(beta_high, 'f d -> 1 d f 1') # [1,D,F2,1]
spatial_cond = y_cond * gamma_high + beta_high # [B,D,F2,T]
x = torch.cat([x, spatial_cond], dim=1) # [B,2+D,F2,T]
x = self.init_conv(x)
skip_connections = [x]
for encoder in self.encoders:
x = encoder(x, t_embed)
skip_connections.append(x)
x = self.midcoder(x, t_embed)
for decoder in self.decoders:
skip = skip_connections.pop()
if x.shape != skip.shape:
x = nn.functional.interpolate(x, size=skip.shape[2:])
x = x + skip
x = decoder(x, t_embed)
skip = skip_connections.pop()
x = x + skip
x = self.final_conv(x)
# Crop out
if pad_len:
x = x[...,:-pad_len]
return x
def main():
"""
Dummy forward pass test for ConvNeXtUNetCond.
"""
from torchinfo import summary
batch_size = 2
hr_freq_bins = 432 # High-res bins to be generated (fixed)
lr_freq_bins = 128 # Low-res bins for this specific test case (e.g., for 8kHz)
T = 256 # Number of time frames
sr_config = {8: 80, 12: 128, 16: 170, 24: 256}
model = ConvNeXtUNetCond(
in_channels=2,
out_channels=2,
dims=[96, 192, 384, 768],
depths=[2, 2, 4, 2],
time_dim=256,
cond_dim=384,
total_freq_bins=512,
hr_freq_bins=hr_freq_bins,
feature_enc_layers=4,
cond_dropout_prob=0.1,
sr_to_lr_bins=sr_config, # Pass the dictionary
)
x = torch.randn(batch_size, 2, hr_freq_bins, T)
y = torch.randn(batch_size, 2, lr_freq_bins, T)
t = torch.randint(0, 1000, (batch_size,))
sr_values = [12] * batch_size
print("\n--- Model Summary ---")
summary(
model,
input_data=[x, t, y, sr_values],
depth=4,
col_names=("input_size", "output_size", "num_params",
"kernel_size", "mult_adds", "trainable"),
verbose=1
)
if __name__ == "__main__":
main()
View File
+135
View File
@@ -0,0 +1,135 @@
import math
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
class InvertibleFeatureExtractor(nn.Module, ABC):
"""
An invertible feature extractor, i.e. a one-to-one mapping that has a forward and a true inverse.
It should hold up to numerical error that `extractor.invert(extractor(x)) == x`.
"""
@abstractmethod
def forward(self, x, **kwargs):
pass
@abstractmethod
def invert(self, x, **kwargs):
pass
def analysis_synthesis(self, x, **kwargs):
return self.invert(self.forward(x, **kwargs), **kwargs)
class AmplitudeCompressedComplexSTFT(InvertibleFeatureExtractor):
"""
A convenient composition of ComplexSTFT() and CompressAmplitudesAndScale().
"""
def __init__(
self,
window_fn, n_fft, sampling_rate,
alpha, beta, comp_eps,
hop_length=None, n_hops=None,
learnable_window=False,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.complex_stft = ComplexSTFT(
window_fn, n_fft, sampling_rate, hop_length=hop_length, n_hops=n_hops,
learnable_window=learnable_window,
)
self.compress = CompressAmplitudesAndScale(
compression_exponent=alpha,
scale_factor=beta,
comp_eps=comp_eps,
)
def forward(self, x: Tensor, **kwargs):
X = self.complex_stft(x, **kwargs)
out = self.compress(X, **kwargs)
return out
def invert(self, X: Tensor, **kwargs):
X = self.compress.invert(X, **kwargs)
x = self.complex_stft.invert(X, **kwargs)
return x
class ComplexSTFT(InvertibleFeatureExtractor):
def __init__(
self, window_fn, n_fft, sampling_rate, hop_length=None, n_hops=None, learnable_window=False,
*args, **kwargs):
super().__init__(*args, **kwargs)
assert (hop_length is not None) ^ (n_hops is not None),\
"Exactly one of {hop_length, n_hops} must be specified!"
if hop_length is None:
hop_length = int(math.ceil(n_fft / n_hops))
window_fn = getattr(torch.signal.windows, window_fn)
self.learnable_window = learnable_window
self.window = nn.Parameter(window_fn(n_fft), requires_grad=learnable_window)
self.n_fft = n_fft
self.hop_length = hop_length
self.sampling_rate = sampling_rate
self.center = True
def forward(self, x: Tensor, **kwargs):
"""Assumes x is an audio tensor of shape [B, C, T] or [B, T]
[B,C,T] -> [B,C,F,T]
[B,C,T] -> [B,F,T]
"""
bc = "b c" if x.ndim == 3 else "b"
X = torch.stft(
rearrange(x, f"{bc} t -> ({bc}) t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(x.device), center=self.center,
onesided=True, return_complex=True,
)
X = rearrange(X, f"({bc}) f t -> {bc} f t", b=x.shape[0])
return X
def invert(self, X: Tensor, orig_length: Optional[int] = None, **kwargs):
"""Assumes X is a (complex) spectrogram tensor of shape [B, C, F, T] or [B, F, T]"""
bc = "b c" if X.ndim == 4 else "b"
x = torch.istft(
rearrange(X, f"{bc} f t -> ({bc}) f t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(X.device), center=self.center,
onesided=True, return_complex=False,
length=orig_length,
)
x = rearrange(x, f"({bc}) t -> {bc} t", b=X.shape[0])
return x
class CompressAmplitudesAndScale(InvertibleFeatureExtractor):
def __init__(self, compression_exponent: float, scale_factor: float, comp_eps: float, *args, **kwargs):
super().__init__()
self.compression_exponent = compression_exponent
self.scale_factor = scale_factor
self.comp_eps = comp_eps
def forward(self, X: Tensor, **kwargs):
"""
Assumes X is a complex STFT (complex spectrogram).
"""
alpha = self.compression_exponent
beta = self.scale_factor
if alpha != 1:
X = X + self.comp_eps
X = X.abs()**alpha * torch.exp(1j * X.angle())
return X * beta
def invert(self, X: Tensor, **kwargs):
"""
Assumes X is an amplitude-compressed and scaled complex STFT.
"""
alpha = self.compression_exponent
beta = self.scale_factor
X = X / beta
if alpha != 1:
X = X.abs()**(1/alpha) * torch.exp(1j * X.angle())
return X