feat: equal-quality speed options (TF32 + torch.compile)

Add two opt-in inference speedups to the Model Loader, validated to leave the
output perceptually identical (deviation at the fp32 rounding floor):

- tf32 (default on): TF32 matmul on Ampere+ (~1.15x).
- compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to
  ~2.5x (measured 4.3s -> 1.7s on a 12s clip).

torch.compile needs a static shape (the model's adaptive-avg-pool can't trace
dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of
any length reuse one compiled graph (no per-length recompiles; verified an 8s
clip after a 12s clip ran in 0.9s with no recompile).

Researched + profiled first: CFG-batching, channel/chunk batching, and
channels_last gave ~0 gain because the GPU is already compute-bound at batch 1.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 17:16:21 +02:00
parent 9a901adcc5
commit 104cd4bf5f
3 changed files with 97 additions and 11 deletions
+18 -4
View File
@@ -55,6 +55,17 @@ class UniverSRModelLoader:
}),
},
"optional": {
"tf32": ("BOOLEAN", {
"default": True,
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
"but not bit-exact; global setting. Turn off for reference fp32.",
}),
"compile": ("BOOLEAN", {
"default": False,
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
}),
"local_path": ("STRING", {
"default": "",
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
@@ -72,17 +83,20 @@ class UniverSRModelLoader:
RETURN_NAMES = ("model",)
FUNCTION = "load"
def load(self, model, device, local_path="", config_path=""):
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
dev = _default_device() if device == "auto" else device
if dev == "cuda" and not torch.cuda.is_available():
print("[UniverSR] CUDA unavailable, falling back to CPU")
dev = "cpu"
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
model_obj, cache_key = usr.load_model(
model, dev, local_path=local_path, config_path=config_path,
tf32=tf32, compile_model=compile,
)
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
@classmethod
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
return f"{model}:{device}:{local_path}:{config_path}"
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
# --------------------------------------------------------------------------- #