From 104cd4bf5fe7216feb9022da4e2aafd1eaaf880d Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Tue, 16 Jun 2026 17:16:21 +0200 Subject: [PATCH] feat: equal-quality speed options (TF32 + torch.compile) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add two opt-in inference speedups to the Model Loader, validated to leave the output perceptually identical (deviation at the fp32 rounding floor): - tf32 (default on): TF32 matmul on Ampere+ (~1.15x). - compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to ~2.5x (measured 4.3s -> 1.7s on a 12s clip). torch.compile needs a static shape (the model's adaptive-avg-pool can't trace dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of any length reuse one compiled graph (no per-length recompiles; verified an 8s clip after a 12s clip ran in 0.9s with no recompile). Researched + profiled first: CFG-batching, channel/chunk batching, and channels_last gave ~0 gain because the GPU is already compute-bound at batch 1. Co-Authored-By: Claude Opus 4.8 --- README.md | 27 +++++++++++++++++++++ nodes.py | 22 ++++++++++++++--- universr_wrapper.py | 59 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 97 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index fa2fbe3..765a622 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ muffled or band‑limited audio gets believable "air" and detail back. - [UniverSR Load Video Audio](#universr-load-video-audio) - [UniverSR Video Combiner](#universr-video-combiner) - [Choosing `input_sr`](#choosing-input_sr-the-one-setting-that-matters-most) +- [Performance (speed)](#performance-speed) - [Recommended settings](#recommended-settings) - [Long audio & chunking](#long-audio--chunking) - [Example workflow](#example-workflow) @@ -125,6 +126,8 @@ Loads (and caches) a checkpoint. Output: **`UNIVERSR_MODEL`**. |---|---|---|---| | `model` | choice | `universr-audio` | Preset to download, or a local checkpoint folder found under `models/universr/`. | | `device` | `auto` / `cuda` / `cpu` | `auto` | Where to load the weights. `auto` picks CUDA when available. | +| `tf32` *(opt.)* | bool | `True` | TF32 matmul on Ampere+ (~1.15×). Perceptually lossless, not bit-exact. | +| `compile` *(opt.)* | bool | `False` | `torch.compile` the network (~2×). See [Performance](#performance-speed). | | `local_path` *(opt.)* | string | `""` | Override: a folder with `config.yaml` + `pytorch_model.bin`, **or** a raw training checkpoint (`.pth` / `.ckpt`). | | `config_path` *(opt.)* | string | `""` | `config.yaml` to pair with a raw checkpoint. Empty → the bundled default config. | @@ -223,6 +226,30 @@ Two ways to use it: --- +## Performance (speed) + +Two **equal-quality** speedups live on the Model Loader (both leave the output perceptually identical — +measured deviation is at the fp32 rounding floor, ≈ −64 dB): + +| Setting | Speedup (measured) | Notes | +|---|---|---| +| `tf32` (default **on**) | ~1.15× | TF32 matmul on Ampere+. One global flag, no caveats worth worrying about. | +| `compile` (opt-in) | ~2.1× | `torch.compile` the network. **Stacks with TF32 → ~2.5× total.** | + +On the reference machine, a 12 s clip went **4.3 s → 1.7 s (2.48×)** with both enabled, with a max +sample deviation of `2e-4` vs plain fp32. + +**About `compile`:** the first run pays a one-time compile (~10–35 s); after that the compiled model is +cached for the whole ComfyUI session. The model can only be compiled for a **fixed input shape**, so the +node automatically **pads every chunk to `chunk_seconds`** — meaning clips of *any* length reuse the same +compiled graph (no per-length recompiles). Set the sampler's `chunk_seconds` near your typical clip length +so short clips aren't padded up wastefully. Requires CUDA; falls back to eager if compilation fails. + +> These are the only speedups that don't change the output. Things that *don't* help here: CFG-batching, +> channel/chunk batching, and `channels_last` — the GPU is already compute-bound at batch 1, so they +> gave ~0 gain in testing. Going faster than this requires bf16/fp16, which is **not** equal-quality +> (verify by ear first). + ## Recommended settings | Content | `input_sr` | `guidance_scale` | `ode_method` / `ode_steps` | diff --git a/nodes.py b/nodes.py index ca9c5b3..8d8c59b 100644 --- a/nodes.py +++ b/nodes.py @@ -55,6 +55,17 @@ class UniverSRModelLoader: }), }, "optional": { + "tf32": ("BOOLEAN", { + "default": True, + "tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless " + "but not bit-exact; global setting. Turn off for reference fp32.", + }), + "compile": ("BOOLEAN", { + "default": False, + "tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast " + "and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the " + "sampler's chunk_seconds near your typical clip length to avoid wasted compute.", + }), "local_path": ("STRING", { "default": "", "tooltip": "Override: a folder with config.yaml + pytorch_model.bin, " @@ -72,17 +83,20 @@ class UniverSRModelLoader: RETURN_NAMES = ("model",) FUNCTION = "load" - def load(self, model, device, local_path="", config_path=""): + def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""): dev = _default_device() if device == "auto" else device if dev == "cuda" and not torch.cuda.is_available(): print("[UniverSR] CUDA unavailable, falling back to CPU") dev = "cpu" - model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path) + model_obj, cache_key = usr.load_model( + model, dev, local_path=local_path, config_path=config_path, + tf32=tf32, compile_model=compile, + ) return ({"model": model_obj, "device": dev, "cache_key": cache_key},) @classmethod - def IS_CHANGED(cls, model, device, local_path="", config_path=""): - return f"{model}:{device}:{local_path}:{config_path}" + def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""): + return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}" # --------------------------------------------------------------------------- # diff --git a/universr_wrapper.py b/universr_wrapper.py index bc7a2b3..4bcec07 100644 --- a/universr_wrapper.py +++ b/universr_wrapper.py @@ -158,10 +158,21 @@ def resolve_model_ref(model: str, local_path: str = "") -> tuple: ) -def load_model(model: str, device: str, local_path: str = "", config_path: str = ""): +def apply_tf32(enabled: bool): + """Enable/disable TF32 matmul on Ampere+ GPUs. ~1.15x speedup, perceptually + lossless but NOT bit-exact (10 mantissa bits vs 23). Global process setting.""" + try: + torch.set_float32_matmul_precision("high" if enabled else "highest") + except Exception: + pass + + +def load_model(model: str, device: str, local_path: str = "", config_path: str = "", + tf32: bool = True, compile_model: bool = False): """Load (and cache) a UniverSR model. Returns (model_obj, cache_key).""" + apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect kind, path = resolve_model_ref(model, local_path) - cache_key = f"{kind}:{os.path.abspath(path)}:{device}" + cache_key = f"{kind}:{os.path.abspath(path)}:{device}:compile={bool(compile_model)}" with _CACHE_LOCK: if cache_key in _MODEL_CACHE: @@ -182,8 +193,24 @@ def load_model(model: str, device: str, local_path: str = "", config_path: str = model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device) model_obj.eval() + + # torch.compile the UNet (~2.1x measured). Static shapes only — the model's + # adaptive-avg-pool can't trace dynamic shapes — so the sampler pads every + # chunk to a fixed length (see _universr_compiled flag) to compile exactly once. + compiled = False + if compile_model and device == "cuda": + try: + import torch._dynamo as _dynamo + _dynamo.config.cache_size_limit = max(getattr(_dynamo.config, "cache_size_limit", 8), 32) + model_obj.model = torch.compile(model_obj.model, mode="default", dynamic=False) + compiled = True + print("[UniverSR] torch.compile enabled (first run compiles ~10-35s, then ~2x).") + except Exception as e: + print(f"[UniverSR] torch.compile unavailable, continuing eager: {e}") + model_obj._universr_compiled = compiled + n = sum(p.numel() for p in model_obj.parameters()) / 1e6 - print(f"[UniverSR] Ready - {n:.1f}M params on {device}") + print(f"[UniverSR] Ready - {n:.1f}M params on {device} (tf32={tf32}, compile={compiled})") _MODEL_CACHE[cache_key] = model_obj return model_obj, cache_key @@ -282,12 +309,21 @@ def _chunk_starts(total: int, chunk: int, hop: int) -> list: @torch.no_grad() def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps, - guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor: + guidance_scale, chunk: int, ov: int, pbar, pad_to: int = 0) -> torch.Tensor: T = ch48.shape[-1] + + def seg_enhance(seg: torch.Tensor) -> torch.Tensor: + # pad_to>0 (compiled model) → zero-pad to a fixed length so the UNet always + # sees one input shape (compile once), then crop the result back. + L = seg.shape[-1] + if pad_to and pad_to > L: + seg = _fit(seg, pad_to) + return _fit(_enhance_segment(model, seg, input_sr, ode_method, ode_steps, guidance_scale), L) + if chunk <= 0 or T <= chunk: if pbar is not None: pbar.update(1) - return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T) + return _fit(seg_enhance(ch48), T) hop = max(1, chunk - ov) starts = _chunk_starts(T, chunk, hop) @@ -297,7 +333,7 @@ def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps, if HAS_COMFY: mm.throw_exception_if_processing_interrupted() e = min(s + chunk, T) - enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s) + enh = seg_enhance(ch48[s:e]) w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T)) out[s:e] += enh * w wsum[s:e] += w @@ -331,6 +367,15 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int, print(f"[UniverSR] chunk_seconds too small; raising to the model floor " f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).") chunk = MODEL_MIN_SAMPLES + + # A torch.compile'd model needs a fixed input shape, so force chunking and pad + # every chunk to `chunk` samples (compile once, reuse). Without compile, pad_to=0. + compiled = getattr(model, "_universr_compiled", False) + if compiled and chunk <= 0: + chunk = int(round(10.0 * TARGET_SR)) + print("[UniverSR] compile: forcing 10.0s chunks for fixed input shapes.") + pad_to = chunk if compiled else 0 + ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0 n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1 @@ -351,7 +396,7 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int, for c in range(C): wet[b, c] = _fit( _enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps, - guidance_scale, chunk, ov, pbar), + guidance_scale, chunk, ov, pbar, pad_to=pad_to), T48, ) finally: