Add pure-PyTorch fallback for BIM-VFI cost volume kernel
When cupy is unavailable, the costvol_func.forward() now falls back to a pure-PyTorch implementation using unfold + dot product instead of raising a RuntimeError. The CUDA/cupy kernel path is preserved unchanged for when cupy is available. This allows BIM-VFI to run on systems without cupy (including CPU-only setups), matching the pattern used for the softsplat fallbacks in SGM-VFI and GIMM-VFI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+28
-5
@@ -4,6 +4,7 @@ import collections
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import typing
|
||||
|
||||
cupy = None
|
||||
@@ -15,11 +16,7 @@ def _ensure_cupy():
|
||||
import cupy as _cupy
|
||||
cupy = _cupy
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"cupy is required for BIM-VFI. Install it with:\n"
|
||||
" pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n"
|
||||
"Or run install.py from the ComfyUI-Tween directory."
|
||||
)
|
||||
pass # cupy unavailable; PyTorch fallback will be used
|
||||
|
||||
|
||||
##########################################################
|
||||
@@ -246,6 +243,28 @@ def cuda_launch(strKey:str):
|
||||
# end
|
||||
|
||||
|
||||
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
|
||||
"""Pure-PyTorch local cost volume via unfold + dot product."""
|
||||
B, C, H, W = tenOne.shape
|
||||
pad = (intKernelSize - 1) // 2
|
||||
|
||||
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
|
||||
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
|
||||
|
||||
# Unfold into patches: (B, C, H, W, K, K)
|
||||
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
|
||||
# Reshape to (B, C, H, W, K*K)
|
||||
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
|
||||
|
||||
# Dot product over C dimension: (B, H, W, K*K)
|
||||
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
|
||||
|
||||
# Permute to (B, K*K, H, W) to match CUDA output layout
|
||||
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return tenOut
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
@@ -253,6 +272,8 @@ class costvol_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||
_ensure_cupy()
|
||||
if tenOne.is_cuda and cupy is not None:
|
||||
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||
|
||||
cuda_launch(cuda_kernel('costvol_out', '''
|
||||
@@ -302,6 +323,8 @@ class costvol_func(torch.autograd.Function):
|
||||
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
else:
|
||||
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
|
||||
|
||||
self.save_for_backward(tenOne, tenTwo)
|
||||
self.intKernelSize = intKernelSize
|
||||
|
||||
Reference in New Issue
Block a user