Initial release: ComfyUI-UniverSR
ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching. - UniverSR Model Loader: presets auto-download to models/universr, plus local dir / raw .pth (from_local) loading, with caching. - UniverSR Super-Resolution: chunked overlap-add for long audio, per-channel stereo, seed control with global-RNG isolation, wet/dry blend, and an optional before/after spectrogram. - Vendors the universr inference package under vendor/ (prefers an installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Vendored
+351
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
UniverSR: Unified and Versatile Audio Super-Resolution via Vocoder-Free Flow Matching
|
||||
Inference wrapper module.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import yaml
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from universr.models.unet import ConvNeXtUNetCond
|
||||
from universr.flow.path import OriginalCFMPath
|
||||
from universr.flow.solver import CFGVectorFieldODE, VectorFieldODE, TorchDiffeqSolver
|
||||
from universr.utils.spectral_ops import AmplitudeCompressedComplexSTFT
|
||||
|
||||
|
||||
# Supported input sample rates (kHz) and their corresponding LR frequency bins
|
||||
SUPPORTED_INPUT_SR = {8000, 12000, 16000, 24000}
|
||||
TARGET_SR = 48000
|
||||
|
||||
|
||||
class UniverSR(torch.nn.Module):
|
||||
"""
|
||||
UniverSR inference wrapper.
|
||||
|
||||
Performs audio super-resolution from low sample rates (8/12/16/24 kHz)
|
||||
to 48 kHz using vocoder-free flow matching in the complex STFT domain.
|
||||
|
||||
Example:
|
||||
>>> model = UniverSR.from_pretrained("woongzip1/universr-speech")
|
||||
>>> output = model.enhance("input.wav", input_sr=16000)
|
||||
>>> torchaudio.save("output.wav", output.cpu(), 48000)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ConvNeXtUNetCond,
|
||||
transform: AmplitudeCompressedComplexSTFT,
|
||||
path: OriginalCFMPath,
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.transform = transform
|
||||
self.path = path
|
||||
self._device = device
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
repo_id_or_path: str,
|
||||
device: str = "cuda",
|
||||
revision: Optional[str] = None,
|
||||
) -> "UniverSR":
|
||||
"""
|
||||
Load a pretrained UniverSR model.
|
||||
|
||||
Args:
|
||||
repo_id_or_path: HuggingFace repo ID (e.g. "woongzip1/universr-speech")
|
||||
or local directory path containing config.yaml and pytorch_model.bin.
|
||||
device: Device to load the model on.
|
||||
revision: Optional HuggingFace revision (branch, tag, or commit hash).
|
||||
|
||||
Returns:
|
||||
UniverSR instance ready for inference.
|
||||
"""
|
||||
if os.path.isdir(repo_id_or_path):
|
||||
config_path = os.path.join(repo_id_or_path, "config.yaml")
|
||||
model_path = os.path.join(repo_id_or_path, "pytorch_model.bin")
|
||||
else:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=repo_id_or_path, filename="config.yaml", revision=revision
|
||||
)
|
||||
model_path = hf_hub_download(
|
||||
repo_id=repo_id_or_path, filename="pytorch_model.bin", revision=revision
|
||||
)
|
||||
|
||||
# Load config
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Build model
|
||||
model = ConvNeXtUNetCond(**config["model"])
|
||||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device).eval()
|
||||
|
||||
# Build transform
|
||||
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
|
||||
transform.to(device)
|
||||
|
||||
# Build probability path
|
||||
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
|
||||
path = OriginalCFMPath(**path_args)
|
||||
|
||||
return cls(model=model, transform=transform, path=path, device=device)
|
||||
|
||||
@classmethod
|
||||
def from_local(
|
||||
cls,
|
||||
ckpt_path: str,
|
||||
config_path: str,
|
||||
device: str = "cuda",
|
||||
) -> "UniverSR":
|
||||
"""
|
||||
Load UniverSR from a local checkpoint (e.g. training checkpoint with optimizer state).
|
||||
|
||||
This handles the standard training checkpoint format where weights are stored
|
||||
under the 'model_state_dict' key, as opposed to from_pretrained() which expects
|
||||
a clean state_dict saved as pytorch_model.bin.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to checkpoint file (.pth).
|
||||
config_path: Path to YAML config file.
|
||||
device: Device to load the model on.
|
||||
|
||||
Returns:
|
||||
UniverSR instance ready for inference.
|
||||
"""
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model = ConvNeXtUNetCond(**config["model"])
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
|
||||
# Handle both formats: raw state_dict or training checkpoint
|
||||
if "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
model.to(device).eval()
|
||||
|
||||
transform = AmplitudeCompressedComplexSTFT(**config["transform"])
|
||||
transform.to(device)
|
||||
|
||||
path_args = config.get("path", {}).get("init_args", {"sigma_min": 1e-4})
|
||||
path = OriginalCFMPath(**path_args)
|
||||
|
||||
return cls(model=model, transform=transform, path=path, device=device)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public API #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(
|
||||
self,
|
||||
audio: Union[str, torch.Tensor, np.ndarray],
|
||||
input_sr: Optional[int] = None,
|
||||
target_sr: int = TARGET_SR,
|
||||
ode_method: str = "midpoint",
|
||||
ode_steps: int = 4,
|
||||
guidance_scale: Optional[float] = 1.5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Enhance a low-resolution audio signal to high-resolution.
|
||||
|
||||
Args:
|
||||
audio: Input audio. Can be:
|
||||
- str: path to a .wav file
|
||||
- torch.Tensor: waveform tensor of shape (T,), (1, T), or (1, 1, T)
|
||||
- np.ndarray: waveform array
|
||||
input_sr: Effective bandwidth of the input in Hz (e.g. 8000, 16000).
|
||||
For file input: auto-detected from the file's native sample rate
|
||||
if it matches a supported rate (8/12/16/24 kHz). Required if the
|
||||
file is already at 48 kHz but has limited bandwidth.
|
||||
For tensor/array input: always required.
|
||||
target_sr: Target sample rate in Hz. Default: 48000.
|
||||
ode_method: ODE solver method. One of 'euler', 'midpoint', 'rk4'.
|
||||
ode_steps: Number of ODE integration steps.
|
||||
guidance_scale: Classifier-free guidance scale. None or 0 disables CFG.
|
||||
|
||||
Returns:
|
||||
Enhanced waveform tensor of shape (1,T) at target_sr.
|
||||
"""
|
||||
# Load audio
|
||||
wav, file_sr = self._load_audio(audio, input_sr=input_sr)
|
||||
wav = wav.to(self._device)
|
||||
|
||||
# Determine the effective bandwidth SR
|
||||
effective_sr = input_sr if input_sr is not None else file_sr
|
||||
|
||||
if effective_sr not in SUPPORTED_INPUT_SR:
|
||||
if effective_sr == target_sr and input_sr is None:
|
||||
raise ValueError(
|
||||
f"Input audio is already at {target_sr} Hz. "
|
||||
f"Please specify input_sr to indicate the effective bandwidth "
|
||||
f"(e.g., input_sr=16000). Supported: {sorted(SUPPORTED_INPUT_SR)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Effective input sample rate {effective_sr} Hz is not supported. "
|
||||
f"Supported rates: {sorted(SUPPORTED_INPUT_SR)}"
|
||||
)
|
||||
|
||||
# Prepare the 48 kHz LR input for the model
|
||||
if file_sr == target_sr:
|
||||
# Simulate the training degradation: downsample → upsample to match
|
||||
wav = self._apply_bandwidth_limit(wav, effective_sr, target_sr)
|
||||
elif file_sr != target_sr:
|
||||
# File is truly low-resolution; resample up to 48 kHz
|
||||
wav = torchaudio.functional.resample(wav, orig_freq=file_sr, new_freq=target_sr)
|
||||
|
||||
# Minimum length guard
|
||||
MIN_SAMPLES = 32_768
|
||||
original_len = wav.shape[-1]
|
||||
wav = torch.nn.functional.pad(wav, (0, max(0, MIN_SAMPLES - wav.shape[-1])))
|
||||
|
||||
# Ensure shape is [B, C, T] = [1, 1, T]
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0).unsqueeze(0)
|
||||
elif wav.dim() == 2:
|
||||
wav = wav.unsqueeze(0)
|
||||
|
||||
sr_khz = effective_sr // 1000
|
||||
|
||||
# Run flow matching SR
|
||||
output = self._inference(wav, sr_khz, ode_method, ode_steps, guidance_scale)
|
||||
|
||||
# (1,T)
|
||||
return output[..., :original_len]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal methods #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _load_audio(
|
||||
self, audio: Union[str, torch.Tensor, np.ndarray], input_sr: Optional[int] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Load and validate audio input.
|
||||
|
||||
Returns:
|
||||
(waveform, file_sr): The waveform tensor and its *actual* sample rate.
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
wav, file_sr = torchaudio.load(audio)
|
||||
# Mix to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
return wav, file_sr
|
||||
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio).float()
|
||||
|
||||
if isinstance(audio, torch.Tensor):
|
||||
if input_sr is None:
|
||||
raise ValueError("input_sr is required when passing a tensor or array.")
|
||||
return audio.float(), input_sr
|
||||
|
||||
raise TypeError(f"Unsupported audio type: {type(audio)}")
|
||||
|
||||
def _apply_bandwidth_limit(
|
||||
self, wav: torch.Tensor, effective_sr: int, target_sr: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simulate low-resolution input from a high-sample-rate waveform.
|
||||
|
||||
Applies the same downsample-then-upsample pipeline used during training
|
||||
(see WaveformCollator._apply_lpf) so that the spectral cutoff pattern
|
||||
matches what the model expects.
|
||||
|
||||
Args:
|
||||
wav: Waveform at target_sr. Shape: (1, T) or (T,).
|
||||
effective_sr: The effective bandwidth in Hz (e.g. 8000).
|
||||
target_sr: The native sample rate of wav (e.g. 48000).
|
||||
|
||||
Returns:
|
||||
Bandwidth-limited waveform at target_sr, same length as input.
|
||||
"""
|
||||
original_len = wav.shape[-1]
|
||||
lr = torchaudio.functional.resample(wav, orig_freq=target_sr, new_freq=effective_sr)
|
||||
lr = torchaudio.functional.resample(lr, orig_freq=effective_sr, new_freq=target_sr)
|
||||
return lr[..., :original_len]
|
||||
|
||||
def _preprocess(self, waveform: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert waveform to amplitude-compressed complex STFT representation.
|
||||
[B, C, T] -> [B, 2, F-1, T_frames] (real/imag channels, drop Nyquist bin)
|
||||
"""
|
||||
spec = self.transform(waveform) # [B, C, F, T_frames] complex
|
||||
real = torch.view_as_real(spec.squeeze(1)) # [B, F, T_frames, 2]
|
||||
real = real.permute(0, 3, 1, 2) # [B, 2, F, T_frames]
|
||||
return real[:, :, :-1, :] # drop Nyquist bin
|
||||
|
||||
def _postprocess(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert STFT representation back to waveform.
|
||||
[B, 2, F-1, T_frames] -> [B, T]
|
||||
"""
|
||||
spec = torch.nn.functional.pad(spec, [0, 0, 0, 1], value=0) # restore Nyquist
|
||||
spec = spec.permute(0, 2, 3, 1).contiguous() # [B, F, T, 2]
|
||||
spec = torch.view_as_complex(spec) # [B, F, T] complex
|
||||
waveform = self.transform.invert(spec) # [B, T]
|
||||
return waveform
|
||||
|
||||
def _inference(
|
||||
self,
|
||||
lr_audio: torch.Tensor,
|
||||
sr_khz: int,
|
||||
ode_method: str,
|
||||
ode_steps: int,
|
||||
guidance_scale: Optional[float],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Core inference pipeline:
|
||||
1. STFT the (resampled) LR audio
|
||||
2. Extract LR condition bins
|
||||
3. Sample noise for HF region
|
||||
4. Solve ODE (flow matching)
|
||||
5. Concatenate LR + generated HF
|
||||
6. iSTFT to waveform
|
||||
"""
|
||||
# Frequency bin bookkeeping
|
||||
lr_bin_count = self.model.sr_to_lr_bins[sr_khz]
|
||||
hf_start_bin = self.model.total_freq_bins - self.model.hr_freq_bins
|
||||
|
||||
# STFT
|
||||
Y = self._preprocess(lr_audio) # [B, 2, F-1, T]
|
||||
Y_lr = Y[:, :, :lr_bin_count, :] # LR condition
|
||||
Y_hr = Y[:, :, hf_start_bin:, :] # HR target region (for shape reference)
|
||||
|
||||
# Initial noise
|
||||
x0 = self.path.sample_source(Y_hr).to(self._device)
|
||||
|
||||
# Build ODE solver
|
||||
if guidance_scale is not None and guidance_scale > 0:
|
||||
ode = CFGVectorFieldODE(net=self.model, guidance_scale=guidance_scale)
|
||||
else:
|
||||
ode = VectorFieldODE(net=self.model)
|
||||
solver = TorchDiffeqSolver(ode, method=ode_method)
|
||||
|
||||
# Time discretization
|
||||
ts = torch.linspace(0, 1, ode_steps + 1, device=self._device)
|
||||
|
||||
# Solve ODE
|
||||
x1_spec = solver.simulate(
|
||||
x0, ts=ts, y=Y_lr, sr_values=torch.tensor([sr_khz], device=self._device)
|
||||
)
|
||||
|
||||
# Concatenate LR bins + generated HF bins (handle overlapping region)
|
||||
slice_start = max(0, lr_bin_count - hf_start_bin)
|
||||
x1_spec = x1_spec[:, :, slice_start:, :]
|
||||
full_spec = torch.cat([Y_lr, x1_spec], dim=2)
|
||||
|
||||
# iSTFT
|
||||
output = self._postprocess(full_spec)
|
||||
return output
|
||||
Reference in New Issue
Block a user