5f29b225b7
ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching. - UniverSR Model Loader: presets auto-download to models/universr, plus local dir / raw .pth (from_local) loading, with caching. - UniverSR Super-Resolution: chunked overlap-add for long audio, per-channel stereo, seed control with global-RNG isolation, wet/dry blend, and an optional before/after spectrogram. - Vendors the universr inference package under vendor/ (prefers an installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
import importlib
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class ConditionalProbabilityPath(nn.Module, ABC):
|
|
"""Abstract base class for conditional probability paths in flow matching."""
|
|
|
|
@abstractmethod
|
|
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
|
|
"""Sample from the source distribution. shape_ref is used only for shape/device."""
|
|
|
|
@abstractmethod
|
|
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
"""Interpolate between source x0 and target x1 at time t."""
|
|
|
|
@abstractmethod
|
|
def get_target_vector_field(
|
|
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the target vector field u_t(xt | x1)."""
|
|
|
|
class OriginalCFMPath(ConditionalProbabilityPath):
|
|
def __init__(self, sigma_min: float = 1e-4):
|
|
super().__init__()
|
|
self.sigma_min = sigma_min
|
|
|
|
def sample_source(self, shape_ref):
|
|
return torch.randn_like(shape_ref)
|
|
|
|
def sample_xt(self, x0, x1, t):
|
|
return t * x1 + (1 - t + self.sigma_min * t) * x0
|
|
|
|
def get_target_vector_field(self, xt, x0, x1, t):
|
|
return x1 - (1 - self.sigma_min) * x0
|
|
|
|
def get_path(config):
|
|
class_path = config.get("class_path")
|
|
|
|
if not class_path:
|
|
raise ValueError("Configuration must contain a 'class_path' key")
|
|
try:
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
|
|
|
|
module = importlib.import_module(module_path)
|
|
Class = getattr(module, class_name)
|
|
init_args = config.get("init_args", {})
|
|
return Class(**init_args)
|
|
|
|
|