feat: equal-quality speed options (TF32 + torch.compile)
Add two opt-in inference speedups to the Model Loader, validated to leave the output perceptually identical (deviation at the fp32 rounding floor): - tf32 (default on): TF32 matmul on Ampere+ (~1.15x). - compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to ~2.5x (measured 4.3s -> 1.7s on a 12s clip). torch.compile needs a static shape (the model's adaptive-avg-pool can't trace dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of any length reuse one compiled graph (no per-length recompiles; verified an 8s clip after a 12s clip ran in 0.9s with no recompile). Researched + profiled first: CFG-batching, channel/chunk batching, and channels_last gave ~0 gain because the GPU is already compute-bound at batch 1. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -30,6 +30,7 @@ muffled or band‑limited audio gets believable "air" and detail back.
|
|||||||
- [UniverSR Load Video Audio](#universr-load-video-audio)
|
- [UniverSR Load Video Audio](#universr-load-video-audio)
|
||||||
- [UniverSR Video Combiner](#universr-video-combiner)
|
- [UniverSR Video Combiner](#universr-video-combiner)
|
||||||
- [Choosing `input_sr`](#choosing-input_sr-the-one-setting-that-matters-most)
|
- [Choosing `input_sr`](#choosing-input_sr-the-one-setting-that-matters-most)
|
||||||
|
- [Performance (speed)](#performance-speed)
|
||||||
- [Recommended settings](#recommended-settings)
|
- [Recommended settings](#recommended-settings)
|
||||||
- [Long audio & chunking](#long-audio--chunking)
|
- [Long audio & chunking](#long-audio--chunking)
|
||||||
- [Example workflow](#example-workflow)
|
- [Example workflow](#example-workflow)
|
||||||
@@ -125,6 +126,8 @@ Loads (and caches) a checkpoint. Output: **`UNIVERSR_MODEL`**.
|
|||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| `model` | choice | `universr-audio` | Preset to download, or a local checkpoint folder found under `models/universr/`. |
|
| `model` | choice | `universr-audio` | Preset to download, or a local checkpoint folder found under `models/universr/`. |
|
||||||
| `device` | `auto` / `cuda` / `cpu` | `auto` | Where to load the weights. `auto` picks CUDA when available. |
|
| `device` | `auto` / `cuda` / `cpu` | `auto` | Where to load the weights. `auto` picks CUDA when available. |
|
||||||
|
| `tf32` *(opt.)* | bool | `True` | TF32 matmul on Ampere+ (~1.15×). Perceptually lossless, not bit-exact. |
|
||||||
|
| `compile` *(opt.)* | bool | `False` | `torch.compile` the network (~2×). See [Performance](#performance-speed). |
|
||||||
| `local_path` *(opt.)* | string | `""` | Override: a folder with `config.yaml` + `pytorch_model.bin`, **or** a raw training checkpoint (`.pth` / `.ckpt`). |
|
| `local_path` *(opt.)* | string | `""` | Override: a folder with `config.yaml` + `pytorch_model.bin`, **or** a raw training checkpoint (`.pth` / `.ckpt`). |
|
||||||
| `config_path` *(opt.)* | string | `""` | `config.yaml` to pair with a raw checkpoint. Empty → the bundled default config. |
|
| `config_path` *(opt.)* | string | `""` | `config.yaml` to pair with a raw checkpoint. Empty → the bundled default config. |
|
||||||
|
|
||||||
@@ -223,6 +226,30 @@ Two ways to use it:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Performance (speed)
|
||||||
|
|
||||||
|
Two **equal-quality** speedups live on the Model Loader (both leave the output perceptually identical —
|
||||||
|
measured deviation is at the fp32 rounding floor, ≈ −64 dB):
|
||||||
|
|
||||||
|
| Setting | Speedup (measured) | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| `tf32` (default **on**) | ~1.15× | TF32 matmul on Ampere+. One global flag, no caveats worth worrying about. |
|
||||||
|
| `compile` (opt-in) | ~2.1× | `torch.compile` the network. **Stacks with TF32 → ~2.5× total.** |
|
||||||
|
|
||||||
|
On the reference machine, a 12 s clip went **4.3 s → 1.7 s (2.48×)** with both enabled, with a max
|
||||||
|
sample deviation of `2e-4` vs plain fp32.
|
||||||
|
|
||||||
|
**About `compile`:** the first run pays a one-time compile (~10–35 s); after that the compiled model is
|
||||||
|
cached for the whole ComfyUI session. The model can only be compiled for a **fixed input shape**, so the
|
||||||
|
node automatically **pads every chunk to `chunk_seconds`** — meaning clips of *any* length reuse the same
|
||||||
|
compiled graph (no per-length recompiles). Set the sampler's `chunk_seconds` near your typical clip length
|
||||||
|
so short clips aren't padded up wastefully. Requires CUDA; falls back to eager if compilation fails.
|
||||||
|
|
||||||
|
> These are the only speedups that don't change the output. Things that *don't* help here: CFG-batching,
|
||||||
|
> channel/chunk batching, and `channels_last` — the GPU is already compute-bound at batch 1, so they
|
||||||
|
> gave ~0 gain in testing. Going faster than this requires bf16/fp16, which is **not** equal-quality
|
||||||
|
> (verify by ear first).
|
||||||
|
|
||||||
## Recommended settings
|
## Recommended settings
|
||||||
|
|
||||||
| Content | `input_sr` | `guidance_scale` | `ode_method` / `ode_steps` |
|
| Content | `input_sr` | `guidance_scale` | `ode_method` / `ode_steps` |
|
||||||
|
|||||||
@@ -55,6 +55,17 @@ class UniverSRModelLoader:
|
|||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
|
"tf32": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
|
||||||
|
"but not bit-exact; global setting. Turn off for reference fp32.",
|
||||||
|
}),
|
||||||
|
"compile": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
|
||||||
|
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
|
||||||
|
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
|
||||||
|
}),
|
||||||
"local_path": ("STRING", {
|
"local_path": ("STRING", {
|
||||||
"default": "",
|
"default": "",
|
||||||
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
|
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
|
||||||
@@ -72,17 +83,20 @@ class UniverSRModelLoader:
|
|||||||
RETURN_NAMES = ("model",)
|
RETURN_NAMES = ("model",)
|
||||||
FUNCTION = "load"
|
FUNCTION = "load"
|
||||||
|
|
||||||
def load(self, model, device, local_path="", config_path=""):
|
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||||
dev = _default_device() if device == "auto" else device
|
dev = _default_device() if device == "auto" else device
|
||||||
if dev == "cuda" and not torch.cuda.is_available():
|
if dev == "cuda" and not torch.cuda.is_available():
|
||||||
print("[UniverSR] CUDA unavailable, falling back to CPU")
|
print("[UniverSR] CUDA unavailable, falling back to CPU")
|
||||||
dev = "cpu"
|
dev = "cpu"
|
||||||
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
|
model_obj, cache_key = usr.load_model(
|
||||||
|
model, dev, local_path=local_path, config_path=config_path,
|
||||||
|
tf32=tf32, compile_model=compile,
|
||||||
|
)
|
||||||
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
|
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
|
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||||
return f"{model}:{device}:{local_path}:{config_path}"
|
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
|
|||||||
+52
-7
@@ -158,10 +158,21 @@ def resolve_model_ref(model: str, local_path: str = "") -> tuple:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model(model: str, device: str, local_path: str = "", config_path: str = ""):
|
def apply_tf32(enabled: bool):
|
||||||
|
"""Enable/disable TF32 matmul on Ampere+ GPUs. ~1.15x speedup, perceptually
|
||||||
|
lossless but NOT bit-exact (10 mantissa bits vs 23). Global process setting."""
|
||||||
|
try:
|
||||||
|
torch.set_float32_matmul_precision("high" if enabled else "highest")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model: str, device: str, local_path: str = "", config_path: str = "",
|
||||||
|
tf32: bool = True, compile_model: bool = False):
|
||||||
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
|
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
|
||||||
|
apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect
|
||||||
kind, path = resolve_model_ref(model, local_path)
|
kind, path = resolve_model_ref(model, local_path)
|
||||||
cache_key = f"{kind}:{os.path.abspath(path)}:{device}"
|
cache_key = f"{kind}:{os.path.abspath(path)}:{device}:compile={bool(compile_model)}"
|
||||||
|
|
||||||
with _CACHE_LOCK:
|
with _CACHE_LOCK:
|
||||||
if cache_key in _MODEL_CACHE:
|
if cache_key in _MODEL_CACHE:
|
||||||
@@ -182,8 +193,24 @@ def load_model(model: str, device: str, local_path: str = "", config_path: str =
|
|||||||
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
|
model_obj = UniverSR.from_local(ckpt_path=path, config_path=cfg, device=device)
|
||||||
|
|
||||||
model_obj.eval()
|
model_obj.eval()
|
||||||
|
|
||||||
|
# torch.compile the UNet (~2.1x measured). Static shapes only — the model's
|
||||||
|
# adaptive-avg-pool can't trace dynamic shapes — so the sampler pads every
|
||||||
|
# chunk to a fixed length (see _universr_compiled flag) to compile exactly once.
|
||||||
|
compiled = False
|
||||||
|
if compile_model and device == "cuda":
|
||||||
|
try:
|
||||||
|
import torch._dynamo as _dynamo
|
||||||
|
_dynamo.config.cache_size_limit = max(getattr(_dynamo.config, "cache_size_limit", 8), 32)
|
||||||
|
model_obj.model = torch.compile(model_obj.model, mode="default", dynamic=False)
|
||||||
|
compiled = True
|
||||||
|
print("[UniverSR] torch.compile enabled (first run compiles ~10-35s, then ~2x).")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[UniverSR] torch.compile unavailable, continuing eager: {e}")
|
||||||
|
model_obj._universr_compiled = compiled
|
||||||
|
|
||||||
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
|
n = sum(p.numel() for p in model_obj.parameters()) / 1e6
|
||||||
print(f"[UniverSR] Ready - {n:.1f}M params on {device}")
|
print(f"[UniverSR] Ready - {n:.1f}M params on {device} (tf32={tf32}, compile={compiled})")
|
||||||
_MODEL_CACHE[cache_key] = model_obj
|
_MODEL_CACHE[cache_key] = model_obj
|
||||||
return model_obj, cache_key
|
return model_obj, cache_key
|
||||||
|
|
||||||
@@ -282,12 +309,21 @@ def _chunk_starts(total: int, chunk: int, hop: int) -> list:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
||||||
guidance_scale, chunk: int, ov: int, pbar) -> torch.Tensor:
|
guidance_scale, chunk: int, ov: int, pbar, pad_to: int = 0) -> torch.Tensor:
|
||||||
T = ch48.shape[-1]
|
T = ch48.shape[-1]
|
||||||
|
|
||||||
|
def seg_enhance(seg: torch.Tensor) -> torch.Tensor:
|
||||||
|
# pad_to>0 (compiled model) → zero-pad to a fixed length so the UNet always
|
||||||
|
# sees one input shape (compile once), then crop the result back.
|
||||||
|
L = seg.shape[-1]
|
||||||
|
if pad_to and pad_to > L:
|
||||||
|
seg = _fit(seg, pad_to)
|
||||||
|
return _fit(_enhance_segment(model, seg, input_sr, ode_method, ode_steps, guidance_scale), L)
|
||||||
|
|
||||||
if chunk <= 0 or T <= chunk:
|
if chunk <= 0 or T <= chunk:
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
return _fit(_enhance_segment(model, ch48, input_sr, ode_method, ode_steps, guidance_scale), T)
|
return _fit(seg_enhance(ch48), T)
|
||||||
|
|
||||||
hop = max(1, chunk - ov)
|
hop = max(1, chunk - ov)
|
||||||
starts = _chunk_starts(T, chunk, hop)
|
starts = _chunk_starts(T, chunk, hop)
|
||||||
@@ -297,7 +333,7 @@ def _enhance_channel(model, ch48: torch.Tensor, input_sr, ode_method, ode_steps,
|
|||||||
if HAS_COMFY:
|
if HAS_COMFY:
|
||||||
mm.throw_exception_if_processing_interrupted()
|
mm.throw_exception_if_processing_interrupted()
|
||||||
e = min(s + chunk, T)
|
e = min(s + chunk, T)
|
||||||
enh = _fit(_enhance_segment(model, ch48[s:e], input_sr, ode_method, ode_steps, guidance_scale), e - s)
|
enh = seg_enhance(ch48[s:e])
|
||||||
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
|
w = _crossfade_window(e - s, ov, first=(i == 0), last=(e >= T))
|
||||||
out[s:e] += enh * w
|
out[s:e] += enh * w
|
||||||
wsum[s:e] += w
|
wsum[s:e] += w
|
||||||
@@ -331,6 +367,15 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
|||||||
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
|
print(f"[UniverSR] chunk_seconds too small; raising to the model floor "
|
||||||
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
|
f"({MODEL_MIN_SAMPLES / TARGET_SR:.2f}s).")
|
||||||
chunk = MODEL_MIN_SAMPLES
|
chunk = MODEL_MIN_SAMPLES
|
||||||
|
|
||||||
|
# A torch.compile'd model needs a fixed input shape, so force chunking and pad
|
||||||
|
# every chunk to `chunk` samples (compile once, reuse). Without compile, pad_to=0.
|
||||||
|
compiled = getattr(model, "_universr_compiled", False)
|
||||||
|
if compiled and chunk <= 0:
|
||||||
|
chunk = int(round(10.0 * TARGET_SR))
|
||||||
|
print("[UniverSR] compile: forcing 10.0s chunks for fixed input shapes.")
|
||||||
|
pad_to = chunk if compiled else 0
|
||||||
|
|
||||||
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
|
ov = max(0, min(int(round(overlap_seconds * TARGET_SR)), chunk - 1)) if chunk > 0 else 0
|
||||||
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
|
n_per_ch = len(_chunk_starts(T48, chunk, max(1, chunk - ov))) if chunk > 0 else 1
|
||||||
|
|
||||||
@@ -351,7 +396,7 @@ def super_resolve(model, waveform: torch.Tensor, sr: int, input_sr: int,
|
|||||||
for c in range(C):
|
for c in range(C):
|
||||||
wet[b, c] = _fit(
|
wet[b, c] = _fit(
|
||||||
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
|
_enhance_channel(model, dry48[b, c], input_sr, ode_method, ode_steps,
|
||||||
guidance_scale, chunk, ov, pbar),
|
guidance_scale, chunk, ov, pbar, pad_to=pad_to),
|
||||||
T48,
|
T48,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Reference in New Issue
Block a user