Add pure-PyTorch fallback for BIM-VFI cost volume kernel

When cupy is unavailable, the costvol_func.forward() now falls back to a
pure-PyTorch implementation using unfold + dot product instead of raising
a RuntimeError. The CUDA/cupy kernel path is preserved unchanged for when
cupy is available. This allows BIM-VFI to run on systems without cupy
(including CPU-only setups), matching the pattern used for the softsplat
fallbacks in SGM-VFI and GIMM-VFI.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 02:09:56 +02:00
parent daf0304243
commit 2e98e453a4
+28 -5
View File
@@ -4,6 +4,7 @@ import collections
import os
import re
import torch
import torch.nn.functional as F
import typing
cupy = None
@@ -15,11 +16,7 @@ def _ensure_cupy():
import cupy as _cupy
cupy = _cupy
except ImportError:
raise RuntimeError(
"cupy is required for BIM-VFI. Install it with:\n"
" pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n"
"Or run install.py from the ComfyUI-Tween directory."
)
pass # cupy unavailable; PyTorch fallback will be used
##########################################################
@@ -246,6 +243,28 @@ def cuda_launch(strKey:str):
# end
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
"""Pure-PyTorch local cost volume via unfold + dot product."""
B, C, H, W = tenOne.shape
pad = (intKernelSize - 1) // 2
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
# Unfold into patches: (B, C, H, W, K, K)
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
# Reshape to (B, C, H, W, K*K)
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
# Dot product over C dimension: (B, H, W, K*K)
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
# Permute to (B, K*K, H, W) to match CUDA output layout
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
return tenOut
##########################################################
@@ -253,6 +272,8 @@ class costvol_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo, intKernelSize):
_ensure_cupy()
if tenOne.is_cuda and cupy is not None:
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
cuda_launch(cuda_kernel('costvol_out', '''
@@ -302,6 +323,8 @@ class costvol_func(torch.autograd.Function):
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
else:
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
self.save_for_backward(tenOne, tenTwo)
self.intKernelSize = intKernelSize