Initial release: ComfyUI-UniverSR

ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 12:59:42 +02:00
commit 5f29b225b7
20 changed files with 2129 additions and 0 deletions
View File
+9
View File
@@ -0,0 +1,9 @@
import torch
import torch.nn.functional as F
def flow_matching_loss(predicted_vf: torch.Tensor, target_vf: torch.Tensor) -> torch.Tensor:
"""
Flow matching loss; L2 loss between estimated and target vector field.
"""
return F.mse_loss(predicted_vf, target_vf)
+54
View File
@@ -0,0 +1,54 @@
import importlib
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
class ConditionalProbabilityPath(nn.Module, ABC):
"""Abstract base class for conditional probability paths in flow matching."""
@abstractmethod
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
"""Sample from the source distribution. shape_ref is used only for shape/device."""
@abstractmethod
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Interpolate between source x0 and target x1 at time t."""
@abstractmethod
def get_target_vector_field(
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Compute the target vector field u_t(xt | x1)."""
class OriginalCFMPath(ConditionalProbabilityPath):
def __init__(self, sigma_min: float = 1e-4):
super().__init__()
self.sigma_min = sigma_min
def sample_source(self, shape_ref):
return torch.randn_like(shape_ref)
def sample_xt(self, x0, x1, t):
return t * x1 + (1 - t + self.sigma_min * t) * x0
def get_target_vector_field(self, xt, x0, x1, t):
return x1 - (1 - self.sigma_min) * x0
def get_path(config):
class_path = config.get("class_path")
if not class_path:
raise ValueError("Configuration must contain a 'class_path' key")
try:
module_path, class_name = class_path.rsplit(".", 1)
except ValueError:
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
module = importlib.import_module(module_path)
Class = getattr(module, class_name)
init_args = config.get("init_args", {})
return Class(**init_args)
+127
View File
@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
import torch
from torchdiffeq import odeint
from tqdm import tqdm
from universr.models.unet import ConditionalVectorFieldModel
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class Solver(ABC):
# @abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time, shape (bs, 1, 1, 1)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x.clone()]
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:,t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
xs.append(x.clone())
return torch.stack(xs, dim=1)
class VectorFieldODE(ODE):
def __init__(self, net:ConditionalVectorFieldModel) -> None:
super().__init__()
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t, y, **kwargs)
class CFGVectorFieldODE(ODE):
""" For Classifier Free Guidance """
def __init__(self, net:ConditionalVectorFieldModel, guidance_scale: float = 1.0) -> None:
super().__init__()
self.net = net
self.guidance_scale = guidance_scale
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor:
guided_vector_field = self.net(xt, t, y, **kwargs)
unguided_vector_field = self.net(xt, t, None, **kwargs)
return (1-self.guidance_scale) * unguided_vector_field + self.guidance_scale * guided_vector_field
class EulerSolver(Solver):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs):
return xt + self.ode.drift_coefficient(xt,t, **kwargs) * h
class TorchDiffeqSolver(Solver):
def __init__(self,
ode: ODE,
method: str = 'euler',
atol: float = 1e-5,
rtol: float = 1e-5,
):
super().__init__()
self.ode = ode
self.method = method
self.atol = atol
self.rtol = rtol
@torch.no_grad()
def simulate(self, x_init: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
x_init: [B,C,H,W]
ts: [N]
return: final state [B,C,H,W]
"""
func = lambda t, x: self.ode.drift_coefficient(xt=x, t=t, **kwargs)
xs = odeint(
func=func,
y0=x_init, t=ts,
method=self.method,
atol=self.atol, rtol=self.rtol) # [N,B,C,H,W]
return xs[-1]