fix(perf): default TF32 off; off = true fp32 (matmul + cuDNN conv)

Reported as "darker", but a fixed-seed spectral A/B shows TF32 is tonally
neutral (centroid 564→565 Hz, HF>8k 0.00825→0.00833) — the perceived change
is the seed=0 random-noise confound, not TF32. Still, TF32 is only ~1.15x and
not bit-exact, so default it OFF for reference-fp32 output and let compile
(~2.1x, op fusion) be the headline speedup. apply_tf32 now also toggles
cuDNN conv-TF32 (PyTorch leaves it on by default), so off is genuinely fp32.
Docs updated with the seed-confound A/B guidance.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 10:47:39 +02:00
parent 104cd4bf5f
commit 94178a4851
3 changed files with 29 additions and 19 deletions
+9 -4
View File
@@ -159,16 +159,21 @@ def resolve_model_ref(model: str, local_path: str = "") -> tuple:
def apply_tf32(enabled: bool):
"""Enable/disable TF32 matmul on Ampere+ GPUs. ~1.15x speedup, perceptually
lossless but NOT bit-exact (10 mantissa bits vs 23). Global process setting."""
"""Enable/disable TF32 for BOTH matmul and cuDNN convolutions on Ampere+ GPUs.
~1.15x when on. In our spectral A/B (centroid + HF energy) TF32 was tonally
neutral, but it is NOT bit-exact (10 mantissa bits vs 23), so it's off by
default. Off sets true fp32 — note PyTorch otherwise leaves cuDNN conv-TF32 ON
by default, so we explicitly disable it here too. Global process setting."""
try:
torch.set_float32_matmul_precision("high" if enabled else "highest")
torch.set_float32_matmul_precision("high" if enabled else "highest") # matmul TF32
torch.backends.cudnn.allow_tf32 = enabled # conv TF32
except Exception:
pass
def load_model(model: str, device: str, local_path: str = "", config_path: str = "",
tf32: bool = True, compile_model: bool = False):
tf32: bool = False, compile_model: bool = False):
"""Load (and cache) a UniverSR model. Returns (model_obj, cache_key)."""
apply_tf32(tf32) # global; apply before the cache short-circuit so toggling takes effect
kind, path = resolve_model_ref(model, local_path)