feat: equal-quality speed options (TF32 + torch.compile)

Add two opt-in inference speedups to the Model Loader, validated to leave the
output perceptually identical (deviation at the fp32 rounding floor):

- tf32 (default on): TF32 matmul on Ampere+ (~1.15x).
- compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to
  ~2.5x (measured 4.3s -> 1.7s on a 12s clip).

torch.compile needs a static shape (the model's adaptive-avg-pool can't trace
dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of
any length reuse one compiled graph (no per-length recompiles; verified an 8s
clip after a 12s clip ran in 0.9s with no recompile).

Researched + profiled first: CFG-batching, channel/chunk batching, and
channels_last gave ~0 gain because the GPU is already compute-bound at batch 1.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 17:16:21 +02:00
parent 9a901adcc5
commit 104cd4bf5f
3 changed files with 97 additions and 11 deletions
+27
View File
@@ -30,6 +30,7 @@ muffled or bandlimited audio gets believable "air" and detail back.
- [UniverSR Load Video Audio](#universr-load-video-audio)
- [UniverSR Video Combiner](#universr-video-combiner)
- [Choosing `input_sr`](#choosing-input_sr-the-one-setting-that-matters-most)
- [Performance (speed)](#performance-speed)
- [Recommended settings](#recommended-settings)
- [Long audio & chunking](#long-audio--chunking)
- [Example workflow](#example-workflow)
@@ -125,6 +126,8 @@ Loads (and caches) a checkpoint. Output: **`UNIVERSR_MODEL`**.
|---|---|---|---|
| `model` | choice | `universr-audio` | Preset to download, or a local checkpoint folder found under `models/universr/`. |
| `device` | `auto` / `cuda` / `cpu` | `auto` | Where to load the weights. `auto` picks CUDA when available. |
| `tf32` *(opt.)* | bool | `True` | TF32 matmul on Ampere+ (~1.15×). Perceptually lossless, not bit-exact. |
| `compile` *(opt.)* | bool | `False` | `torch.compile` the network (~2×). See [Performance](#performance-speed). |
| `local_path` *(opt.)* | string | `""` | Override: a folder with `config.yaml` + `pytorch_model.bin`, **or** a raw training checkpoint (`.pth` / `.ckpt`). |
| `config_path` *(opt.)* | string | `""` | `config.yaml` to pair with a raw checkpoint. Empty → the bundled default config. |
@@ -223,6 +226,30 @@ Two ways to use it:
---
## Performance (speed)
Two **equal-quality** speedups live on the Model Loader (both leave the output perceptually identical —
measured deviation is at the fp32 rounding floor, ≈ 64 dB):
| Setting | Speedup (measured) | Notes |
|---|---|---|
| `tf32` (default **on**) | ~1.15× | TF32 matmul on Ampere+. One global flag, no caveats worth worrying about. |
| `compile` (opt-in) | ~2.1× | `torch.compile` the network. **Stacks with TF32 → ~2.5× total.** |
On the reference machine, a 12 s clip went **4.3 s → 1.7 s (2.48×)** with both enabled, with a max
sample deviation of `2e-4` vs plain fp32.
**About `compile`:** the first run pays a one-time compile (~1035 s); after that the compiled model is
cached for the whole ComfyUI session. The model can only be compiled for a **fixed input shape**, so the
node automatically **pads every chunk to `chunk_seconds`** — meaning clips of *any* length reuse the same
compiled graph (no per-length recompiles). Set the sampler's `chunk_seconds` near your typical clip length
so short clips aren't padded up wastefully. Requires CUDA; falls back to eager if compilation fails.
> These are the only speedups that don't change the output. Things that *don't* help here: CFG-batching,
> channel/chunk batching, and `channels_last` — the GPU is already compute-bound at batch 1, so they
> gave ~0 gain in testing. Going faster than this requires bf16/fp16, which is **not** equal-quality
> (verify by ear first).
## Recommended settings
| Content | `input_sr` | `guidance_scale` | `ode_method` / `ode_steps` |
+18 -4
View File
@@ -55,6 +55,17 @@ class UniverSRModelLoader:
}),
},
"optional": {
"tf32": ("BOOLEAN", {
"default": True,
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
"but not bit-exact; global setting. Turn off for reference fp32.",
}),
"compile": ("BOOLEAN", {
"default": False,
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
}),
"local_path": ("STRING", {
"default": "",
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
@@ -72,17 +83,20 @@ class UniverSRModelLoader:
RETURN_NAMES = ("model",)
FUNCTION = "load"
def load(self, model, device, local_path="", config_path=""):
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
dev = _default_device() if device == "auto" else device
if dev == "cuda" and not torch.cuda.is_available():
print("[UniverSR] CUDA unavailable, falling back to CPU")
dev = "cpu"
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
model_obj, cache_key = usr.load_model(
model, dev, local_path=local_path, config_path=config_path,
tf32=tf32, compile_model=compile,
)
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
@classmethod
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
return f"{model}:{device}:{local_path}:{config_path}"
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
# --------------------------------------------------------------------------- #
+52 -7
View File
@@ -158,10 +158,21 @@ def resolve_model_ref(model: str, local_path: str = "") -> tuple:
)
def load_model(model: str, device: str, local_path: str = "", config_path: str = ""):
def apply_tf32(enabled: bool):
"""Enable/disable TF32 matmul on Ampere+ GPUs. ~1.15x speedup, perceptually
lossless but NOT bit-exact (10 mantissa bits vs 23). Global process setting."""
try:
torch.set_float32_matmul_precision("high" if enabled else "highest")
except Exception:
pass
def load_model(model: str, device: str, local_path: str = "", config_path: str = "",
tf32: bool = True, compile_model: bool = False):
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect
kind, path = resolve_model_ref(model, local_path)
cache_key = f"{kind}:{os.path.abspath(path)}:{device}"
cache_key = f"{kind}:{os.path.abspath(path)}:{device}:compile={bool(compile_model)}"
with _CACHE_LOCK:
if cache_key in _MODEL_CACHE:
@@ -182,8 +193,24 @@ def load_model(model: str, device: str, local_path: str = "", config_path: str =
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
model_obj.eval()
# torch.compile the UNet (~2.1x measured). Static shapes only — the model's
# adaptive-avg-pool can't trace dynamic shapes — so the sampler pads every
# chunk to a fixed length (see _universr_compiled flag) to compile exactly once.
compiled = False
if compile_model and device == "cuda":
try:
import torch._dynamo as _dynamo
_dynamo.config.cache_size_limit = max(getattr(_dynamo.config, "cache_size_limit", 8), 32)
model_obj.model = torch.compile(model_obj.model, mode="default", dynamic=False)
compiled = True
print("[UniverSR] torch.compile enabled (first run compiles ~10-35s, then ~2x).")
except Exception as e:
print(f"[UniverSR] torch.compile unavailable, continuing eager: {e}")
model_obj._universr_compiled = compiled
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
print(f"[UniverSR] Ready - {n:.1f}M params on {device}")
print(f"[UniverSR] Ready - {n:.1f}M params on {device} (tf32={tf32}, compile={compiled})")
_MODEL_CACHE[cache_key] = model_obj
return model_obj, cache_key
@@ -282,12 +309,21 @@ def _chunk_starts(total: int, chunk: int, hop: int) -> list:
@torch.no_grad()
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor:
guidance_scale, chunk: int, ov: int, pbar, pad_to: int = 0) -> torch.Tensor:
T = ch48.shape[-1]
def seg_enhance(seg: torch.Tensor) -> torch.Tensor:
# pad_to>0 (compiled model) → zero-pad to a fixed length so the UNet always
# sees one input shape (compile once), then crop the result back.
L = seg.shape[-1]
if pad_to and pad_to > L:
seg = _fit(seg, pad_to)
return _fit(_enhance_segment(model, seg, input_sr, ode_method, ode_steps, guidance_scale), L)
if chunk <= 0 or T <= chunk:
if pbar is not None:
pbar.update(1)
return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T)
return _fit(seg_enhance(ch48), T)
hop = max(1, chunk - ov)
starts = _chunk_starts(T, chunk, hop)
@@ -297,7 +333,7 @@ def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
if HAS_COMFY:
mm.throw_exception_if_processing_interrupted()
e = min(s + chunk, T)
enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s)
enh = seg_enhance(ch48[s:e])
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
out[s:e] += enh * w
wsum[s:e] += w
@@ -331,6 +367,15 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
chunk = MODEL_MIN_SAMPLES
# A torch.compile'd model needs a fixed input shape, so force chunking and pad
# every chunk to `chunk` samples (compile once, reuse). Without compile, pad_to=0.
compiled = getattr(model, "_universr_compiled", False)
if compiled and chunk <= 0:
chunk = int(round(10.0 * TARGET_SR))
print("[UniverSR] compile: forcing 10.0s chunks for fixed input shapes.")
pad_to = chunk if compiled else 0
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
@@ -351,7 +396,7 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
for c in range(C):
wet[b, c] = _fit(
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
guidance_scale, chunk, ov, pbar),
guidance_scale, chunk, ov, pbar, pad_to=pad_to),
T48,
)
finally: