Initial commit: ComfyUI BIM-VFI node for video frame interpolation

Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 18:26:49 +01:00
commit db64fc195a
16 changed files with 1302 additions and 0 deletions

2
bim_vfi_arch/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .bim_vfi import BiMVFI
from .backwarp import clear_backwarp_cache

42
bim_vfi_arch/arch.py Normal file
View File

@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
x = x.permute(0, 2, 3, 1)
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class ResBlock(nn.Module):
def __init__(self, feat_channels, kernel_size=3, padding_mode='zeros'):
super().__init__()
self.conv1 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
padding_mode=padding_mode)
self.act = nn.LeakyReLU()
self.conv2 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
padding_mode=padding_mode)
def forward(self, x):
inp = x
x = self.conv2(self.act(self.conv1(x)))
return inp + x

24
bim_vfi_arch/backwarp.py Normal file
View File

@@ -0,0 +1,24 @@
import torch
objBackwarpcache = {}
def clear_backwarp_cache():
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
objBackwarpcache.clear()
def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'):
if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache:
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1)
if tenFlow.shape[3] == tenFlow.shape[2]:
tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0))
elif tenFlow.shape[3] != tenFlow.shape[2]:
tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1)
return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True)

114
bim_vfi_arch/bim_vfi.py Normal file
View File

@@ -0,0 +1,114 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from .backwarp import backwarp
from .resnet_encoder import ResNetPyramid
from .caun import CAUN
from .bimfn import BiMFN
from .sn import SynthesisNetwork
from ..utils.padder import InputPadder
class BiMVFI(nn.Module):
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
super(BiMVFI, self).__init__()
self.pyr_level = pyr_level
self.mfe = ResNetPyramid(feat_channels)
self.cfe = ResNetPyramid(feat_channels)
self.bimfn = BiMFN(feat_channels)
self.sn = SynthesisNetwork(feat_channels)
self.feat_channels = feat_channels
self.caun = CAUN(feat_channels)
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
feat0_pyr = self.mfe(img0)
feat1_pyr = self.mfe(img1)
cfeat0_pyr = self.cfe(img0)
cfeat1_pyr = self.cfe(img1)
B, _, H, W = feat0_pyr[-1].shape
# Inference path: prepare uniform BiM
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
last_flow = F.interpolate(
input=last_flow.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False) * 0.5
last_occ = F.interpolate(
input=last_occ.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False)
flow_low, flow_res = self.bimfn(
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
flow = bi_flow_pyr[0]
interp_img, occ, extra_dict = self.sn(
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
extra_dict.update({'flow_res': flow_res})
return flow, occ, interp_img, extra_dict
def forward(self, img0, img1, time_step,
pyr_level=None, **kwargs):
if pyr_level is None: pyr_level = self.pyr_level
N, _, H, W = img0.shape
interp_imgs = []
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
# Normalize input images
with torch.set_grad_enabled(False):
tenStats = [img0, img1]
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
# Pad images for downsampling
img0, img1 = padder.pad(img0, img1)
N, _, H, W = img0.shape
for level in list(range(pyr_level))[::-1]:
# Downsample images if needed
if level != 0:
scale_factor = 1 / 2 ** level
img0_this_lvl = F.interpolate(
input=img0, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
img1_this_lvl = F.interpolate(
input=img1, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
else:
img0_this_lvl = img0
img1_this_lvl = img1
# Initialize zero flows for lowest pyramid level
if level == pyr_level - 1:
last_flow = torch.zeros(
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
)
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
else:
last_flow = flow
last_occ = occ
# Single pyramid level run
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
result_dict = {
"imgt_preds": interp_imgs,
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
}
return result_dict

115
bim_vfi_arch/bimfn.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
from .backwarp import backwarp
from .arch import LayerNorm, ResBlock
from .costvol import costvol_func
class BiMMConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_FV = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_r = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_phi = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
def forward(self, FV, r, phi):
FV_out1 = self.conv_FV(FV)
r_out = self.conv_r(r)
phi_out = self.conv_phi(phi)
FV_out2 = FV_out1 + FV_out1 * r_out * phi_out
return FV_out2, r_out, phi_out
class BiMFN(nn.Module):
def __init__(self, feat_channels):
super(BiMFN, self).__init__()
self.conv_flow = nn.Sequential(
nn.Conv2d(2, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_occ = nn.Sequential(
nn.Conv2d(1, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_corr = nn.Sequential(
nn.LeakyReLU(0.1),
nn.Conv2d(81, feat_channels * 2, 1, padding=0),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv0 = nn.Sequential(
nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0),
nn.PReLU(feat_channels * 12),
nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1),
)
self.conv1 = nn.Sequential(
nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0),
nn.PReLU(feat_channels * 8),
nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1),
)
self.conv2 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.conv3 = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.dem = nn.Sequential(
nn.Conv2d(1, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.aem = nn.Sequential(
nn.Conv2d(2, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6)
self.conv_out = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, 4, 1, padding=0))
def forward(self, feat0, feat1, r, phi, last_flow, last_occ):
feat0_warp = backwarp(feat0, (last_flow[:, :2]))
feat1_warp = backwarp(feat1, (last_flow[:, 2:]))
volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9)
volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9)
corr0 = self.conv_corr(volume0)
corr1 = self.conv_corr(volume1)
flo0 = self.conv_flow(last_flow[:, :2])
flo1 = self.conv_flow(last_flow[:, 2:])
occ = self.conv_occ(last_occ)
input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1)
FV = self.conv0(input_feat)
FV = self.conv1(FV)
FV = self.conv2(FV)
FV0 = self.conv3(FV)
r0 = self.dem(r)
phi0 = self.aem(phi)
bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0)
flow_res = self.conv_out(bim_feat)
flow_low = flow_res + last_flow
return flow_low, flow_res

72
bim_vfi_arch/caun.py Normal file
View File

@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backwarp import backwarp
class CAUN(nn.Module):
def __init__(self, feat_channels):
super(CAUN, self).__init__()
self.enc0 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc1 = nn.Sequential(
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc2 = nn.Sequential(
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
)
self.kernel_x2 = nn.Sequential(
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
)
self.kernel_x4 = nn.Sequential(
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
)
def upsample_input(self, inp, mask, upsample_factor):
N, c, H, W = inp.shape
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
mask = torch.softmax(mask, dim=2)
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
up_inp = F.unfold(inp, [3, 3])
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
up_inp = torch.sum(mask * up_inp, dim=2)
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
def forward(self, flow, feat0, feat1, last_occ):
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
N, _, H, W = flow.shape
feat0_warped_list, feat1_warped_list = [], []
for i in range(3):
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
feat0_warped_list.append(feat0_warped)
feat1_warped_list.append(feat1_warped)
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
feature0 = self.enc0(feature)
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
mask_x2 = self.kernel_x2(feature1)
mask_x4 = self.kernel_x4(feature2)
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
return [up_flow_x4, up_flow_x2, flow], up_occ

399
bim_vfi_arch/costvol.py Normal file
View File

@@ -0,0 +1,399 @@
#!/usr/bin/env python
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn:int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
if 'device' not in objCudacache:
objCudacache['device'] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
strKey += objCudacache['device']
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace('{{type}}', 'unsigned char')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace('{{type}}', 'half')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace('{{type}}', 'float')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace('{{type}}', 'double')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace('{{type}}', 'int')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace('{{type}}', 'long')
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert(False)
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
# end
objCudacache[strKey] = {
'strFunction': strFunction,
'strKernel': strKernel
}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
def cuda_launch(strKey:str):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = '/usr/local/cuda/'
# end
return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction'])
# end
##########################################################
class costvol_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo, intKernelSize):
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
cuda_launch(cuda_kernel('costvol_out', '''
extern "C" __global__ void __launch_bounds__(512) costvol_out(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const int intKernelSize,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
{{type}} fltValue = 0.0f;
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
}
}
tenOut[intOffset] = fltValue;
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'intKernelSize': intKernelSize,
'tenOut': tenOut
}))(
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
self.save_for_backward(tenOne, tenTwo)
self.intKernelSize = intKernelSize
return tenOut
# end
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
def backward(self, tenOutgrad):
tenOne, tenTwo = self.saved_tensors
intKernelSize = self.intKernelSize
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None
tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None
if tenOnegrad is not None:
cuda_launch(cuda_kernel('costvol_onegrad', '''
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad);
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad);
const int intX = ( intIndex ) % SIZE_3(tenOnegrad);
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
if tenTwograd is not None:
cuda_launch(cuda_kernel('costvol_twograd', '''
extern "C" __global__ void __launch_bounds__(512) costvol_twograd(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd);
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd);
const int intX = ( intIndex ) % SIZE_3(tenTwograd);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
return tenOnegrad, tenTwograd, None, None, None
# end
# end

View File

@@ -0,0 +1,102 @@
import torch.nn as nn
from typing import Optional, Callable
from torch import Tensor
from functools import partial
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=True,
dilation=dilation,
)
def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""2x2 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first')
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = norm_layer(inplanes)
if stride == 1:
self.conv1 = conv3x3(inplanes, planes, stride)
else:
self.conv1 = conv2x2(inplanes, planes, stride)
self.lrelu = nn.LeakyReLU(inplace=True)
self.bn2 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.lrelu(out)
out = self.bn2(out)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.lrelu(out)
out = out + identity
return out
class ResNetPyramid(nn.Module):
"""A 3-level feature pyramid, which by default is shared by the motion
estimator and synthesis network.
"""
def __init__(self, feat_channels: int):
super(ResNetPyramid, self).__init__()
self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1)
self.layer0 = nn.Sequential(
BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d),
)
self.layer1 = nn.Sequential(
nn.Conv2d(feat_channels, feat_channels * 2, 2, 2),
BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d),
)
self.layer2 = nn.Sequential(
nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2),
BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d),
)
self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1)
def forward(self, img):
C0 = self.layer0(self.conv(img))
C1 = self.layer1(C0)
C2 = self.conv_last(self.layer2(C1))
return [C0, C1, C2]

95
bim_vfi_arch/sn.py Normal file
View File

@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from .backwarp import backwarp
class SynthesisNetwork(nn.Module):
def __init__(self, feat_channels):
super(SynthesisNetwork, self).__init__()
input_channels = 6 + 1
self.conv_down1 = nn.Sequential(
nn.Conv2d(input_channels, feat_channels, 7, padding=3),
nn.PReLU(feat_channels),
nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_down2 = nn.Sequential(
nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_down3 = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4))
self.conv_up1 = nn.Sequential(
torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_up2 = nn.Sequential(
torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(feat_channels * 1),
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1))
self.conv_up3 = nn.Sequential(
nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1)
def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None):
flow_t0 = bi_flow[:, :2]
flow_t1 = bi_flow[:, 2:4]
warped_c0 = backwarp(c0, flow_t0)
warped_c1 = backwarp(c1, flow_t1)
if (i0 is None) and (i1 is None):
return warped_c0, warped_c1
else:
warped_img0 = backwarp(i0, flow_t0)
warped_img1 = backwarp(i1, flow_t1)
return warped_img0, warped_img1, warped_c0, warped_c1
def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ):
warped_img0, warped_img1, warped_c0, warped_c1 = \
self.get_warped_representations(
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1)
input_feat = torch.cat(
(warped_img0, warped_img1, occ), 1)
s0 = self.conv_down1(input_feat)
s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1))
warped_c0, warped_c1 = self.get_warped_representations(
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None)
s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1))
warped_c0, warped_c1 = self.get_warped_representations(
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None)
x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1))
x = self.conv_up2(torch.cat((x, s1), 1))
x = self.conv_up3(torch.cat((x, s0), 1))
refine = self.conv_out(x)
refine_res = refine[:, :3]
occ_res = refine[:, 3:]
occ_out = occ + occ_res
blending_mask = torch.sigmoid(occ_out)
merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res
interp_img = merged_img
extra_dict = {}
extra_dict["refine_res"] = refine_res
extra_dict["refine_mask"] = occ_out
extra_dict["warped_img0"] = warped_img0
extra_dict["warped_img1"] = warped_img1
extra_dict["merged_img"] = merged_img
return interp_img, occ_out, extra_dict