Initial release: ComfyUI-UniverSR
ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching. - UniverSR Model Loader: presets auto-download to models/universr, plus local dir / raw .pth (from_local) loading, with caching. - UniverSR Super-Resolution: chunked overlap-add for long audio, per-channel stereo, seed control with global-RNG isolation, wet/dry blend, and an optional before/after spectrogram. - Vendors the universr inference package under vendor/ (prefers an installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Vendored
+54
@@ -0,0 +1,54 @@
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ConditionalProbabilityPath(nn.Module, ABC):
|
||||
"""Abstract base class for conditional probability paths in flow matching."""
|
||||
|
||||
@abstractmethod
|
||||
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
|
||||
"""Sample from the source distribution. shape_ref is used only for shape/device."""
|
||||
|
||||
@abstractmethod
|
||||
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
"""Interpolate between source x0 and target x1 at time t."""
|
||||
|
||||
@abstractmethod
|
||||
def get_target_vector_field(
|
||||
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute the target vector field u_t(xt | x1)."""
|
||||
|
||||
class OriginalCFMPath(ConditionalProbabilityPath):
|
||||
def __init__(self, sigma_min: float = 1e-4):
|
||||
super().__init__()
|
||||
self.sigma_min = sigma_min
|
||||
|
||||
def sample_source(self, shape_ref):
|
||||
return torch.randn_like(shape_ref)
|
||||
|
||||
def sample_xt(self, x0, x1, t):
|
||||
return t * x1 + (1 - t + self.sigma_min * t) * x0
|
||||
|
||||
def get_target_vector_field(self, xt, x0, x1, t):
|
||||
return x1 - (1 - self.sigma_min) * x0
|
||||
|
||||
def get_path(config):
|
||||
class_path = config.get("class_path")
|
||||
|
||||
if not class_path:
|
||||
raise ValueError("Configuration must contain a 'class_path' key")
|
||||
try:
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
Class = getattr(module, class_name)
|
||||
init_args = config.get("init_args", {})
|
||||
return Class(**init_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user