Initial release: ComfyUI-UniverSR

ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 12:59:42 +02:00
commit 5f29b225b7
20 changed files with 2129 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
from universr.inference import UniverSR
__version__ = "0.1.0"
__all__ = ["UniverSR"]
View File
+9
View File
@@ -0,0 +1,9 @@
import torch
import torch.nn.functional as F
def flow_matching_loss(predicted_vf: torch.Tensor, target_vf: torch.Tensor) -> torch.Tensor:
"""
Flow matching loss; L2 loss between estimated and target vector field.
"""
return F.mse_loss(predicted_vf, target_vf)
+54
View File
@@ -0,0 +1,54 @@
import importlib
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
class ConditionalProbabilityPath(nn.Module, ABC):
"""Abstract base class for conditional probability paths in flow matching."""
@abstractmethod
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
"""Sample from the source distribution. shape_ref is used only for shape/device."""
@abstractmethod
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Interpolate between source x0 and target x1 at time t."""
@abstractmethod
def get_target_vector_field(
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Compute the target vector field u_t(xt | x1)."""
class OriginalCFMPath(ConditionalProbabilityPath):
def __init__(self, sigma_min: float = 1e-4):
super().__init__()
self.sigma_min = sigma_min
def sample_source(self, shape_ref):
return torch.randn_like(shape_ref)
def sample_xt(self, x0, x1, t):
return t * x1 + (1 - t + self.sigma_min * t) * x0
def get_target_vector_field(self, xt, x0, x1, t):
return x1 - (1 - self.sigma_min) * x0
def get_path(config):
class_path = config.get("class_path")
if not class_path:
raise ValueError("Configuration must contain a 'class_path' key")
try:
module_path, class_name = class_path.rsplit(".", 1)
except ValueError:
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
module = importlib.import_module(module_path)
Class = getattr(module, class_name)
init_args = config.get("init_args", {})
return Class(**init_args)
+127
View File
@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
import torch
from torchdiffeq import odeint
from tqdm import tqdm
from universr.models.unet import ConditionalVectorFieldModel
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class Solver(ABC):
# @abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time, shape (bs, 1, 1, 1)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x.clone()]
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:,t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
xs.append(x.clone())
return torch.stack(xs, dim=1)
class VectorFieldODE(ODE):
def __init__(self, net:ConditionalVectorFieldModel) -> None:
super().__init__()
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t, y, **kwargs)
class CFGVectorFieldODE(ODE):
""" For Classifier Free Guidance """
def __init__(self, net:ConditionalVectorFieldModel, guidance_scale: float = 1.0) -> None:
super().__init__()
self.net = net
self.guidance_scale = guidance_scale
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
guided_vector_field = self.net(xt, t, y, **kwargs)
unguided_vector_field = self.net(xt, t, None, **kwargs)
return (1-self.guidance_scale) * unguided_vector_field + self.guidance_scale * guided_vector_field
class EulerSolver(Solver):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs):
return xt + self.ode.drift_coefficient(xt,t, **kwargs) * h
class TorchDiffeqSolver(Solver):
def __init__(self,
ode: ODE,
method: str = 'euler',
atol: float = 1e-5,
rtol: float = 1e-5,
):
super().__init__()
self.ode = ode
self.method = method
self.atol = atol
self.rtol = rtol
@torch.no_grad()
def simulate(self, x_init: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
x_init: [B,C,H,W]
ts: [N]
return: final state [B,C,H,W]
"""
func = lambda t, x: self.ode.drift_coefficient(xt=x, t=t, **kwargs)
xs = odeint(
func=func,
y0=x_init, t=ts,
method=self.method,
atol=self.atol, rtol=self.rtol) # [N,B,C,H,W]
return xs[-1]
+351
View File
@@ -0,0 +1,351 @@
"""
UniverSR: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching
Inference wrapper module.
"""
import os
from typing import Optional, Union
import numpy as np
import torch
import torchaudio
import yaml
from huggingface_hub import hf_hub_download
from universr.models.unet import ConvNeXtUNetCond
from universr.flow.path import OriginalCFMPath
from universr.flow.solver import CFGVectorFieldODE, VectorFieldODE, TorchDiffeqSolver
from universr.utils.spectral_ops import AmplitudeCompressedComplexSTFT
# Supported input sample rates (kHz) and their corresponding LR frequency bins
SUPPORTED_INPUT_SR = {8000, 12000, 16000, 24000}
TARGET_SR = 48000
class UniverSR(torch.nn.Module):
"""
UniverSR inference wrapper.
Performs audio super-resolution from low sample rates (8/12/16/24 kHz)
to 48 kHz using vocoder-free flow matching in the complex STFT domain.
Example:
>>> model = UniverSR.from_pretrained("woongzip1/universr-speech")
>>> output = model.enhance("input.wav", input_sr=16000)
>>> torchaudio.save("output.wav", output.cpu(), 48000)
"""
def __init__(
self,
model: ConvNeXtUNetCond,
transform: AmplitudeCompressedComplexSTFT,
path: OriginalCFMPath,
device: str = "cuda",
):
super().__init__()
self.model = model
self.transform = transform
self.path = path
self._device = device
@classmethod
def from_pretrained(
cls,
repo_id_or_path: str,
device: str = "cuda",
revision: Optional[str] = None,
) -> "UniverSR":
"""
Load a pretrained UniverSR model.
Args:
repo_id_or_path: HuggingFace repo ID (e.g. "woongzip1/universr-speech")
or local directory path containing config.yaml and pytorch_model.bin.
device: Device to load the model on.
revision: Optional HuggingFace revision (branch, tag, or commit hash).
Returns:
UniverSR instance ready for inference.
"""
if os.path.isdir(repo_id_or_path):
config_path = os.path.join(repo_id_or_path, "config.yaml")
model_path = os.path.join(repo_id_or_path, "pytorch_model.bin")
else:
config_path = hf_hub_download(
repo_id=repo_id_or_path, filename="config.yaml", revision=revision
)
model_path = hf_hub_download(
repo_id=repo_id_or_path, filename="pytorch_model.bin", revision=revision
)
# Load config
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Build model
model = ConvNeXtUNetCond(**config["model"])
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.to(device).eval()
# Build transform
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
transform.to(device)
# Build probability path
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
path = OriginalCFMPath(**path_args)
return cls(model=model, transform=transform, path=path, device=device)
@classmethod
def from_local(
cls,
ckpt_path: str,
config_path: str,
device: str = "cuda",
) -> "UniverSR":
"""
Load UniverSR from a local checkpoint (e.g. training checkpoint with optimizer state).
This handles the standard training checkpoint format where weights are stored
under the 'model_state_dict' key, as opposed to from_pretrained() which expects
a clean state_dict saved as pytorch_model.bin.
Args:
ckpt_path: Path to checkpoint file (.pth).
config_path: Path to YAML config file.
device: Device to load the model on.
Returns:
UniverSR instance ready for inference.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model = ConvNeXtUNetCond(**config["model"])
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
# Handle both formats: raw state_dict or training checkpoint
if "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
else:
model.load_state_dict(ckpt)
model.to(device).eval()
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
transform.to(device)
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
path = OriginalCFMPath(**path_args)
return cls(model=model, transform=transform, path=path, device=device)
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
@torch.no_grad()
def enhance(
self,
audio: Union[str, torch.Tensor, np.ndarray],
input_sr: Optional[int] = None,
target_sr: int = TARGET_SR,
ode_method: str = "midpoint",
ode_steps: int = 4,
guidance_scale: Optional[float] = 1.5,
) -> torch.Tensor:
"""
Enhance a low-resolution audio signal to high-resolution.
Args:
audio: Input audio. Can be:
- str: path to a .wav file
- torch.Tensor: waveform tensor of shape (T,), (1, T), or (1, 1, T)
- np.ndarray: waveform array
input_sr: Effective bandwidth of the input in Hz (e.g. 8000, 16000).
For file input: auto-detected from the file's native sample rate
if it matches a supported rate (8/12/16/24 kHz). Required if the
file is already at 48 kHz but has limited bandwidth.
For tensor/array input: always required.
target_sr: Target sample rate in Hz. Default: 48000.
ode_method: ODE solver method. One of 'euler', 'midpoint', 'rk4'.
ode_steps: Number of ODE integration steps.
guidance_scale: Classifier-free guidance scale. None or 0 disables CFG.
Returns:
Enhanced waveform tensor of shape (1,T) at target_sr.
"""
# Load audio
wav, file_sr = self._load_audio(audio, input_sr=input_sr)
wav = wav.to(self._device)
# Determine the effective bandwidth SR
effective_sr = input_sr if input_sr is not None else file_sr
if effective_sr not in SUPPORTED_INPUT_SR:
if effective_sr == target_sr and input_sr is None:
raise ValueError(
f"Input audio is already at {target_sr} Hz. "
f"Please specify input_sr to indicate the effective bandwidth "
f"(e.g., input_sr=16000). Supported: {sorted(SUPPORTED_INPUT_SR)}"
)
raise ValueError(
f"Effective input sample rate {effective_sr} Hz is not supported. "
f"Supported rates: {sorted(SUPPORTED_INPUT_SR)}"
)
# Prepare the 48 kHz LR input for the model
if file_sr == target_sr:
# Simulate the training degradation: downsample → upsample to match
wav = self._apply_bandwidth_limit(wav, effective_sr, target_sr)
elif file_sr != target_sr:
# File is truly low-resolution; resample up to 48 kHz
wav = torchaudio.functional.resample(wav, orig_freq=file_sr, new_freq=target_sr)
# Minimum length guard
MIN_SAMPLES = 32_768
original_len = wav.shape[-1]
wav = torch.nn.functional.pad(wav, (0, max(0, MIN_SAMPLES - wav.shape[-1])))
# Ensure shape is [B, C, T] = [1, 1, T]
if wav.dim() == 1:
wav = wav.unsqueeze(0).unsqueeze(0)
elif wav.dim() == 2:
wav = wav.unsqueeze(0)
sr_khz = effective_sr // 1000
# Run flow matching SR
output = self._inference(wav, sr_khz, ode_method, ode_steps, guidance_scale)
# (1,T)
return output[..., :original_len]
# ------------------------------------------------------------------ #
# Internal methods #
# ------------------------------------------------------------------ #
def _load_audio(
self, audio: Union[str, torch.Tensor, np.ndarray], input_sr: Optional[int] = None,
) -> tuple:
"""
Load and validate audio input.
Returns:
(waveform, file_sr): The waveform tensor and its *actual* sample rate.
"""
if isinstance(audio, str):
wav, file_sr = torchaudio.load(audio)
# Mix to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav, file_sr
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if isinstance(audio, torch.Tensor):
if input_sr is None:
raise ValueError("input_sr is required when passing a tensor or array.")
return audio.float(), input_sr
raise TypeError(f"Unsupported audio type: {type(audio)}")
def _apply_bandwidth_limit(
self, wav: torch.Tensor, effective_sr: int, target_sr: int,
) -> torch.Tensor:
"""
Simulate low-resolution input from a high-sample-rate waveform.
Applies the same downsample-then-upsample pipeline used during training
(see WaveformCollator._apply_lpf) so that the spectral cutoff pattern
matches what the model expects.
Args:
wav: Waveform at target_sr. Shape: (1, T) or (T,).
effective_sr: The effective bandwidth in Hz (e.g. 8000).
target_sr: The native sample rate of wav (e.g. 48000).
Returns:
Bandwidth-limited waveform at target_sr, same length as input.
"""
original_len = wav.shape[-1]
lr = torchaudio.functional.resample(wav, orig_freq=target_sr, new_freq=effective_sr)
lr = torchaudio.functional.resample(lr, orig_freq=effective_sr, new_freq=target_sr)
return lr[..., :original_len]
def _preprocess(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Convert waveform to amplitude-compressed complex STFT representation.
[B, C, T] -> [B, 2, F-1, T_frames] (real/imag channels, drop Nyquist bin)
"""
spec = self.transform(waveform) # [B, C, F, T_frames] complex
real = torch.view_as_real(spec.squeeze(1)) # [B, F, T_frames, 2]
real = real.permute(0, 3, 1, 2) # [B, 2, F, T_frames]
return real[:, :, :-1, :] # drop Nyquist bin
def _postprocess(self, spec: torch.Tensor) -> torch.Tensor:
"""
Convert STFT representation back to waveform.
[B, 2, F-1, T_frames] -> [B, T]
"""
spec = torch.nn.functional.pad(spec, [0, 0, 0, 1], value=0) # restore Nyquist
spec = spec.permute(0, 2, 3, 1).contiguous() # [B, F, T, 2]
spec = torch.view_as_complex(spec) # [B, F, T] complex
waveform = self.transform.invert(spec) # [B, T]
return waveform
def _inference(
self,
lr_audio: torch.Tensor,
sr_khz: int,
ode_method: str,
ode_steps: int,
guidance_scale: Optional[float],
) -> torch.Tensor:
"""
Core inference pipeline:
1. STFT the (resampled) LR audio
2. Extract LR condition bins
3. Sample noise for HF region
4. Solve ODE (flow matching)
5. Concatenate LR + generated HF
6. iSTFT to waveform
"""
# Frequency bin bookkeeping
lr_bin_count = self.model.sr_to_lr_bins[sr_khz]
hf_start_bin = self.model.total_freq_bins - self.model.hr_freq_bins
# STFT
Y = self._preprocess(lr_audio) # [B, 2, F-1, T]
Y_lr = Y[:, :, :lr_bin_count, :] # LR condition
Y_hr = Y[:, :, hf_start_bin:, :] # HR target region (for shape reference)
# Initial noise
x0 = self.path.sample_source(Y_hr).to(self._device)
# Build ODE solver
if guidance_scale is not None and guidance_scale > 0:
ode = CFGVectorFieldODE(net=self.model, guidance_scale=guidance_scale)
else:
ode = VectorFieldODE(net=self.model)
solver = TorchDiffeqSolver(ode, method=ode_method)
# Time discretization
ts = torch.linspace(0, 1, ode_steps + 1, device=self._device)
# Solve ODE
x1_spec = solver.simulate(
x0, ts=ts, y=Y_lr, sr_values=torch.tensor([sr_khz], device=self._device)
)
# Concatenate LR bins + generated HF bins (handle overlapping region)
slice_start = max(0, lr_bin_count - hf_start_bin)
x1_spec = x1_spec[:, :, slice_start:, :]
full_spec = torch.cat([Y_lr, x1_spec], dim=2)
# iSTFT
output = self._postprocess(full_spec)
return output
View File
+470
View File
@@ -0,0 +1,470 @@
import math
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
class ConditionalVectorFieldModel(nn.Module, ABC):
"""
Base class for DNN-based VF model
MLP-parameterization of the learned vector field u_t^theta(x)
"""
@abstractmethod
def forward(self, x:torch.Tensor, t:torch.Tensor, y:torch.Tensor):
"""
Args:
- x: (bs, c, h, w)
- t: (bs, 1, 1, 1)
- y: (bs,)
Returns:
- u_t^theta(x|y): (bs, c, h, w)
"""
pass
class SinusoidalTimeEmbedding(nn.Module):
"""
Based on https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/karras_unet.py#L183
& DiffWave / WaveFM
"""
def __init__(self, dim: int=128, mode: str='learnable', time_scale=1):
super().__init__()
assert dim % 2 == 0, "Dimension must be an even number"
assert mode in ['fixed', 'learnable'], "Mode must be 'fixed' or 'learnable'"
self.dim = dim # D
self.half_dim = dim // 2
self.mode = mode
self.time_scale = time_scale # 1(diffusion) or 100(flow)
if self.mode == 'learnable':
self.weights = nn.Parameter(torch.randn(1, self.half_dim)) # [1,D/2]
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Args:
- t: Time tensor. Shape can be [B] or [B, 1].
Returns:
- embeddings: Time embeddings of shape [B, D]
"""
# Ensure t has shape [B, 1] for broadcasting
t = t.view(-1, 1)
device = t.device
if self.mode == 'fixed':
# Create a sequence from 0 to D/2 - 1
pos = torch.arange(self.half_dim, device=device).unsqueeze(0) # [1,D/2]
freqs = self.time_scale * t * 10.0 ** (pos * 4.0 / (self.half_dim - 1)) # 100 is a magnitude hyperparameter
sin_embed = torch.sin(freqs)
cos_embed = torch.cos(freqs)
return torch.cat([sin_embed, cos_embed], dim=-1)
elif self.mode == 'learnable':
freqs = t * self.weights * 2 * math.pi
sin_embed = torch.sin(freqs)
cos_embed = torch.cos(freqs)
return torch.cat([sin_embed, cos_embed], dim=-1) * math.sqrt(2)
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Block(nn.Module):
""" ConvNeXt V2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, padding_mode="reflect")
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.grn = GRN(4 * dim) # GRN for V2
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
# This Block preserves the input shape (C, H, W) -> (C, H, W)
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # [N,C,H,W] -> [N,H,W,C]
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # [N,H,W,C] -> [N,C,H,W]
x = input + self.drop_path(x) # Residual connection
return x
class BlockWithEmbedding(nn.Module):
""" ConvNeXt block with time embedding injection
"""
def __init__(self, dim, drop_path=0., time_embed_dim=128):
super().__init__()
self.block = Block(dim, drop_path)
self.time_adapter = nn.Sequential(
nn.Linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, dim),
)
def forward(self, x, t_embed):
t_embed = self.time_adapter(t_embed).unsqueeze(-1).unsqueeze(-1) # [B,C,1,1]
x = x + t_embed
x = self.block(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.blocks= nn.ModuleList(
[BlockWithEmbedding(dim_in, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
self.downsampler = nn.Sequential(
LayerNorm(dim_in, eps=1e-6, data_format="channels_first"),
nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2),
)
def forward(self, x, t_emb):
for block in self.blocks:
x = block(x, t_emb)
x = self.downsampler(x)
return x
class Midcoder(nn.Module):
def __init__(self, dim, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.blocks = nn.ModuleList(
[BlockWithEmbedding(dim, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
def forward(self, x, t_emb):
for block in self.blocks:
x = block(x, t_emb)
return x
class DecoderBlock(nn.Module):
def __init__(self, dim_in, dim_out, num_blocks, drop_path, time_embed_dim):
super().__init__()
self.upsampler = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=2, stride=2)
self.blocks = nn.ModuleList(
[BlockWithEmbedding(dim_out, drop_path, time_embed_dim)
for _ in range(num_blocks)]
)
def forward(self, x, t_emb):
x = self.upsampler(x)
for block in self.blocks:
x = block(x, t_emb)
return x
class ConditioningEncoder2D(nn.Module):
def __init__(self, cond_dim, num_blocks=3):
"""
Args:
cond_dim (int): The main conditioning dimension (D).
num_blocks (int): The number of shared 2D ConvNeXt blocks.
"""
super().__init__()
self.cond_dim = cond_dim
self.film_generator = nn.Linear(cond_dim, 4)
self.head = nn.Conv2d(2, cond_dim, kernel_size=1)
self.sr_adapter = nn.Sequential(
nn.Linear(cond_dim, cond_dim),
nn.GELU(),
nn.Linear(cond_dim, cond_dim * 2)
)
self.blocks = nn.Sequential(*[
Block(dim=cond_dim) for _ in range(num_blocks)
])
self.freq_pool = nn.AdaptiveAvgPool2d((1,None))
def forward(self, y_lr, f_emb_lr, sr_emb):
"""
Args:
y_lr (Tensor): LR Spec [B, 2, F1, T]
f_emb : Freq positional embedding for lr spec [F1,D]
sr_emb: Sampling rate embedding [B,D]
Returns:
z (Tensor): Conditioning Emb [B, D, T]
"""
film_params = self.film_generator(f_emb_lr) # [F1, 4]
gamma, beta = torch.chunk(film_params, chunks=2, dim=-1) # [F1,2]
gamma = rearrange(gamma, 'f c -> 1 c f 1') # [1,2,F1,1]
beta = rearrange(beta, 'f c -> 1 c f 1') # [1,2,F1,1]
z = y_lr * gamma + beta # [B, 2, F1, T]
z = self.head(z) # [B,D,F1,T]
sr_film_params = self.sr_adapter(sr_emb) # [B, 2*D]
sr_gamma, sr_beta = torch.chunk(sr_film_params, 2, dim=-1) # [B,D]
sr_gamma = sr_gamma.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1]
sr_beta = sr_beta.unsqueeze(-1).unsqueeze(-1) # [B,D,1,1]
z = z * sr_gamma + sr_beta # [B,D,F1,T]
z = self.blocks(z) # [B,D,F1,T]
z = self.freq_pool(z).squeeze(2) # [B,D,T]
return z
class FrequencyPositionalEmbedding(nn.Module):
def __init__(self, num_bins: int, emb_dim: int):
super().__init__()
# (F, D)
pe = torch.zeros(num_bins, emb_dim)
position = torch.arange(num_bins, dtype=torch.float32).unsqueeze(1) # (F,1)
div_term = torch.exp(
torch.arange(0, emb_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / emb_dim)
) # (D/2,)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self):
# returns (F, D)
return self.pe
class ConvNeXtUNetCond(ConditionalVectorFieldModel):
def __init__(self, in_channels=2, out_channels=2,
dims=[64,128,256,512], depths=[2,2,2,4],
drop_path=0., time_dim=128,
cond_dim=256, # D1
total_freq_bins=512,
hr_freq_bins=432,
feature_enc_layers=10,
cond_dropout_prob=0.1,
sr_to_lr_bins={8: 80, 12: 128, 16: 170, 24: 256},
):
super().__init__()
self.strides = 2**len(dims)
self.time_embedder = SinusoidalTimeEmbedding(dim=time_dim)
self.total_freq_bins = total_freq_bins
self.hr_freq_bins = hr_freq_bins
self.sr_to_lr_bins = sr_to_lr_bins
self.sr_values_list = sorted(list(sr_to_lr_bins.keys())) # (8,12,16,24) kHz
self.sr_to_idx = {sr: i for i, sr in enumerate(self.sr_values_list)}
self.sr_embedder = nn.Embedding(len(self.sr_values_list), cond_dim) # [4,D]
self.cond_dropout_prob = cond_dropout_prob
self.cond_dim = cond_dim
self.uncond_emb = nn.Parameter(torch.randn(cond_dim))
self.sr_projector = nn.Linear(cond_dim, time_dim) # projector to t_emb
self.freq_pos_enc = FrequencyPositionalEmbedding(num_bins=total_freq_bins, emb_dim=cond_dim)
self.film_generator = nn.Linear(cond_dim, cond_dim * 2)
self.conditioning_encoder = ConditioningEncoder2D(
cond_dim=cond_dim,
num_blocks=feature_enc_layers,
)
self.init_conv = nn.Sequential(
nn.Conv2d(in_channels+cond_dim, dims[0], kernel_size=1),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
# Encoder
for i in range(len(depths)):
dim_in = dims[i]
dim_out = dims[i+1] if i+1 < len(dims) else dims[i]
self.encoders.append(EncoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim))
# Midcoder
self.midcoder = Midcoder(dims[-1], depths[-1], drop_path, time_dim)
# Decoder
for i in reversed(range(len(depths))):
dim_in = dims[i+1] if i+1 < len(dims) else dims[i]
dim_out = dims[i]
self.decoders.append(DecoderBlock(dim_in, dim_out, depths[i], drop_path, time_dim))
self.final_conv = nn.Conv2d(dims[0], out_channels, kernel_size=1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def _pad_frames(self, x):
num_frames = x.shape[-1]
pad_len = (self.strides - num_frames % self.strides) % self.strides
if pad_len:
x = torch.nn.functional.pad(x, [0,pad_len,0,0], mode='reflect')
assert x.shape[-1] % self.strides == 0, \
f"After padding, time dim:{x.shape(-1)} must be multiples of {self.strides}"
return x, pad_len
def forward(self, x, t, y, sr_values):
"""
x : x_t noisy spec [B,2,F,T]
t : time embedding [B,1] or [B]
y : condition lr spectrum [B,2,F,T]
sr_values: input sampling_rate [B] or [1]
"""
# Pad logic
x, pad_len = self._pad_frames(x)
if pad_len > 0 and y is not None:
y = torch.nn.functional.pad(y, [0, pad_len, 0, 0], mode='reflect')
B, _, F, T = x.shape
# get number of lr bins for input sr
if isinstance(sr_values, int):
current_sr = sr_values
else:
current_sr = sr_values[0].item() if hasattr(sr_values[0], 'item') else sr_values[0]
lr_bin_count = self.sr_to_lr_bins[current_sr]
# freq pe
pe_full = self.freq_pos_enc() # [F,D]
pe_low = pe_full[:lr_bin_count,:] # [F1,D]
hf_start_bin = self.total_freq_bins - self.hr_freq_bins # 512 - 432
pe_high = pe_full[hf_start_bin:, :] # [F2=432,D]
# time / sr embedding
t_embed = self.time_embedder(t) # [B,timedim]
sr_idx = self.sr_to_idx[current_sr]
sr_emb = self.sr_embedder(torch.tensor([sr_idx], device=x.device)).expand(B,-1) # [B, D]
t_embed = t_embed + self.sr_projector(sr_emb) # [B, timedim]
if y is not None: # (Training)
y_cond_real = self.conditioning_encoder(y, pe_low, sr_emb) # [B,D,T]
# Uncond token masking
if self.training and self.cond_dropout_prob > 0:
# random mask for uncond
mask = (torch.rand(B, device=x.device) < self.cond_dropout_prob) # [B]
uncond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T) # [B,D,T]
y_cond = torch.where(mask.reshape(B,1,1), uncond, y_cond_real)
else:
y_cond = y_cond_real
else: # Unconditional (inference)
y_cond = self.uncond_emb.reshape(1,self.cond_dim,1).expand(B,self.cond_dim,T)
y_cond = y_cond.unsqueeze(2) # [B,D,1,T]
# FiLM Conditioning of freq-bins
film_params = self.film_generator(pe_high) # [F2,D] -> [F2,2D]
gamma_high, beta_high = torch.chunk(film_params, chunks=2, dim=-1) # [F2, D]
gamma_high = rearrange(gamma_high, 'f d -> 1 d f 1') # [1,D,F2,1]
beta_high = rearrange(beta_high, 'f d -> 1 d f 1') # [1,D,F2,1]
spatial_cond = y_cond * gamma_high + beta_high # [B,D,F2,T]
x = torch.cat([x, spatial_cond], dim=1) # [B,2+D,F2,T]
x = self.init_conv(x)
skip_connections = [x]
for encoder in self.encoders:
x = encoder(x, t_embed)
skip_connections.append(x)
x = self.midcoder(x, t_embed)
for decoder in self.decoders:
skip = skip_connections.pop()
if x.shape != skip.shape:
x = nn.functional.interpolate(x, size=skip.shape[2:])
x = x + skip
x = decoder(x, t_embed)
skip = skip_connections.pop()
x = x + skip
x = self.final_conv(x)
# Crop out
if pad_len:
x = x[...,:-pad_len]
return x
def main():
"""
Dummy forward pass test for ConvNeXtUNetCond.
"""
from torchinfo import summary
batch_size = 2
hr_freq_bins = 432 # High-res bins to be generated (fixed)
lr_freq_bins = 128 # Low-res bins for this specific test case (e.g., for 8kHz)
T = 256 # Number of time frames
sr_config = {8: 80, 12: 128, 16: 170, 24: 256}
model = ConvNeXtUNetCond(
in_channels=2,
out_channels=2,
dims=[96, 192, 384, 768],
depths=[2, 2, 4, 2],
time_dim=256,
cond_dim=384,
total_freq_bins=512,
hr_freq_bins=hr_freq_bins,
feature_enc_layers=4,
cond_dropout_prob=0.1,
sr_to_lr_bins=sr_config, # Pass the dictionary
)
x = torch.randn(batch_size, 2, hr_freq_bins, T)
y = torch.randn(batch_size, 2, lr_freq_bins, T)
t = torch.randint(0, 1000, (batch_size,))
sr_values = [12] * batch_size
print("\n--- Model Summary ---")
summary(
model,
input_data=[x, t, y, sr_values],
depth=4,
col_names=("input_size", "output_size", "num_params",
"kernel_size", "mult_adds", "trainable"),
verbose=1
)
if __name__ == "__main__":
main()
View File
+135
View File
@@ -0,0 +1,135 @@
import math
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
class InvertibleFeatureExtractor(nn.Module, ABC):
"""
An invertible feature extractor, i.e. a one-to-one mapping that has a forward and a true inverse.
It should hold up to numerical error that `extractor.invert(extractor(x)) == x`.
"""
@abstractmethod
def forward(self, x, **kwargs):
pass
@abstractmethod
def invert(self, x, **kwargs):
pass
def analysis_synthesis(self, x, **kwargs):
return self.invert(self.forward(x, **kwargs), **kwargs)
class AmplitudeCompressedComplexSTFT(InvertibleFeatureExtractor):
"""
A convenient composition of ComplexSTFT() and CompressAmplitudesAndScale().
"""
def __init__(
self,
window_fn, n_fft, sampling_rate,
alpha, beta, comp_eps,
hop_length=None, n_hops=None,
learnable_window=False,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.complex_stft = ComplexSTFT(
window_fn, n_fft, sampling_rate, hop_length=hop_length, n_hops=n_hops,
learnable_window=learnable_window,
)
self.compress = CompressAmplitudesAndScale(
compression_exponent=alpha,
scale_factor=beta,
comp_eps=comp_eps,
)
def forward(self, x: Tensor, **kwargs):
X = self.complex_stft(x, **kwargs)
out = self.compress(X, **kwargs)
return out
def invert(self, X: Tensor, **kwargs):
X = self.compress.invert(X, **kwargs)
x = self.complex_stft.invert(X, **kwargs)
return x
class ComplexSTFT(InvertibleFeatureExtractor):
def __init__(
self, window_fn, n_fft, sampling_rate, hop_length=None, n_hops=None, learnable_window=False,
*args, **kwargs):
super().__init__(*args, **kwargs)
assert (hop_length is not None) ^ (n_hops is not None),\
"Exactly one of {hop_length, n_hops} must be specified!"
if hop_length is None:
hop_length = int(math.ceil(n_fft / n_hops))
window_fn = getattr(torch.signal.windows, window_fn)
self.learnable_window = learnable_window
self.window = nn.Parameter(window_fn(n_fft), requires_grad=learnable_window)
self.n_fft = n_fft
self.hop_length = hop_length
self.sampling_rate = sampling_rate
self.center = True
def forward(self, x: Tensor, **kwargs):
"""Assumes x is an audio tensor of shape [B, C, T] or [B, T]
[B,C,T] -> [B,C,F,T]
[B,C,T] -> [B,F,T]
"""
bc = "b c" if x.ndim == 3 else "b"
X = torch.stft(
rearrange(x, f"{bc} t -> ({bc}) t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(x.device), center=self.center,
onesided=True, return_complex=True,
)
X = rearrange(X, f"({bc}) f t -> {bc} f t", b=x.shape[0])
return X
def invert(self, X: Tensor, orig_length: Optional[int] = None, **kwargs):
"""Assumes X is a (complex) spectrogram tensor of shape [B, C, F, T] or [B, F, T]"""
bc = "b c" if X.ndim == 4 else "b"
x = torch.istft(
rearrange(X, f"{bc} f t -> ({bc}) f t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(X.device), center=self.center,
onesided=True, return_complex=False,
length=orig_length,
)
x = rearrange(x, f"({bc}) t -> {bc} t", b=X.shape[0])
return x
class CompressAmplitudesAndScale(InvertibleFeatureExtractor):
def __init__(self, compression_exponent: float, scale_factor: float, comp_eps: float, *args, **kwargs):
super().__init__()
self.compression_exponent = compression_exponent
self.scale_factor = scale_factor
self.comp_eps = comp_eps
def forward(self, X: Tensor, **kwargs):
"""
Assumes X is a complex STFT (complex spectrogram).
"""
alpha = self.compression_exponent
beta = self.scale_factor
if alpha != 1:
X = X + self.comp_eps
X = X.abs()**alpha * torch.exp(1j * X.angle())
return X * beta
def invert(self, X: Tensor, **kwargs):
"""
Assumes X is an amplitude-compressed and scaled complex STFT.
"""
alpha = self.compression_exponent
beta = self.scale_factor
X = X / beta
if alpha != 1:
X = X.abs()**(1/alpha) * torch.exp(1j * X.angle())
return X