feat: equal-quality speed options (TF32 + torch.compile)
Add two opt-in inference speedups to the Model Loader, validated to leave the output perceptually identical (deviation at the fp32 rounding floor): - tf32 (default on): TF32 matmul on Ampere+ (~1.15x). - compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to ~2.5x (measured 4.3s -> 1.7s on a 12s clip). torch.compile needs a static shape (the model's adaptive-avg-pool can't trace dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of any length reuse one compiled graph (no per-length recompiles; verified an 8s clip after a 12s clip ran in 0.9s with no recompile). Researched + profiled first: CFG-batching, channel/chunk batching, and channels_last gave ~0 gain because the GPU is already compute-bound at batch 1. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -30,6 +30,7 @@ muffled or band‑limited audio gets believable "air" and detail back.
|
||||
- [UniverSR Load Video Audio](#universr-load-video-audio)
|
||||
- [UniverSR Video Combiner](#universr-video-combiner)
|
||||
- [Choosing `input_sr`](#choosing-input_sr-the-one-setting-that-matters-most)
|
||||
- [Performance (speed)](#performance-speed)
|
||||
- [Recommended settings](#recommended-settings)
|
||||
- [Long audio & chunking](#long-audio--chunking)
|
||||
- [Example workflow](#example-workflow)
|
||||
@@ -125,6 +126,8 @@ Loads (and caches) a checkpoint. Output: **`UNIVERSR_MODEL`**.
|
||||
|---|---|---|---|
|
||||
| `model` | choice | `universr-audio` | Preset to download, or a local checkpoint folder found under `models/universr/`. |
|
||||
| `device` | `auto` / `cuda` / `cpu` | `auto` | Where to load the weights. `auto` picks CUDA when available. |
|
||||
| `tf32` *(opt.)* | bool | `True` | TF32 matmul on Ampere+ (~1.15×). Perceptually lossless, not bit-exact. |
|
||||
| `compile` *(opt.)* | bool | `False` | `torch.compile` the network (~2×). See [Performance](#performance-speed). |
|
||||
| `local_path` *(opt.)* | string | `""` | Override: a folder with `config.yaml` + `pytorch_model.bin`, **or** a raw training checkpoint (`.pth` / `.ckpt`). |
|
||||
| `config_path` *(opt.)* | string | `""` | `config.yaml` to pair with a raw checkpoint. Empty → the bundled default config. |
|
||||
|
||||
@@ -223,6 +226,30 @@ Two ways to use it:
|
||||
|
||||
---
|
||||
|
||||
## Performance (speed)
|
||||
|
||||
Two **equal-quality** speedups live on the Model Loader (both leave the output perceptually identical —
|
||||
measured deviation is at the fp32 rounding floor, ≈ −64 dB):
|
||||
|
||||
| Setting | Speedup (measured) | Notes |
|
||||
|---|---|---|
|
||||
| `tf32` (default **on**) | ~1.15× | TF32 matmul on Ampere+. One global flag, no caveats worth worrying about. |
|
||||
| `compile` (opt-in) | ~2.1× | `torch.compile` the network. **Stacks with TF32 → ~2.5× total.** |
|
||||
|
||||
On the reference machine, a 12 s clip went **4.3 s → 1.7 s (2.48×)** with both enabled, with a max
|
||||
sample deviation of `2e-4` vs plain fp32.
|
||||
|
||||
**About `compile`:** the first run pays a one-time compile (~10–35 s); after that the compiled model is
|
||||
cached for the whole ComfyUI session. The model can only be compiled for a **fixed input shape**, so the
|
||||
node automatically **pads every chunk to `chunk_seconds`** — meaning clips of *any* length reuse the same
|
||||
compiled graph (no per-length recompiles). Set the sampler's `chunk_seconds` near your typical clip length
|
||||
so short clips aren't padded up wastefully. Requires CUDA; falls back to eager if compilation fails.
|
||||
|
||||
> These are the only speedups that don't change the output. Things that *don't* help here: CFG-batching,
|
||||
> channel/chunk batching, and `channels_last` — the GPU is already compute-bound at batch 1, so they
|
||||
> gave ~0 gain in testing. Going faster than this requires bf16/fp16, which is **not** equal-quality
|
||||
> (verify by ear first).
|
||||
|
||||
## Recommended settings
|
||||
|
||||
| Content | `input_sr` | `guidance_scale` | `ode_method` / `ode_steps` |
|
||||
|
||||
@@ -55,6 +55,17 @@ class UniverSRModelLoader:
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"tf32": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
|
||||
"but not bit-exact; global setting. Turn off for reference fp32.",
|
||||
}),
|
||||
"compile": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
|
||||
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
|
||||
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
|
||||
}),
|
||||
"local_path": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
|
||||
@@ -72,17 +83,20 @@ class UniverSRModelLoader:
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, model, device, local_path="", config_path=""):
|
||||
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||
dev = _default_device() if device == "auto" else device
|
||||
if dev == "cuda" and not torch.cuda.is_available():
|
||||
print("[UniverSR] CUDA unavailable, falling back to CPU")
|
||||
dev = "cpu"
|
||||
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
|
||||
model_obj, cache_key = usr.load_model(
|
||||
model, dev, local_path=local_path, config_path=config_path,
|
||||
tf32=tf32, compile_model=compile,
|
||||
)
|
||||
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
|
||||
return f"{model}:{device}:{local_path}:{config_path}"
|
||||
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
+52
-7
@@ -158,10 +158,21 @@ def resolve_model_ref(model: str, local_path: str = "") -> tuple:
|
||||
)
|
||||
|
||||
|
||||
def load_model(model: str, device: str, local_path: str = "", config_path: str = ""):
|
||||
def apply_tf32(enabled: bool):
|
||||
"""Enable/disable TF32 matmul on Ampere+ GPUs. ~1.15x speedup, perceptually
|
||||
lossless but NOT bit-exact (10 mantissa bits vs 23). Global process setting."""
|
||||
try:
|
||||
torch.set_float32_matmul_precision("high" if enabled else "highest")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def load_model(model: str, device: str, local_path: str = "", config_path: str = "",
|
||||
tf32: bool = True, compile_model: bool = False):
|
||||
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
|
||||
apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect
|
||||
kind, path = resolve_model_ref(model, local_path)
|
||||
cache_key = f"{kind}:{os.path.abspath(path)}:{device}"
|
||||
cache_key = f"{kind}:{os.path.abspath(path)}:{device}:compile={bool(compile_model)}"
|
||||
|
||||
with _CACHE_LOCK:
|
||||
if cache_key in _MODEL_CACHE:
|
||||
@@ -182,8 +193,24 @@ def load_model(model: str, device: str, local_path: str = "", config_path: str =
|
||||
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
|
||||
|
||||
model_obj.eval()
|
||||
|
||||
# torch.compile the UNet (~2.1x measured). Static shapes only — the model's
|
||||
# adaptive-avg-pool can't trace dynamic shapes — so the sampler pads every
|
||||
# chunk to a fixed length (see _universr_compiled flag) to compile exactly once.
|
||||
compiled = False
|
||||
if compile_model and device == "cuda":
|
||||
try:
|
||||
import torch._dynamo as _dynamo
|
||||
_dynamo.config.cache_size_limit = max(getattr(_dynamo.config, "cache_size_limit", 8), 32)
|
||||
model_obj.model = torch.compile(model_obj.model, mode="default", dynamic=False)
|
||||
compiled = True
|
||||
print("[UniverSR] torch.compile enabled (first run compiles ~10-35s, then ~2x).")
|
||||
except Exception as e:
|
||||
print(f"[UniverSR] torch.compile unavailable, continuing eager: {e}")
|
||||
model_obj._universr_compiled = compiled
|
||||
|
||||
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
|
||||
print(f"[UniverSR] Ready - {n:.1f}M params on {device}")
|
||||
print(f"[UniverSR] Ready - {n:.1f}M params on {device} (tf32={tf32}, compile={compiled})")
|
||||
_MODEL_CACHE[cache_key] = model_obj
|
||||
return model_obj, cache_key
|
||||
|
||||
@@ -282,12 +309,21 @@ def _chunk_starts(total: int, chunk: int, hop: int) -> list:
|
||||
|
||||
@torch.no_grad()
|
||||
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
||||
guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor:
|
||||
guidance_scale, chunk: int, ov: int, pbar, pad_to: int = 0) -> torch.Tensor:
|
||||
T = ch48.shape[-1]
|
||||
|
||||
def seg_enhance(seg: torch.Tensor) -> torch.Tensor:
|
||||
# pad_to>0 (compiled model) → zero-pad to a fixed length so the UNet always
|
||||
# sees one input shape (compile once), then crop the result back.
|
||||
L = seg.shape[-1]
|
||||
if pad_to and pad_to > L:
|
||||
seg = _fit(seg, pad_to)
|
||||
return _fit(_enhance_segment(model, seg, input_sr, ode_method, ode_steps, guidance_scale), L)
|
||||
|
||||
if chunk <= 0 or T <= chunk:
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T)
|
||||
return _fit(seg_enhance(ch48), T)
|
||||
|
||||
hop = max(1, chunk - ov)
|
||||
starts = _chunk_starts(T, chunk, hop)
|
||||
@@ -297,7 +333,7 @@ def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
||||
if HAS_COMFY:
|
||||
mm.throw_exception_if_processing_interrupted()
|
||||
e = min(s + chunk, T)
|
||||
enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s)
|
||||
enh = seg_enhance(ch48[s:e])
|
||||
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
|
||||
out[s:e] += enh * w
|
||||
wsum[s:e] += w
|
||||
@@ -331,6 +367,15 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
||||
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
|
||||
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
|
||||
chunk = MODEL_MIN_SAMPLES
|
||||
|
||||
# A torch.compile'd model needs a fixed input shape, so force chunking and pad
|
||||
# every chunk to `chunk` samples (compile once, reuse). Without compile, pad_to=0.
|
||||
compiled = getattr(model, "_universr_compiled", False)
|
||||
if compiled and chunk <= 0:
|
||||
chunk = int(round(10.0 * TARGET_SR))
|
||||
print("[UniverSR] compile: forcing 10.0s chunks for fixed input shapes.")
|
||||
pad_to = chunk if compiled else 0
|
||||
|
||||
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
|
||||
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
|
||||
|
||||
@@ -351,7 +396,7 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
||||
for c in range(C):
|
||||
wet[b, c] = _fit(
|
||||
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
|
||||
guidance_scale, chunk, ov, pbar),
|
||||
guidance_scale, chunk, ov, pbar, pad_to=pad_to),
|
||||
T48,
|
||||
)
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user