Initial commit: ComfyUI BIM-VFI node for video frame interpolation

Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 18:26:49 +01:00
commit db64fc195a
16 changed files with 1302 additions and 0 deletions

11
__init__.py Normal file
View File

@@ -0,0 +1,11 @@
from .nodes import LoadBIMVFIModel, BIMVFIInterpolate
NODE_CLASS_MAPPINGS = {
"LoadBIMVFIModel": LoadBIMVFIModel,
"BIMVFIInterpolate": BIMVFIInterpolate,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadBIMVFIModel": "Load BIM-VFI Model",
"BIMVFIInterpolate": "BIM-VFI Interpolate",
}

2
bim_vfi_arch/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .bim_vfi import BiMVFI
from .backwarp import clear_backwarp_cache

42
bim_vfi_arch/arch.py Normal file
View File

@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
x = x.permute(0, 2, 3, 1)
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class ResBlock(nn.Module):
def __init__(self, feat_channels, kernel_size=3, padding_mode='zeros'):
super().__init__()
self.conv1 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
padding_mode=padding_mode)
self.act = nn.LeakyReLU()
self.conv2 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
padding_mode=padding_mode)
def forward(self, x):
inp = x
x = self.conv2(self.act(self.conv1(x)))
return inp + x

24
bim_vfi_arch/backwarp.py Normal file
View File

@@ -0,0 +1,24 @@
import torch
objBackwarpcache = {}
def clear_backwarp_cache():
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
objBackwarpcache.clear()
def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'):
if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache:
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1)
if tenFlow.shape[3] == tenFlow.shape[2]:
tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0))
elif tenFlow.shape[3] != tenFlow.shape[2]:
tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1)
return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True)

114
bim_vfi_arch/bim_vfi.py Normal file
View File

@@ -0,0 +1,114 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from .backwarp import backwarp
from .resnet_encoder import ResNetPyramid
from .caun import CAUN
from .bimfn import BiMFN
from .sn import SynthesisNetwork
from ..utils.padder import InputPadder
class BiMVFI(nn.Module):
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
super(BiMVFI, self).__init__()
self.pyr_level = pyr_level
self.mfe = ResNetPyramid(feat_channels)
self.cfe = ResNetPyramid(feat_channels)
self.bimfn = BiMFN(feat_channels)
self.sn = SynthesisNetwork(feat_channels)
self.feat_channels = feat_channels
self.caun = CAUN(feat_channels)
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
feat0_pyr = self.mfe(img0)
feat1_pyr = self.mfe(img1)
cfeat0_pyr = self.cfe(img0)
cfeat1_pyr = self.cfe(img1)
B, _, H, W = feat0_pyr[-1].shape
# Inference path: prepare uniform BiM
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
last_flow = F.interpolate(
input=last_flow.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False) * 0.5
last_occ = F.interpolate(
input=last_occ.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False)
flow_low, flow_res = self.bimfn(
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
flow = bi_flow_pyr[0]
interp_img, occ, extra_dict = self.sn(
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
extra_dict.update({'flow_res': flow_res})
return flow, occ, interp_img, extra_dict
def forward(self, img0, img1, time_step,
pyr_level=None, **kwargs):
if pyr_level is None: pyr_level = self.pyr_level
N, _, H, W = img0.shape
interp_imgs = []
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
# Normalize input images
with torch.set_grad_enabled(False):
tenStats = [img0, img1]
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
# Pad images for downsampling
img0, img1 = padder.pad(img0, img1)
N, _, H, W = img0.shape
for level in list(range(pyr_level))[::-1]:
# Downsample images if needed
if level != 0:
scale_factor = 1 / 2 ** level
img0_this_lvl = F.interpolate(
input=img0, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
img1_this_lvl = F.interpolate(
input=img1, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
else:
img0_this_lvl = img0
img1_this_lvl = img1
# Initialize zero flows for lowest pyramid level
if level == pyr_level - 1:
last_flow = torch.zeros(
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
)
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
else:
last_flow = flow
last_occ = occ
# Single pyramid level run
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
result_dict = {
"imgt_preds": interp_imgs,
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
}
return result_dict

115
bim_vfi_arch/bimfn.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
from .backwarp import backwarp
from .arch import LayerNorm, ResBlock
from .costvol import costvol_func
class BiMMConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_FV = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_r = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_phi = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
def forward(self, FV, r, phi):
FV_out1 = self.conv_FV(FV)
r_out = self.conv_r(r)
phi_out = self.conv_phi(phi)
FV_out2 = FV_out1 + FV_out1 * r_out * phi_out
return FV_out2, r_out, phi_out
class BiMFN(nn.Module):
def __init__(self, feat_channels):
super(BiMFN, self).__init__()
self.conv_flow = nn.Sequential(
nn.Conv2d(2, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_occ = nn.Sequential(
nn.Conv2d(1, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_corr = nn.Sequential(
nn.LeakyReLU(0.1),
nn.Conv2d(81, feat_channels * 2, 1, padding=0),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv0 = nn.Sequential(
nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0),
nn.PReLU(feat_channels * 12),
nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1),
)
self.conv1 = nn.Sequential(
nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0),
nn.PReLU(feat_channels * 8),
nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1),
)
self.conv2 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.conv3 = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.dem = nn.Sequential(
nn.Conv2d(1, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.aem = nn.Sequential(
nn.Conv2d(2, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6)
self.conv_out = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, 4, 1, padding=0))
def forward(self, feat0, feat1, r, phi, last_flow, last_occ):
feat0_warp = backwarp(feat0, (last_flow[:, :2]))
feat1_warp = backwarp(feat1, (last_flow[:, 2:]))
volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9)
volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9)
corr0 = self.conv_corr(volume0)
corr1 = self.conv_corr(volume1)
flo0 = self.conv_flow(last_flow[:, :2])
flo1 = self.conv_flow(last_flow[:, 2:])
occ = self.conv_occ(last_occ)
input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1)
FV = self.conv0(input_feat)
FV = self.conv1(FV)
FV = self.conv2(FV)
FV0 = self.conv3(FV)
r0 = self.dem(r)
phi0 = self.aem(phi)
bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0)
flow_res = self.conv_out(bim_feat)
flow_low = flow_res + last_flow
return flow_low, flow_res

72
bim_vfi_arch/caun.py Normal file
View File

@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backwarp import backwarp
class CAUN(nn.Module):
def __init__(self, feat_channels):
super(CAUN, self).__init__()
self.enc0 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc1 = nn.Sequential(
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc2 = nn.Sequential(
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
)
self.kernel_x2 = nn.Sequential(
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
)
self.kernel_x4 = nn.Sequential(
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
)
def upsample_input(self, inp, mask, upsample_factor):
N, c, H, W = inp.shape
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
mask = torch.softmax(mask, dim=2)
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
up_inp = F.unfold(inp, [3, 3])
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
up_inp = torch.sum(mask * up_inp, dim=2)
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
def forward(self, flow, feat0, feat1, last_occ):
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
N, _, H, W = flow.shape
feat0_warped_list, feat1_warped_list = [], []
for i in range(3):
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
feat0_warped_list.append(feat0_warped)
feat1_warped_list.append(feat1_warped)
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
feature0 = self.enc0(feature)
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
mask_x2 = self.kernel_x2(feature1)
mask_x4 = self.kernel_x4(feature2)
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
return [up_flow_x4, up_flow_x2, flow], up_occ

399
bim_vfi_arch/costvol.py Normal file
View File

@@ -0,0 +1,399 @@
#!/usr/bin/env python
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn:int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
if 'device' not in objCudacache:
objCudacache['device'] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
strKey += objCudacache['device']
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace('{{type}}', 'unsigned char')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace('{{type}}', 'half')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace('{{type}}', 'float')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace('{{type}}', 'double')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace('{{type}}', 'int')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace('{{type}}', 'long')
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert(False)
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
# end
objCudacache[strKey] = {
'strFunction': strFunction,
'strKernel': strKernel
}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
def cuda_launch(strKey:str):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = '/usr/local/cuda/'
# end
return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction'])
# end
##########################################################
class costvol_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo, intKernelSize):
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
cuda_launch(cuda_kernel('costvol_out', '''
extern "C" __global__ void __launch_bounds__(512) costvol_out(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const int intKernelSize,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
{{type}} fltValue = 0.0f;
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
}
}
tenOut[intOffset] = fltValue;
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'intKernelSize': intKernelSize,
'tenOut': tenOut
}))(
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
self.save_for_backward(tenOne, tenTwo)
self.intKernelSize = intKernelSize
return tenOut
# end
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
def backward(self, tenOutgrad):
tenOne, tenTwo = self.saved_tensors
intKernelSize = self.intKernelSize
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None
tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None
if tenOnegrad is not None:
cuda_launch(cuda_kernel('costvol_onegrad', '''
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad);
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad);
const int intX = ( intIndex ) % SIZE_3(tenOnegrad);
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
if tenTwograd is not None:
cuda_launch(cuda_kernel('costvol_twograd', '''
extern "C" __global__ void __launch_bounds__(512) costvol_twograd(
const int n,
const {{type}}* __restrict__ tenOne,
const {{type}}* __restrict__ tenTwo,
const {{type}}* __restrict__ tenOutgrad,
const int intKernelSize,
{{type}}* __restrict__ tenOnegrad,
{{type}}* __restrict__ tenTwograd
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd);
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd);
const int intX = ( intIndex ) % SIZE_3(tenTwograd);
{{type}} fltOne[{{intChans}}];
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
}
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]);
}
}
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
}
}
} }
''', {
'intChans': tenOne.shape[1],
'tenOne': tenOne,
'tenTwo': tenTwo,
'tenOutgrad': tenOutgrad,
'intKernelSize': intKernelSize,
'tenOnegrad': tenOnegrad,
'tenTwograd': tenTwograd
}))(
grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
return tenOnegrad, tenTwograd, None, None, None
# end
# end

View File

@@ -0,0 +1,102 @@
import torch.nn as nn
from typing import Optional, Callable
from torch import Tensor
from functools import partial
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=True,
dilation=dilation,
)
def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""2x2 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first')
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = norm_layer(inplanes)
if stride == 1:
self.conv1 = conv3x3(inplanes, planes, stride)
else:
self.conv1 = conv2x2(inplanes, planes, stride)
self.lrelu = nn.LeakyReLU(inplace=True)
self.bn2 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.lrelu(out)
out = self.bn2(out)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.lrelu(out)
out = out + identity
return out
class ResNetPyramid(nn.Module):
"""A 3-level feature pyramid, which by default is shared by the motion
estimator and synthesis network.
"""
def __init__(self, feat_channels: int):
super(ResNetPyramid, self).__init__()
self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1)
self.layer0 = nn.Sequential(
BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d),
)
self.layer1 = nn.Sequential(
nn.Conv2d(feat_channels, feat_channels * 2, 2, 2),
BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d),
)
self.layer2 = nn.Sequential(
nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2),
BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d),
)
self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1)
def forward(self, img):
C0 = self.layer0(self.conv(img))
C1 = self.layer1(C0)
C2 = self.conv_last(self.layer2(C1))
return [C0, C1, C2]

95
bim_vfi_arch/sn.py Normal file
View File

@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from .backwarp import backwarp
class SynthesisNetwork(nn.Module):
def __init__(self, feat_channels):
super(SynthesisNetwork, self).__init__()
input_channels = 6 + 1
self.conv_down1 = nn.Sequential(
nn.Conv2d(input_channels, feat_channels, 7, padding=3),
nn.PReLU(feat_channels),
nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_down2 = nn.Sequential(
nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_down3 = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4))
self.conv_up1 = nn.Sequential(
torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2))
self.conv_up2 = nn.Sequential(
torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(feat_channels * 1),
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1))
self.conv_up3 = nn.Sequential(
nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1)
def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None):
flow_t0 = bi_flow[:, :2]
flow_t1 = bi_flow[:, 2:4]
warped_c0 = backwarp(c0, flow_t0)
warped_c1 = backwarp(c1, flow_t1)
if (i0 is None) and (i1 is None):
return warped_c0, warped_c1
else:
warped_img0 = backwarp(i0, flow_t0)
warped_img1 = backwarp(i1, flow_t1)
return warped_img0, warped_img1, warped_c0, warped_c1
def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ):
warped_img0, warped_img1, warped_c0, warped_c1 = \
self.get_warped_representations(
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1)
input_feat = torch.cat(
(warped_img0, warped_img1, occ), 1)
s0 = self.conv_down1(input_feat)
s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1))
warped_c0, warped_c1 = self.get_warped_representations(
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None)
s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1))
warped_c0, warped_c1 = self.get_warped_representations(
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None)
x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1))
x = self.conv_up2(torch.cat((x, s1), 1))
x = self.conv_up3(torch.cat((x, s0), 1))
refine = self.conv_out(x)
refine_res = refine[:, :3]
occ_res = refine[:, 3:]
occ_out = occ + occ_res
blending_mask = torch.sigmoid(occ_out)
merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res
interp_img = merged_img
extra_dict = {}
extra_dict["refine_res"] = refine_res
extra_dict["refine_mask"] = occ_out
extra_dict["warped_img0"] = warped_img0
extra_dict["warped_img1"] = warped_img1
extra_dict["merged_img"] = merged_img
return interp_img, occ_out, extra_dict

81
inference.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
from .bim_vfi_arch import BiMVFI
class BiMVFIModel:
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
def __init__(self, checkpoint_path, pyr_level=3, device="cpu"):
self.pyr_level = pyr_level
self.device = device
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
self._load_checkpoint(checkpoint_path)
self.model.eval()
self.model.to(device)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Handle different checkpoint formats
if "model" in checkpoint:
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Strip common prefixes (e.g. "module." from DDP or "model." from wrapper)
cleaned = {}
for k, v in state_dict.items():
key = k
if key.startswith("module."):
key = key[len("module."):]
if key.startswith("model."):
key = key[len("model."):]
cleaned[key] = v
self.model.load_state_dict(cleaned)
def to(self, device):
self.device = device
self.model.to(device)
return self
@torch.no_grad()
def interpolate_pair(self, frame0, frame1, time_step=0.5):
"""Interpolate a single frame between two input frames.
Args:
frame0: [1, C, H, W] tensor, float32, range [0, 1]
frame1: [1, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1), temporal position of interpolated frame
Returns:
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
img0 = frame0.to(device)
img1 = frame1.to(device)
_, _, h, w = img0.shape
if h >= 2160:
pyr_level = 7
elif h >= 1080:
pyr_level = 6
elif h >= 540:
pyr_level = 5
else:
pyr_level = self.pyr_level
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
result_dict = self.model(
img0=img0, img1=img1,
time_step=time_step_tensor,
pyr_level=pyr_level,
)
interp = result_dict["imgt_pred"]
interp = torch.clamp(interp, 0, 1)
return interp

47
install.py Normal file
View File

@@ -0,0 +1,47 @@
import subprocess
import sys
import os
def get_cupy_package():
"""Detect PyTorch's CUDA version and return the matching cupy package name."""
try:
import torch
if not torch.cuda.is_available():
print("[BIM-VFI] WARNING: CUDA not available. cupy requires CUDA.")
return None
cuda_version = torch.version.cuda
if cuda_version is None:
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
return None
major = cuda_version.split(".")[0]
major = int(major)
cupy_pkg = f"cupy-cuda{major}x"
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
return cupy_pkg
except Exception as e:
print(f"[BIM-VFI] WARNING: Could not detect CUDA version: {e}")
return None
def install():
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-r", requirements_path
])
# Install cupy matching the current CUDA version
try:
import cupy
print("[BIM-VFI] cupy already installed, skipping.")
except ImportError:
cupy_pkg = get_cupy_package()
if cupy_pkg:
print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...")
subprocess.check_call([
sys.executable, "-m", "pip", "install", cupy_pkg
])
if __name__ == "__main__":
install()

169
nodes.py Normal file
View File

@@ -0,0 +1,169 @@
import os
import logging
import torch
import folder_paths
from comfy.utils import ProgressBar
from .inference import BiMVFIModel
from .bim_vfi_arch import clear_backwarp_cache
logger = logging.getLogger("BIM-VFI")
# Google Drive file ID for the pretrained model
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
MODEL_FILENAME = "bim_vfi.pth"
# Register the model folder with ComfyUI
MODEL_DIR = os.path.join(folder_paths.models_dir, "bim-vfi")
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR, exist_ok=True)
def get_available_models():
"""List available checkpoint files in the bim-vfi model directory."""
models = []
if os.path.isdir(MODEL_DIR):
for f in os.listdir(MODEL_DIR):
if f.endswith((".pth", ".pt", ".ckpt", ".safetensors")):
models.append(f)
if not models:
models.append(MODEL_FILENAME) # Will trigger auto-download
return sorted(models)
def download_model_from_gdrive(file_id, dest_path):
"""Download a file from Google Drive using gdown."""
try:
import gdown
except ImportError:
raise RuntimeError(
"gdown is required to auto-download the BIM-VFI model. "
"Install it with: pip install gdown"
)
url = f"https://drive.google.com/uc?id={file_id}"
logger.info(f"Downloading BIM-VFI model to {dest_path}...")
gdown.download(url, dest_path, quiet=False)
if not os.path.exists(dest_path):
raise RuntimeError(f"Failed to download model to {dest_path}")
logger.info("Download complete.")
class LoadBIMVFIModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": (get_available_models(), {"default": MODEL_FILENAME}),
"pyr_level": ("INT", {"default": 3, "min": 3, "max": 7, "step": 1}),
}
}
RETURN_TYPES = ("BIM_VFI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/BIM-VFI"
def load_model(self, model_path, pyr_level):
full_path = os.path.join(MODEL_DIR, model_path)
if not os.path.exists(full_path):
logger.info(f"Model not found at {full_path}, attempting download...")
download_model_from_gdrive(GDRIVE_FILE_ID, full_path)
wrapper = BiMVFIModel(
checkpoint_path=full_path,
pyr_level=pyr_level,
device="cpu",
)
logger.info(f"BIM-VFI model loaded (pyr_level={pyr_level})")
return (wrapper,)
class BIMVFIInterpolate:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"model": ("BIM_VFI_MODEL",),
"multiplier": ([2, 4, 8], {"default": 2}),
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1}),
"keep_device": ("BOOLEAN", {"default": True}),
"all_on_gpu": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "interpolate"
CATEGORY = "video/BIM-VFI"
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu):
if images.shape[0] < 2:
return (images,)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
# all_on_gpu implies keep_device
if all_on_gpu:
keep_device = True
# Where to store intermediate frames
storage_device = device if all_on_gpu else torch.device("cpu")
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
frames = images.permute(0, 3, 1, 2).to(storage_device)
# After each 2x pass, frame count = 2*N - 1, so compute total pairs across passes
n = frames.shape[0]
total_steps = 0
for _ in range(num_passes):
total_steps += n - 1
n = 2 * n - 1
pbar = ProgressBar(total_steps)
step = 0
if keep_device:
model.to(device)
for pass_idx in range(num_passes):
new_frames = []
num_pairs = frames.shape[0] - 1
for i in range(num_pairs):
frame0 = frames[i:i+1] # [1, C, H, W]
frame1 = frames[i+1:i+2] # [1, C, H, W]
if not keep_device:
model.to(device)
mid = model.interpolate_pair(frame0, frame1, time_step=0.5)
mid = mid.to(storage_device)
if not keep_device:
model.to("cpu")
new_frames.append(frames[i:i+1])
new_frames.append(mid)
step += 1
pbar.update_absolute(step, total_steps)
if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
clear_backwarp_cache()
torch.cuda.empty_cache()
# Append last frame
new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available():
clear_backwarp_cache()
torch.cuda.empty_cache()
# Convert back to ComfyUI [B, H, W, C], on CPU for ComfyUI
result = frames.cpu().permute(0, 2, 3, 1)
return (result,)

1
requirements.txt Normal file
View File

@@ -0,0 +1 @@
gdown

0
utils/__init__.py Normal file
View File

28
utils/padder.py Normal file
View File

@@ -0,0 +1,28 @@
import torch.nn.functional as F
class InputPadder:
""" Pads images such that dimensions are divisible by divisor """
def __init__(self, dims, divisor=16):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [0, pad_wd, 0, pad_ht]
def pad(self, *inputs):
if len(inputs) == 1:
return F.pad(inputs[0], self._pad, mode='constant')
else:
return [F.pad(x, self._pad, mode='constant') for x in inputs]
def unpad(self, *inputs):
if len(inputs) == 1:
return self._unpad(inputs[0])
else:
return [self._unpad(x) for x in inputs]
def _unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]