Files
ComfyUI-UniverSR/nodes.py
T
Ethanfel 104cd4bf5f feat: equal-quality speed options (TF32 + torch.compile)
Add two opt-in inference speedups to the Model Loader, validated to leave the
output perceptually identical (deviation at the fp32 rounding floor):

- tf32 (default on): TF32 matmul on Ampere+ (~1.15x).
- compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to
  ~2.5x (measured 4.3s -> 1.7s on a 12s clip).

torch.compile needs a static shape (the model's adaptive-avg-pool can't trace
dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of
any length reuse one compiled graph (no per-length recompiles; verified an 8s
clip after a 12s clip ran in 0.9s with no recompile).

Researched + profiled first: CFG-batching, channel/chunk batching, and
channels_last gave ~0 gain because the GPU is already compute-bound at batch 1.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 17:16:21 +02:00

214 lines
9.2 KiB
Python

"""ComfyUI-UniverSR nodes.
Two-node design (mirrors the ComfyUI-Flash-AudioSR pattern):
UniverSRModelLoader -> UNIVERSR_MODEL (loads + caches weights, auto-downloads)
UniverSRSampler -> AUDIO, IMAGE (runs the super-resolution)
"""
import torch
from . import universr_wrapper as usr
try:
import comfy.model_management as mm
HAS_COMFY = True
except Exception: # pragma: no cover
HAS_COMFY = False
def _default_device() -> str:
if HAS_COMFY:
try:
return "cuda" if mm.get_torch_device().type == "cuda" else "cpu"
except Exception:
pass
return "cuda" if torch.cuda.is_available() else "cpu"
# --------------------------------------------------------------------------- #
# Model loader
# --------------------------------------------------------------------------- #
class UniverSRModelLoader:
"""Load a UniverSR checkpoint. Auto-downloads the presets on first use.
Output: UNIVERSR_MODEL -> connect to UniverSR Super-Resolution.
"""
DESCRIPTION = ("Load UniverSR (vocoder-free audio super-resolution, ICASSP 2026). "
"Presets auto-download to models/universr on first use.")
CATEGORY = "audio/UniverSR"
@classmethod
def INPUT_TYPES(cls):
choices = list(usr.HF_REPOS.keys()) + usr.list_local_models()
return {
"required": {
"model": (choices, {
"default": choices[0],
"tooltip": "universr-audio = general (music/SFX/mixed, recommended); "
"universr-speech = voice only. Both download (~230 MB) on first use. "
"Local checkpoint folders in models/universr also appear here.",
}),
"device": (["auto", "cuda", "cpu"], {
"default": "auto",
"tooltip": "Device to load the model onto.",
}),
},
"optional": {
"tf32": ("BOOLEAN", {
"default": True,
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
"but not bit-exact; global setting. Turn off for reference fp32.",
}),
"compile": ("BOOLEAN", {
"default": False,
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
}),
"local_path": ("STRING", {
"default": "",
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
"or a raw .pth/.ckpt file (uses config_path or the bundled config).",
}),
"config_path": ("STRING", {
"default": "",
"tooltip": "config.yaml for a raw checkpoint given in local_path. "
"Leave empty to use the bundled default config.",
}),
},
}
RETURN_TYPES = ("UNIVERSR_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load"
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
dev = _default_device() if device == "auto" else device
if dev == "cuda" and not torch.cuda.is_available():
print("[UniverSR] CUDA unavailable, falling back to CPU")
dev = "cpu"
model_obj, cache_key = usr.load_model(
model, dev, local_path=local_path, config_path=config_path,
tf32=tf32, compile_model=compile,
)
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
@classmethod
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
# --------------------------------------------------------------------------- #
# Sampler
# --------------------------------------------------------------------------- #
class UniverSRSampler:
"""Super-resolve audio to 48 kHz with UniverSR. Long clips are processed in
overlapping chunks (click-free overlap-add) to stay within VRAM."""
DESCRIPTION = ("Upscale low-bandwidth audio to 48 kHz with UniverSR. Pick input_sr to "
"match the effective bandwidth of your content (the model regenerates "
"everything above input_sr/2).")
CATEGORY = "audio/UniverSR"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO", {}),
"model": ("UNIVERSR_MODEL", {}),
"input_sr": ([8000, 12000, 16000, 24000], {
"default": 8000,
"tooltip": "Effective input bandwidth (Hz). Content is treated as valid up to "
"input_sr/2 and regenerated above it. 8000 = genuine low-rate audio "
"(strongest, 8 kHz->48 kHz). 16000 = brighten muffled audio above 8 kHz.",
}),
},
"optional": {
"ode_method": (["midpoint", "euler", "rk4"], {
"default": "midpoint",
"tooltip": "ODE solver. euler (fastest) -> midpoint (balanced) -> rk4 (best).",
}),
"ode_steps": ("INT", {
"default": 4, "min": 1, "max": 64, "step": 1,
"tooltip": "Flow-matching integration steps. 4 is fast and validated; 4-10 is a good range.",
}),
"guidance_scale": ("FLOAT", {
"default": 1.5, "min": 0.0, "max": 6.0, "step": 0.25,
"tooltip": "Classifier-free guidance. Speech 1.0-1.5, music 1.5-2.0, SFX ~1.5. "
"Higher = denser highs but less faithful. 0 disables CFG.",
}),
"seed": ("INT", {
"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Noise seed for the flow-matching source. 0 = random each run.",
}),
"chunk_seconds": ("FLOAT", {
"default": 10.0, "min": 0.0, "max": 120.0, "step": 0.5,
"tooltip": "Process long audio in chunks of this length (seconds) to avoid OOM. "
"0 = process the whole clip at once.",
}),
"overlap_seconds": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 5.0, "step": 0.1,
"tooltip": "Crossfade overlap between chunks (seconds). Prevents seam clicks.",
}),
"blend": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Wet/dry mix. 1.0 = full super-resolution. Lower to keep more of the "
"original (useful when brightening already-48 kHz audio).",
}),
"unload_model": ("BOOLEAN", {
"default": False,
"tooltip": "Free the model from VRAM after this run.",
}),
"show_spectrogram": ("BOOLEAN", {
"default": True,
"tooltip": "Also output a before/after spectrogram comparison image.",
}),
},
}
RETURN_TYPES = ("AUDIO", "IMAGE")
RETURN_NAMES = ("audio", "spectrogram")
FUNCTION = "run"
def run(self, audio, model, input_sr, ode_method="midpoint", ode_steps=4,
guidance_scale=1.5, seed=0, chunk_seconds=10.0, overlap_seconds=0.5,
blend=1.0, unload_model=False, show_spectrogram=True):
model_obj = model["model"]
waveform, sr = usr.comfy_audio_to_tensor(audio)
dur = waveform.shape[-1] / max(sr, 1)
print(f"[UniverSR] {tuple(waveform.shape)} @ {sr} Hz ({dur:.2f}s) -> 48 kHz | "
f"input_sr={input_sr}, {ode_method}/{ode_steps}, cfg={guidance_scale}, blend={blend}")
out, dry48 = usr.super_resolve(
model_obj, waveform, sr, int(input_sr),
ode_method=ode_method, ode_steps=int(ode_steps), guidance_scale=guidance_scale,
seed=int(seed), chunk_seconds=float(chunk_seconds),
overlap_seconds=float(overlap_seconds), blend=float(blend),
)
audio_out = usr.tensor_to_comfy_audio(out, usr.TARGET_SR)
spec = torch.zeros(1, 64, 64, 3)
if show_spectrogram:
in_mono = dry48[0].mean(0).numpy()
out_mono = out[0].mean(0).numpy()
spec = usr.make_spectrogram_image(in_mono, out_mono, int(input_sr))
if unload_model:
usr.evict_model(model["cache_key"])
print(f"[UniverSR] Done -> {out.shape[-1] / usr.TARGET_SR:.2f}s at 48 kHz")
return (audio_out, spec)
NODE_CLASS_MAPPINGS = {
"UniverSRModelLoader": UniverSRModelLoader,
"UniverSRSampler": UniverSRSampler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UniverSRModelLoader": "UniverSR Model Loader",
"UniverSRSampler": "UniverSR Super-Resolution",
}