Files
ComfyUI-UniverSR/vendor/universr/flow/path.py
T
Ethanfel 5f29b225b7 Initial release: ComfyUI-UniverSR
ComfyUI nodes for UniverSR (ICASSP 2026) — vocoder-free audio
super-resolution (8/12/16/24 kHz → 48 kHz) via flow matching.

- UniverSR Model Loader: presets auto-download to models/universr,
  plus local dir / raw .pth (from_local) loading, with caching.
- UniverSR Super-Resolution: chunked overlap-add for long audio,
  per-channel stereo, seed control with global-RNG isolation,
  wet/dry blend, and an optional before/after spectrogram.
- Vendors the universr inference package under vendor/ (prefers an
  installed copy); only extra dep beyond ComfyUI's stack is torchdiffeq.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-01 12:59:42 +02:00

54 lines
1.8 KiB
Python

import importlib
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
class ConditionalProbabilityPath(nn.Module, ABC):
"""Abstract base class for conditional probability paths in flow matching."""
@abstractmethod
def sample_source(self, shape_ref: torch.Tensor) -> torch.Tensor:
"""Sample from the source distribution. shape_ref is used only for shape/device."""
@abstractmethod
def sample_xt(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Interpolate between source x0 and target x1 at time t."""
@abstractmethod
def get_target_vector_field(
self, xt: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Compute the target vector field u_t(xt | x1)."""
class OriginalCFMPath(ConditionalProbabilityPath):
def __init__(self, sigma_min: float = 1e-4):
super().__init__()
self.sigma_min = sigma_min
def sample_source(self, shape_ref):
return torch.randn_like(shape_ref)
def sample_xt(self, x0, x1, t):
return t * x1 + (1 - t + self.sigma_min * t) * x0
def get_target_vector_field(self, xt, x0, x1, t):
return x1 - (1 - self.sigma_min) * x0
def get_path(config):
class_path = config.get("class_path")
if not class_path:
raise ValueError("Configuration must contain a 'class_path' key")
try:
module_path, class_name = class_path.rsplit(".", 1)
except ValueError:
raise ValueError(f"Invalid class_path '{class_path}'. Must contain at least one")
module = importlib.import_module(module_path)
Class = getattr(module, class_name)
init_args = config.get("init_args", {})
return Class(**init_args)