Add pure-PyTorch fallback for BIM-VFI cost volume kernel
When cupy is unavailable, the costvol_func.forward() now falls back to a pure-PyTorch implementation using unfold + dot product instead of raising a RuntimeError. The CUDA/cupy kernel path is preserved unchanged for when cupy is available. This allows BIM-VFI to run on systems without cupy (including CPU-only setups), matching the pattern used for the softsplat fallbacks in SGM-VFI and GIMM-VFI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+28
-5
@@ -4,6 +4,7 @@ import collections
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
cupy = None
|
cupy = None
|
||||||
@@ -15,11 +16,7 @@ def _ensure_cupy():
|
|||||||
import cupy as _cupy
|
import cupy as _cupy
|
||||||
cupy = _cupy
|
cupy = _cupy
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError(
|
pass # cupy unavailable; PyTorch fallback will be used
|
||||||
"cupy is required for BIM-VFI. Install it with:\n"
|
|
||||||
" pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n"
|
|
||||||
"Or run install.py from the ComfyUI-Tween directory."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
@@ -246,6 +243,28 @@ def cuda_launch(strKey:str):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
|
||||||
|
"""Pure-PyTorch local cost volume via unfold + dot product."""
|
||||||
|
B, C, H, W = tenOne.shape
|
||||||
|
pad = (intKernelSize - 1) // 2
|
||||||
|
|
||||||
|
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
|
||||||
|
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
|
||||||
|
|
||||||
|
# Unfold into patches: (B, C, H, W, K, K)
|
||||||
|
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
|
||||||
|
# Reshape to (B, C, H, W, K*K)
|
||||||
|
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
|
||||||
|
|
||||||
|
# Dot product over C dimension: (B, H, W, K*K)
|
||||||
|
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
|
||||||
|
|
||||||
|
# Permute to (B, K*K, H, W) to match CUDA output layout
|
||||||
|
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
@@ -253,6 +272,8 @@ class costvol_func(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||||
def forward(self, tenOne, tenTwo, intKernelSize):
|
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||||
|
_ensure_cupy()
|
||||||
|
if tenOne.is_cuda and cupy is not None:
|
||||||
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||||
|
|
||||||
cuda_launch(cuda_kernel('costvol_out', '''
|
cuda_launch(cuda_kernel('costvol_out', '''
|
||||||
@@ -302,6 +323,8 @@ class costvol_func(torch.autograd.Function):
|
|||||||
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
||||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
|
||||||
|
|
||||||
self.save_for_backward(tenOne, tenTwo)
|
self.save_for_backward(tenOne, tenTwo)
|
||||||
self.intKernelSize = intKernelSize
|
self.intKernelSize = intKernelSize
|
||||||
|
|||||||
Reference in New Issue
Block a user