feat: equal-quality speed options (TF32 + torch.compile)
Add two opt-in inference speedups to the Model Loader, validated to leave the output perceptually identical (deviation at the fp32 rounding floor): - tf32 (default on): TF32 matmul on Ampere+ (~1.15x). - compile (opt-in): torch.compile the UNet (~2.1x). Stacks with TF32 to ~2.5x (measured 4.3s -> 1.7s on a 12s clip). torch.compile needs a static shape (the model's adaptive-avg-pool can't trace dynamic shapes), so the sampler pads every chunk to chunk_seconds — clips of any length reuse one compiled graph (no per-length recompiles; verified an 8s clip after a 12s clip ran in 0.9s with no recompile). Researched + profiled first: CFG-batching, channel/chunk batching, and channels_last gave ~0 gain because the GPU is already compute-bound at batch 1. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -55,6 +55,17 @@ class UniverSRModelLoader:
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"tf32": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable TF32 matmul on Ampere+ GPUs (~1.15x). Perceptually lossless "
|
||||
"but not bit-exact; global setting. Turn off for reference fp32.",
|
||||
}),
|
||||
"compile": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "torch.compile the network (~2x). First run compiles (~10-35s), then fast "
|
||||
"and cached. Needs CUDA. Chunks are auto-padded to a fixed size, so set the "
|
||||
"sampler's chunk_seconds near your typical clip length to avoid wasted compute.",
|
||||
}),
|
||||
"local_path": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": "Override: a folder with config.yaml + pytorch_model.bin, "
|
||||
@@ -72,17 +83,20 @@ class UniverSRModelLoader:
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, model, device, local_path="", config_path=""):
|
||||
def load(self, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||
dev = _default_device() if device == "auto" else device
|
||||
if dev == "cuda" and not torch.cuda.is_available():
|
||||
print("[UniverSR] CUDA unavailable, falling back to CPU")
|
||||
dev = "cpu"
|
||||
model_obj, cache_key = usr.load_model(model, dev, local_path=local_path, config_path=config_path)
|
||||
model_obj, cache_key = usr.load_model(
|
||||
model, dev, local_path=local_path, config_path=config_path,
|
||||
tf32=tf32, compile_model=compile,
|
||||
)
|
||||
return ({"model": model_obj, "device": dev, "cache_key": cache_key},)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, model, device, local_path="", config_path=""):
|
||||
return f"{model}:{device}:{local_path}:{config_path}"
|
||||
def IS_CHANGED(cls, model, device, tf32=True, compile=False, local_path="", config_path=""):
|
||||
return f"{model}:{device}:tf32={tf32}:compile={compile}:{local_path}:{config_path}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
Reference in New Issue
Block a user