Add pure-PyTorch fallback for BIM-VFI cost volume kernel

When cupy is unavailable, the costvol_func.forward() now falls back to a
pure-PyTorch implementation using unfold + dot product instead of raising
a RuntimeError. The CUDA/cupy kernel path is preserved unchanged for when
cupy is available. This allows BIM-VFI to run on systems without cupy
(including CPU-only setups), matching the pattern used for the softsplat
fallbacks in SGM-VFI and GIMM-VFI.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-11 02:09:56 +02:00
parent daf0304243
commit 2e98e453a4
+73 -50
View File
@@ -4,6 +4,7 @@ import collections
import os
import re
import torch
import torch.nn.functional as F
import typing
cupy = None
@@ -15,11 +16,7 @@ def _ensure_cupy():
import cupy as _cupy
cupy = _cupy
except ImportError:
raise RuntimeError(
"cupy is required for BIM-VFI. Install it with:\n"
" pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n"
"Or run install.py from the ComfyUI-Tween directory."
)
pass # cupy unavailable; PyTorch fallback will be used
##########################################################
@@ -246,6 +243,28 @@ def cuda_launch(strKey:str):
# end
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
"""Pure-PyTorch local cost volume via unfold + dot product."""
B, C, H, W = tenOne.shape
pad = (intKernelSize - 1) // 2
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
# Unfold into patches: (B, C, H, W, K, K)
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
# Reshape to (B, C, H, W, K*K)
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
# Dot product over C dimension: (B, H, W, K*K)
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
# Permute to (B, K*K, H, W) to match CUDA output layout
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
return tenOut
##########################################################
@@ -253,55 +272,59 @@ class costvol_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo, intKernelSize):
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
_ensure_cupy()
if tenOne.is_cuda and cupy is not None:
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
cuda_launch(cuda_kernel('costvol_out', '''
extern "C" __global__ void __launch_bounds__(512) costvol_out(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const int intKernelSize,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
cuda_launch(cuda_kernel('costvol_out', '''
extern "C" __global__ void __launch_bounds__(512) costvol_out(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const int intKernelSize,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
{{type}} fltOne[{{intChans}}];
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
{{type}} fltValue = 0.0f;
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
}
}
tenOut[intOffset] = fltValue;
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'intKernelSize': intKernelSize,
'tenOut': tenOut
}))(
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
{{type}} fltValue = 0.0f;
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
}
}
tenOut[intOffset] = fltValue;
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'intKernelSize': intKernelSize,
'tenOut': tenOut
}))(
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
else:
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
self.save_for_backward(tenOne, tenTwo)
self.intKernelSize = intKernelSize