Initial release: ComfyUI-UniverSR

ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 12:59:42 +02:00
commit 5f29b225b7
20 changed files with 2129 additions and 0 deletions
+127
View File
@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
import torch
from torchdiffeq import odeint
from tqdm import tqdm
from universr.models.unet import ConditionalVectorFieldModel
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class Solver(ABC):
# @abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time, shape (bs, 1, 1, 1)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x.clone()]
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:,t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
xs.append(x.clone())
return torch.stack(xs, dim=1)
class VectorFieldODE(ODE):
def __init__(self, net:ConditionalVectorFieldModel) -> None:
super().__init__()
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t, y, **kwargs)
class CFGVectorFieldODE(ODE):
""" For Classifier Free Guidance """
def __init__(self, net:ConditionalVectorFieldModel, guidance_scale: float = 1.0) -> None:
super().__init__()
self.net = net
self.guidance_scale = guidance_scale
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
guided_vector_field = self.net(xt, t, y, **kwargs)
unguided_vector_field = self.net(xt, t, None, **kwargs)
return (1-self.guidance_scale) * unguided_vector_field + self.guidance_scale * guided_vector_field
class EulerSolver(Solver):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs):
return xt + self.ode.drift_coefficient(xt,t, **kwargs) * h
class TorchDiffeqSolver(Solver):
def __init__(self,
ode: ODE,
method: str = 'euler',
atol: float = 1e-5,
rtol: float = 1e-5,
):
super().__init__()
self.ode = ode
self.method = method
self.atol = atol
self.rtol = rtol
@torch.no_grad()
def simulate(self, x_init: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
x_init: [B,C,H,W]
ts: [N]
return: final state [B,C,H,W]
"""
func = lambda t, x: self.ode.drift_coefficient(xt=x, t=t, **kwargs)
xs = odeint(
func=func,
y0=x_init, t=ts,
method=self.method,
atol=self.atol, rtol=self.rtol) # [N,B,C,H,W]
return xs[-1]