Initial release: ComfyUI-UniverSR
ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching. - UniverSR Model Loader: presets auto-download to models/universr, plus local dir / raw .pth (from_local) loading, with caching. - UniverSR Super-Resolution: chunked overlap-add for long audio, per-channel stereo, seed control with global-RNG isolation, wet/dry blend, and an optional before/after spectrogram. - Vendors the universr inference package under vendor/ (prefers an installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+135
@@ -0,0 +1,135 @@
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
class InvertibleFeatureExtractor(nn.Module, ABC):
|
||||
"""
|
||||
An invertible feature extractor, i.e. a one-to-one mapping that has a forward and a true inverse.
|
||||
It should hold up to numerical error that `extractor.invert(extractor(x)) == x`.
|
||||
"""
|
||||
@abstractmethod
|
||||
def forward(self, x, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def invert(self, x, **kwargs):
|
||||
pass
|
||||
|
||||
def analysis_synthesis(self, x, **kwargs):
|
||||
return self.invert(self.forward(x, **kwargs), **kwargs)
|
||||
|
||||
class AmplitudeCompressedComplexSTFT(InvertibleFeatureExtractor):
|
||||
"""
|
||||
A convenient composition of ComplexSTFT() and CompressAmplitudesAndScale().
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
window_fn, n_fft, sampling_rate,
|
||||
alpha, beta, comp_eps,
|
||||
hop_length=None, n_hops=None,
|
||||
learnable_window=False,
|
||||
*args, **kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.complex_stft = ComplexSTFT(
|
||||
window_fn, n_fft, sampling_rate, hop_length=hop_length, n_hops=n_hops,
|
||||
learnable_window=learnable_window,
|
||||
)
|
||||
self.compress = CompressAmplitudesAndScale(
|
||||
compression_exponent=alpha,
|
||||
scale_factor=beta,
|
||||
comp_eps=comp_eps,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, **kwargs):
|
||||
X = self.complex_stft(x, **kwargs)
|
||||
out = self.compress(X, **kwargs)
|
||||
return out
|
||||
|
||||
def invert(self, X: Tensor, **kwargs):
|
||||
X = self.compress.invert(X, **kwargs)
|
||||
x = self.complex_stft.invert(X, **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
class ComplexSTFT(InvertibleFeatureExtractor):
|
||||
def __init__(
|
||||
self, window_fn, n_fft, sampling_rate, hop_length=None, n_hops=None, learnable_window=False,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert (hop_length is not None) ^ (n_hops is not None),\
|
||||
"Exactly one of {hop_length, n_hops} must be specified!"
|
||||
if hop_length is None:
|
||||
hop_length = int(math.ceil(n_fft / n_hops))
|
||||
|
||||
window_fn = getattr(torch.signal.windows, window_fn)
|
||||
self.learnable_window = learnable_window
|
||||
self.window = nn.Parameter(window_fn(n_fft), requires_grad=learnable_window)
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.sampling_rate = sampling_rate
|
||||
self.center = True
|
||||
|
||||
def forward(self, x: Tensor, **kwargs):
|
||||
"""Assumes x is an audio tensor of shape [B, C, T] or [B, T]
|
||||
|
||||
[B,C,T] -> [B,C,F,T]
|
||||
[B,C,T] -> [B,F,T]
|
||||
|
||||
"""
|
||||
bc = "b c" if x.ndim == 3 else "b"
|
||||
X = torch.stft(
|
||||
rearrange(x, f"{bc} t -> ({bc}) t"), n_fft=self.n_fft, hop_length=self.hop_length,
|
||||
window=self.window.to(x.device), center=self.center,
|
||||
onesided=True, return_complex=True,
|
||||
)
|
||||
X = rearrange(X, f"({bc}) f t -> {bc} f t", b=x.shape[0])
|
||||
return X
|
||||
|
||||
def invert(self, X: Tensor, orig_length: Optional[int] = None, **kwargs):
|
||||
"""Assumes X is a (complex) spectrogram tensor of shape [B, C, F, T] or [B, F, T]"""
|
||||
bc = "b c" if X.ndim == 4 else "b"
|
||||
x = torch.istft(
|
||||
rearrange(X, f"{bc} f t -> ({bc}) f t"), n_fft=self.n_fft, hop_length=self.hop_length,
|
||||
window=self.window.to(X.device), center=self.center,
|
||||
onesided=True, return_complex=False,
|
||||
length=orig_length,
|
||||
)
|
||||
x = rearrange(x, f"({bc}) t -> {bc} t", b=X.shape[0])
|
||||
return x
|
||||
|
||||
class CompressAmplitudesAndScale(InvertibleFeatureExtractor):
|
||||
def __init__(self, compression_exponent: float, scale_factor: float, comp_eps: float, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.compression_exponent = compression_exponent
|
||||
self.scale_factor = scale_factor
|
||||
self.comp_eps = comp_eps
|
||||
|
||||
def forward(self, X: Tensor, **kwargs):
|
||||
"""
|
||||
Assumes X is a complex STFT (complex spectrogram).
|
||||
"""
|
||||
alpha = self.compression_exponent
|
||||
beta = self.scale_factor
|
||||
if alpha != 1:
|
||||
X = X + self.comp_eps
|
||||
X = X.abs()**alpha * torch.exp(1j * X.angle())
|
||||
return X * beta
|
||||
|
||||
def invert(self, X: Tensor, **kwargs):
|
||||
"""
|
||||
Assumes X is an amplitude-compressed and scaled complex STFT.
|
||||
"""
|
||||
alpha = self.compression_exponent
|
||||
beta = self.scale_factor
|
||||
X = X / beta
|
||||
if alpha != 1:
|
||||
X = X.abs()**(1/alpha) * torch.exp(1j * X.angle())
|
||||
return X
|
||||
|
||||
|
||||
Reference in New Issue
Block a user