perf: add torch.compile to PyTorch fallback kernels

Wraps _pytorch_softsplat and _pytorch_costvol with torch.compile
for ~6x speedup on ROCm/non-cupy setups. Falls back to eager
execution gracefully if compilation fails.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 10:13:30 +02:00
parent 2e75e2d076
commit 83e4b5dd98
3 changed files with 53 additions and 6 deletions
+17 -1
View File
@@ -243,7 +243,7 @@ def cuda_launch(strKey:str):
# end # end
def _pytorch_costvol(tenOne, tenTwo, intKernelSize): def _pytorch_costvol_impl(tenOne, tenTwo, intKernelSize):
"""Pure-PyTorch local cost volume via unfold + dot product.""" """Pure-PyTorch local cost volume via unfold + dot product."""
B, C, H, W = tenOne.shape B, C, H, W = tenOne.shape
pad = (intKernelSize - 1) // 2 pad = (intKernelSize - 1) // 2
@@ -265,6 +265,22 @@ def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
return tenOut return tenOut
_costvol_fn = None
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
global _costvol_fn
if _costvol_fn is None:
try:
_costvol_fn = torch.compile(_pytorch_costvol_impl)
except Exception:
_costvol_fn = _pytorch_costvol_impl
try:
return _costvol_fn(tenOne, tenTwo, intKernelSize)
except Exception:
_costvol_fn = _pytorch_costvol_impl
return _costvol_fn(tenOne, tenTwo, intKernelSize)
########################################################## ##########################################################
@@ -292,7 +292,7 @@ def cuda_launch(strKey: str):
########################################################## ##########################################################
def _pytorch_softsplat(tenIn, tenFlow): def _pytorch_softsplat_impl(tenIn, tenFlow):
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" """Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
B, C, H, W = tenIn.shape B, C, H, W = tenIn.shape
tenOut = tenIn.new_zeros(B, C, H, W) tenOut = tenIn.new_zeros(B, C, H, W)
@@ -335,6 +335,22 @@ def _pytorch_softsplat(tenIn, tenFlow):
return tenOut return tenOut
_softsplat_fn = None
def _pytorch_softsplat(tenIn, tenFlow):
global _softsplat_fn
if _softsplat_fn is None:
try:
_softsplat_fn = torch.compile(_pytorch_softsplat_impl)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
try:
return _softsplat_fn(tenIn, tenFlow)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
return _softsplat_fn(tenIn, tenFlow)
@torch.compiler.disable() @torch.compiler.disable()
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False): def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"] assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
+19 -4
View File
@@ -244,7 +244,7 @@ def cuda_launch(strKey:str):
########################################################## ##########################################################
def _pytorch_softsplat(tenIn, tenFlow): def _pytorch_softsplat_impl(tenIn, tenFlow):
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" """Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
B, C, H, W = tenIn.shape B, C, H, W = tenIn.shape
tenOut = tenIn.new_zeros(B, C, H, W) tenOut = tenIn.new_zeros(B, C, H, W)
@@ -273,7 +273,6 @@ def _pytorch_softsplat(tenIn, tenFlow):
w_se = frac_x * frac_y * valid w_se = frac_x * frac_y * valid
out_flat = tenOut.view(B, C, -1) out_flat = tenOut.view(B, C, -1)
in_flat = tenIn
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]: for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
tx = nw_x + dx tx = nw_x + dx
@@ -281,11 +280,27 @@ def _pytorch_softsplat(tenIn, tenFlow):
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H) in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
w_masked = w * in_bounds w_masked = w * in_bounds
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1)) idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
idx = idx.unsqueeze(1).expand_as(in_flat) idx = idx.unsqueeze(1).expand_as(tenIn)
weighted = in_flat * w_masked.unsqueeze(1) weighted = tenIn * w_masked.unsqueeze(1)
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1)) out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
return tenOut return tenOut
_softsplat_fn = None
def _pytorch_softsplat(tenIn, tenFlow):
global _softsplat_fn
if _softsplat_fn is None:
try:
_softsplat_fn = torch.compile(_pytorch_softsplat_impl)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
try:
return _softsplat_fn(tenIn, tenFlow)
except Exception:
_softsplat_fn = _pytorch_softsplat_impl
return _softsplat_fn(tenIn, tenFlow)
# end # end