Initial release: ComfyUI-UniverSR

ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 12:59:42 +02:00
commit 5f29b225b7
20 changed files with 2129 additions and 0 deletions
+135
View File
@@ -0,0 +1,135 @@
import math
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
class InvertibleFeatureExtractor(nn.Module, ABC):
"""
An invertible feature extractor, i.e. a one-to-one mapping that has a forward and a true inverse.
It should hold up to numerical error that `extractor.invert(extractor(x)) == x`.
"""
@abstractmethod
def forward(self, x, **kwargs):
pass
@abstractmethod
def invert(self, x, **kwargs):
pass
def analysis_synthesis(self, x, **kwargs):
return self.invert(self.forward(x, **kwargs), **kwargs)
class AmplitudeCompressedComplexSTFT(InvertibleFeatureExtractor):
"""
A convenient composition of ComplexSTFT() and CompressAmplitudesAndScale().
"""
def __init__(
self,
window_fn, n_fft, sampling_rate,
alpha, beta, comp_eps,
hop_length=None, n_hops=None,
learnable_window=False,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.complex_stft = ComplexSTFT(
window_fn, n_fft, sampling_rate, hop_length=hop_length, n_hops=n_hops,
learnable_window=learnable_window,
)
self.compress = CompressAmplitudesAndScale(
compression_exponent=alpha,
scale_factor=beta,
comp_eps=comp_eps,
)
def forward(self, x: Tensor, **kwargs):
X = self.complex_stft(x, **kwargs)
out = self.compress(X, **kwargs)
return out
def invert(self, X: Tensor, **kwargs):
X = self.compress.invert(X, **kwargs)
x = self.complex_stft.invert(X, **kwargs)
return x
class ComplexSTFT(InvertibleFeatureExtractor):
def __init__(
self, window_fn, n_fft, sampling_rate, hop_length=None, n_hops=None, learnable_window=False,
*args, **kwargs):
super().__init__(*args, **kwargs)
assert (hop_length is not None) ^ (n_hops is not None),\
"Exactly one of {hop_length, n_hops} must be specified!"
if hop_length is None:
hop_length = int(math.ceil(n_fft / n_hops))
window_fn = getattr(torch.signal.windows, window_fn)
self.learnable_window = learnable_window
self.window = nn.Parameter(window_fn(n_fft), requires_grad=learnable_window)
self.n_fft = n_fft
self.hop_length = hop_length
self.sampling_rate = sampling_rate
self.center = True
def forward(self, x: Tensor, **kwargs):
"""Assumes x is an audio tensor of shape [B, C, T] or [B, T]
[B,C,T] -> [B,C,F,T]
[B,C,T] -> [B,F,T]
"""
bc = "b c" if x.ndim == 3 else "b"
X = torch.stft(
rearrange(x, f"{bc} t -> ({bc}) t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(x.device), center=self.center,
onesided=True, return_complex=True,
)
X = rearrange(X, f"({bc}) f t -> {bc} f t", b=x.shape[0])
return X
def invert(self, X: Tensor, orig_length: Optional[int] = None, **kwargs):
"""Assumes X is a (complex) spectrogram tensor of shape [B, C, F, T] or [B, F, T]"""
bc = "b c" if X.ndim == 4 else "b"
x = torch.istft(
rearrange(X, f"{bc} f t -> ({bc}) f t"), n_fft=self.n_fft, hop_length=self.hop_length,
window=self.window.to(X.device), center=self.center,
onesided=True, return_complex=False,
length=orig_length,
)
x = rearrange(x, f"({bc}) t -> {bc} t", b=X.shape[0])
return x
class CompressAmplitudesAndScale(InvertibleFeatureExtractor):
def __init__(self, compression_exponent: float, scale_factor: float, comp_eps: float, *args, **kwargs):
super().__init__()
self.compression_exponent = compression_exponent
self.scale_factor = scale_factor
self.comp_eps = comp_eps
def forward(self, X: Tensor, **kwargs):
"""
Assumes X is a complex STFT (complex spectrogram).
"""
alpha = self.compression_exponent
beta = self.scale_factor
if alpha != 1:
X = X + self.comp_eps
X = X.abs()**alpha * torch.exp(1j * X.angle())
return X * beta
def invert(self, X: Tensor, **kwargs):
"""
Assumes X is an amplitude-compressed and scaled complex STFT.
"""
alpha = self.compression_exponent
beta = self.scale_factor
X = X / beta
if alpha != 1:
X = X.abs()**(1/alpha) * torch.exp(1j * X.angle())
return X