diff --git a/bim_vfi_arch/costvol.py b/bim_vfi_arch/costvol.py index 0d58304..06cb6e7 100644 --- a/bim_vfi_arch/costvol.py +++ b/bim_vfi_arch/costvol.py @@ -243,7 +243,7 @@ def cuda_launch(strKey:str): # end -def _pytorch_costvol(tenOne, tenTwo, intKernelSize): +def _pytorch_costvol_impl(tenOne, tenTwo, intKernelSize): """Pure-PyTorch local cost volume via unfold + dot product.""" B, C, H, W = tenOne.shape pad = (intKernelSize - 1) // 2 @@ -265,6 +265,22 @@ def _pytorch_costvol(tenOne, tenTwo, intKernelSize): return tenOut +_costvol_fn = None + +def _pytorch_costvol(tenOne, tenTwo, intKernelSize): + global _costvol_fn + if _costvol_fn is None: + try: + _costvol_fn = torch.compile(_pytorch_costvol_impl) + except Exception: + _costvol_fn = _pytorch_costvol_impl + try: + return _costvol_fn(tenOne, tenTwo, intKernelSize) + except Exception: + _costvol_fn = _pytorch_costvol_impl + return _costvol_fn(tenOne, tenTwo, intKernelSize) + + ########################################################## diff --git a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py index 60eaec1..2f21014 100644 --- a/gimm_vfi_arch/generalizable_INR/modules/softsplat.py +++ b/gimm_vfi_arch/generalizable_INR/modules/softsplat.py @@ -292,7 +292,7 @@ def cuda_launch(strKey: str): ########################################################## -def _pytorch_softsplat(tenIn, tenFlow): +def _pytorch_softsplat_impl(tenIn, tenFlow): """Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" B, C, H, W = tenIn.shape tenOut = tenIn.new_zeros(B, C, H, W) @@ -335,6 +335,22 @@ def _pytorch_softsplat(tenIn, tenFlow): return tenOut +_softsplat_fn = None + +def _pytorch_softsplat(tenIn, tenFlow): + global _softsplat_fn + if _softsplat_fn is None: + try: + _softsplat_fn = torch.compile(_pytorch_softsplat_impl) + except Exception: + _softsplat_fn = _pytorch_softsplat_impl + try: + return _softsplat_fn(tenIn, tenFlow) + except Exception: + _softsplat_fn = _pytorch_softsplat_impl + return _softsplat_fn(tenIn, tenFlow) + + @torch.compiler.disable() def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False): assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"] diff --git a/sgm_vfi_arch/softsplat.py b/sgm_vfi_arch/softsplat.py index ad2d293..c47f07a 100644 --- a/sgm_vfi_arch/softsplat.py +++ b/sgm_vfi_arch/softsplat.py @@ -244,7 +244,7 @@ def cuda_launch(strKey:str): ########################################################## -def _pytorch_softsplat(tenIn, tenFlow): +def _pytorch_softsplat_impl(tenIn, tenFlow): """Pure-PyTorch forward warp via bilinear splatting (scatter_add).""" B, C, H, W = tenIn.shape tenOut = tenIn.new_zeros(B, C, H, W) @@ -273,7 +273,6 @@ def _pytorch_softsplat(tenIn, tenFlow): w_se = frac_x * frac_y * valid out_flat = tenOut.view(B, C, -1) - in_flat = tenIn for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]: tx = nw_x + dx @@ -281,11 +280,27 @@ def _pytorch_softsplat(tenIn, tenFlow): in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H) w_masked = w * in_bounds idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1)) - idx = idx.unsqueeze(1).expand_as(in_flat) - weighted = in_flat * w_masked.unsqueeze(1) + idx = idx.unsqueeze(1).expand_as(tenIn) + weighted = tenIn * w_masked.unsqueeze(1) out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1)) return tenOut + + +_softsplat_fn = None + +def _pytorch_softsplat(tenIn, tenFlow): + global _softsplat_fn + if _softsplat_fn is None: + try: + _softsplat_fn = torch.compile(_pytorch_softsplat_impl) + except Exception: + _softsplat_fn = _pytorch_softsplat_impl + try: + return _softsplat_fn(tenIn, tenFlow) + except Exception: + _softsplat_fn = _pytorch_softsplat_impl + return _softsplat_fn(tenIn, tenFlow) # end