Files
ComfyUI-Tween/bim_vfi_arch/costvol.py
T
Ethanfel 2e98e453a4 Add pure-PyTorch fallback for BIM-VFI cost volume kernel
When cupy is unavailable, the costvol_func.forward() now falls back to a
pure-PyTorch implementation using unfold + dot product instead of raising
a RuntimeError. The CUDA/cupy kernel path is preserved unchanged for when
cupy is available. This allows BIM-VFI to run on systems without cupy
(including CPU-only setups), matching the pattern used for the softsplat
fallbacks in SGM-VFI and GIMM-VFI.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-11 02:09:56 +02:00

443 lines
17 KiB
Python

#!/usr/bin/env python
import collections
import os
import re
import torch
import torch.nn.functional as F
import typing
cupy = None
def _ensure_cupy():
global cupy
if cupy is None:
try:
import cupy as _cupy
cupy = _cupy
except ImportError:
pass # cupy unavailable; PyTorch fallback will be used
##########################################################
objCudacache = {}
def cuda_int32(intIn:int):
_ensure_cupy()
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
_ensure_cupy()
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
_ensure_cupy()
if 'device' not in objCudacache:
objCudacache['device'] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
strKey += objCudacache['device']
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace('{{type}}', 'unsigned char')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace('{{type}}', 'half')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace('{{type}}', 'float')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace('{{type}}', 'double')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace('{{type}}', 'int')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace('{{type}}', 'long')
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert(False)
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
# end
objCudacache[strKey] = {
'strFunction': strFunction,
'strKernel': strKernel
}
# end
return strKey
# end
_cuda_launch_cache = {}
def cuda_launch(strKey:str):
_ensure_cupy()
if strKey not in _cuda_launch_cache:
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = '/usr/local/cuda/'
# end
_cuda_launch_cache[strKey] = cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction'])
return _cuda_launch_cache[strKey]
# end
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
"""Pure-PyTorch local cost volume via unfold + dot product."""
B, C, H, W = tenOne.shape
pad = (intKernelSize - 1) // 2
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
# Unfold into patches: (B, C, H, W, K, K)
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
# Reshape to (B, C, H, W, K*K)
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
# Dot product over C dimension: (B, H, W, K*K)
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
# Permute to (B, K*K, H, W) to match CUDA output layout
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
return tenOut
##########################################################
class costvol_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo, intKernelSize):
_ensure_cupy()
if tenOne.is_cuda and cupy is not None:
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
cuda_launch(cuda_kernel('costvol_out', '''
extern "C" __global__ void __launch_bounds__(512) costvol_out(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const int intKernelSize,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
{{type}} fltValue = 0.0f;
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
}
}
tenOut[intOffset] = fltValue;
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'intKernelSize': intKernelSize,
'tenOut': tenOut
}))(
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
else:
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
self.save_for_backward(tenOne, tenTwo)
self.intKernelSize = intKernelSize
return tenOut
# end
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
def backward(self, tenOutgrad):
tenOne, tenTwo = self.saved_tensors
intKernelSize = self.intKernelSize
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None
tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None
if tenOnegrad is not None:
cuda_launch(cuda_kernel('costvol_onegrad', '''
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad);
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad);
const int intX = ( intIndex ) % SIZE_3(tenOnegrad);
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
if tenTwograd is not None:
cuda_launch(cuda_kernel('costvol_twograd', '''
extern "C" __global__ void __launch_bounds__(512) costvol_twograd(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd);
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd);
const int intX = ( intIndex ) % SIZE_3(tenTwograd);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
return tenOnegrad, tenTwograd, None, None, None
# end
# end