diff --git a/bim_vfi_arch/costvol.py b/bim_vfi_arch/costvol.py index 93bd5d9..0d58304 100644 --- a/bim_vfi_arch/costvol.py +++ b/bim_vfi_arch/costvol.py @@ -4,6 +4,7 @@ import collections import os import re import torch +import torch.nn.functional as F import typing cupy = None @@ -15,11 +16,7 @@ def _ensure_cupy(): import cupy as _cupy cupy = _cupy except ImportError: - raise RuntimeError( - "cupy is required for BIM-VFI. Install it with:\n" - " pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n" - "Or run install.py from the ComfyUI-Tween directory." - ) + pass # cupy unavailable; PyTorch fallback will be used ########################################################## @@ -246,6 +243,28 @@ def cuda_launch(strKey:str): # end +def _pytorch_costvol(tenOne, tenTwo, intKernelSize): + """Pure-PyTorch local cost volume via unfold + dot product.""" + B, C, H, W = tenOne.shape + pad = (intKernelSize - 1) // 2 + + # Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel) + tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad]) + + # Unfold into patches: (B, C, H, W, K, K) + patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1) + # Reshape to (B, C, H, W, K*K) + patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize) + + # Dot product over C dimension: (B, H, W, K*K) + tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1) + + # Permute to (B, K*K, H, W) to match CUDA output layout + tenOut = tenOut.permute(0, 3, 1, 2).contiguous() + + return tenOut + + ########################################################## @@ -253,55 +272,59 @@ class costvol_func(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32) def forward(self, tenOne, tenTwo, intKernelSize): - tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]]) + _ensure_cupy() + if tenOne.is_cuda and cupy is not None: + tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]]) - cuda_launch(cuda_kernel('costvol_out', ''' - extern "C" __global__ void __launch_bounds__(512) costvol_out( - const int n, - const {{type}}* __restrict__ tenOne, - const {{type}}* __restrict__ tenTwo, - const int intKernelSize, - {{type}}* __restrict__ tenOut - ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { - const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); - const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); - const int intX = ( intIndex ) % SIZE_3(tenOut); + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const int intKernelSize, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); - {{type}} fltOne[{{intChans}}]; + {{type}} fltOne[{{intChans}}]; - for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { - fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); - } - - int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); - - for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) { - for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) { - {{type}} fltValue = 0.0f; - - if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { - for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { - fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx)); - } - } - - tenOut[intOffset] = fltValue; - intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); } - } - } } - ''', { - 'intChans': tenOne.shape[1], - 'tenOne': tenOne, - 'tenTwo': tenTwo, - 'intKernelSize': intKernelSize, - 'tenOut': tenOut - }))( - grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), - block=tuple([512, 1, 1]), - args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()], - stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) - ) + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) { + for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } + + tenOut[intOffset] = fltValue; + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'intKernelSize': intKernelSize, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + else: + tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize) self.save_for_backward(tenOne, tenTwo) self.intKernelSize = intKernelSize