From db64fc195a4657f298311ece31e92d324a45fe0e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 12 Feb 2026 18:26:49 +0100 Subject: [PATCH] Initial commit: ComfyUI BIM-VFI node for video frame interpolation Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 --- __init__.py | 11 + bim_vfi_arch/__init__.py | 2 + bim_vfi_arch/arch.py | 42 ++++ bim_vfi_arch/backwarp.py | 24 ++ bim_vfi_arch/bim_vfi.py | 114 ++++++++++ bim_vfi_arch/bimfn.py | 115 ++++++++++ bim_vfi_arch/caun.py | 72 ++++++ bim_vfi_arch/costvol.py | 399 +++++++++++++++++++++++++++++++++ bim_vfi_arch/resnet_encoder.py | 102 +++++++++ bim_vfi_arch/sn.py | 95 ++++++++ inference.py | 81 +++++++ install.py | 47 ++++ nodes.py | 169 ++++++++++++++ requirements.txt | 1 + utils/__init__.py | 0 utils/padder.py | 28 +++ 16 files changed, 1302 insertions(+) create mode 100644 __init__.py create mode 100644 bim_vfi_arch/__init__.py create mode 100644 bim_vfi_arch/arch.py create mode 100644 bim_vfi_arch/backwarp.py create mode 100644 bim_vfi_arch/bim_vfi.py create mode 100644 bim_vfi_arch/bimfn.py create mode 100644 bim_vfi_arch/caun.py create mode 100644 bim_vfi_arch/costvol.py create mode 100644 bim_vfi_arch/resnet_encoder.py create mode 100644 bim_vfi_arch/sn.py create mode 100644 inference.py create mode 100644 install.py create mode 100644 nodes.py create mode 100644 requirements.txt create mode 100644 utils/__init__.py create mode 100644 utils/padder.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..a5ea0a9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,11 @@ +from .nodes import LoadBIMVFIModel, BIMVFIInterpolate + +NODE_CLASS_MAPPINGS = { + "LoadBIMVFIModel": LoadBIMVFIModel, + "BIMVFIInterpolate": BIMVFIInterpolate, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LoadBIMVFIModel": "Load BIM-VFI Model", + "BIMVFIInterpolate": "BIM-VFI Interpolate", +} diff --git a/bim_vfi_arch/__init__.py b/bim_vfi_arch/__init__.py new file mode 100644 index 0000000..459ec98 --- /dev/null +++ b/bim_vfi_arch/__init__.py @@ -0,0 +1,2 @@ +from .bim_vfi import BiMVFI +from .backwarp import clear_backwarp_cache diff --git a/bim_vfi_arch/arch.py b/bim_vfi_arch/arch.py new file mode 100644 index 0000000..371b1d7 --- /dev/null +++ b/bim_vfi_arch/arch.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + x = x.permute(0, 2, 3, 1) + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + +class ResBlock(nn.Module): + def __init__(self, feat_channels, kernel_size=3, padding_mode='zeros'): + super().__init__() + self.conv1 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2, + padding_mode=padding_mode) + self.act = nn.LeakyReLU() + self.conv2 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2, + padding_mode=padding_mode) + + def forward(self, x): + inp = x + x = self.conv2(self.act(self.conv1(x))) + return inp + x diff --git a/bim_vfi_arch/backwarp.py b/bim_vfi_arch/backwarp.py new file mode 100644 index 0000000..8a32a16 --- /dev/null +++ b/bim_vfi_arch/backwarp.py @@ -0,0 +1,24 @@ +import torch + +objBackwarpcache = {} + + +def clear_backwarp_cache(): + """Free all cached grid tensors (call between frame pairs to reclaim VRAM).""" + objBackwarpcache.clear() + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True) diff --git a/bim_vfi_arch/bim_vfi.py b/bim_vfi_arch/bim_vfi.py new file mode 100644 index 0000000..fe3ca7b --- /dev/null +++ b/bim_vfi_arch/bim_vfi.py @@ -0,0 +1,114 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn + +from .backwarp import backwarp +from .resnet_encoder import ResNetPyramid +from .caun import CAUN +from .bimfn import BiMFN +from .sn import SynthesisNetwork + +from ..utils.padder import InputPadder + + +class BiMVFI(nn.Module): + def __init__(self, pyr_level=3, feat_channels=32, **kwargs): + super(BiMVFI, self).__init__() + self.pyr_level = pyr_level + self.mfe = ResNetPyramid(feat_channels) + self.cfe = ResNetPyramid(feat_channels) + self.bimfn = BiMFN(feat_channels) + self.sn = SynthesisNetwork(feat_channels) + self.feat_channels = feat_channels + self.caun = CAUN(feat_channels) + + def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5): + feat0_pyr = self.mfe(img0) + feat1_pyr = self.mfe(img1) + cfeat0_pyr = self.cfe(img0) + cfeat1_pyr = self.cfe(img1) + + B, _, H, W = feat0_pyr[-1].shape + + # Inference path: prepare uniform BiM + r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period + phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi + phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1) + + last_flow = F.interpolate( + input=last_flow.detach().clone(), scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + last_occ = F.interpolate( + input=last_occ.detach().clone(), scale_factor=0.5, + mode="bilinear", align_corners=False) + + flow_low, flow_res = self.bimfn( + feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ) + + bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ) + flow = bi_flow_pyr[0] + + interp_img, occ, extra_dict = self.sn( + img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ) + extra_dict.update({'flow_res': flow_res}) + return flow, occ, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, **kwargs): + if pyr_level is None: pyr_level = self.pyr_level + N, _, H, W = img0.shape + interp_imgs = [] + + padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1))) + + # Normalize input images + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + # Pad images for downsampling + img0, img1 = padder.pad(img0, img1) + + N, _, H, W = img0.shape + + for level in list(range(pyr_level))[::-1]: + # Downsample images if needed + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False, antialias=True) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False, antialias=True) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # Initialize zero flows for lowest pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device + ) + last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device) + else: + last_flow = flow + last_occ = occ + + # Single pyramid level run + flow, occ, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step) + + interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_) + + result_dict = { + "imgt_preds": interp_imgs, + 'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()), + } + + return result_dict diff --git a/bim_vfi_arch/bimfn.py b/bim_vfi_arch/bimfn.py new file mode 100644 index 0000000..274087e --- /dev/null +++ b/bim_vfi_arch/bimfn.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn + +from .backwarp import backwarp +from .arch import LayerNorm, ResBlock +from .costvol import costvol_func + + +class BiMMConv(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv_FV = nn.Sequential( + LayerNorm(in_channels, data_format='channels_first'), + ResBlock(out_channels, 1) + ) + self.conv_r = nn.Sequential( + LayerNorm(in_channels, data_format='channels_first'), + ResBlock(out_channels, 1) + ) + self.conv_phi = nn.Sequential( + LayerNorm(in_channels, data_format='channels_first'), + ResBlock(out_channels, 1) + ) + + def forward(self, FV, r, phi): + FV_out1 = self.conv_FV(FV) + r_out = self.conv_r(r) + phi_out = self.conv_phi(phi) + FV_out2 = FV_out1 + FV_out1 * r_out * phi_out + return FV_out2, r_out, phi_out + + +class BiMFN(nn.Module): + def __init__(self, feat_channels): + super(BiMFN, self).__init__() + self.conv_flow = nn.Sequential( + nn.Conv2d(2, feat_channels * 2, 7, padding=3), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + ) + self.conv_occ = nn.Sequential( + nn.Conv2d(1, feat_channels * 2, 7, padding=3), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + ) + self.conv_corr = nn.Sequential( + nn.LeakyReLU(0.1), + nn.Conv2d(81, feat_channels * 2, 1, padding=0), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + ) + self.conv0 = nn.Sequential( + nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0), + nn.PReLU(feat_channels * 12), + nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1), + ) + self.conv1 = nn.Sequential( + nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0), + nn.PReLU(feat_channels * 8), + nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0), + nn.PReLU(feat_channels * 6), + nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1), + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0), + nn.PReLU(feat_channels * 6), + nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1), + ) + self.dem = nn.Sequential( + nn.Conv2d(1, feat_channels * 4, 1, padding=0), + nn.PReLU(feat_channels * 4), + nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0), + ) + self.aem = nn.Sequential( + nn.Conv2d(2, feat_channels * 4, 1, padding=0), + nn.PReLU(feat_channels * 4), + nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0), + ) + + self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6) + + self.conv_out = nn.Sequential( + nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0), + nn.PReLU(feat_channels * 4), + nn.Conv2d(feat_channels * 4, 4, 1, padding=0)) + + def forward(self, feat0, feat1, r, phi, last_flow, last_occ): + feat0_warp = backwarp(feat0, (last_flow[:, :2])) + feat1_warp = backwarp(feat1, (last_flow[:, 2:])) + volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9) + volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo0 = self.conv_flow(last_flow[:, :2]) + flo1 = self.conv_flow(last_flow[:, 2:]) + occ = self.conv_occ(last_occ) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1) + FV = self.conv0(input_feat) + FV = self.conv1(FV) + FV = self.conv2(FV) + FV0 = self.conv3(FV) + r0 = self.dem(r) + phi0 = self.aem(phi) + bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0) + flow_res = self.conv_out(bim_feat) + flow_low = flow_res + last_flow + + return flow_low, flow_res diff --git a/bim_vfi_arch/caun.py b/bim_vfi_arch/caun.py new file mode 100644 index 0000000..1f9ce58 --- /dev/null +++ b/bim_vfi_arch/caun.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .backwarp import backwarp + + +class CAUN(nn.Module): + def __init__(self, feat_channels): + super(CAUN, self).__init__() + self.enc0 = nn.Sequential( + nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1), + nn.PReLU(feat_channels * 4), + ) + self.enc1 = nn.Sequential( + nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1), + nn.PReLU(feat_channels * 4), + ) + self.enc2 = nn.Sequential( + nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1), + nn.PReLU(feat_channels * 1), + ) + self.kernel_x2 = nn.Sequential( + nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1) + ) + self.kernel_x4 = nn.Sequential( + nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1), + nn.PReLU(feat_channels * 1), + nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1) + ) + + def upsample_input(self, inp, mask, upsample_factor): + N, c, H, W = inp.shape + mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W) + mask = torch.softmax(mask, dim=2) + inp = F.pad(inp, [1, 1, 1, 1], mode='replicate') + up_inp = F.unfold(inp, [3, 3]) + up_inp = up_inp.view(N, c, 9, 1, 1, H, W) + + up_inp = torch.sum(mask * up_inp, dim=2) + up_inp = up_inp.permute(0, 1, 4, 2, 5, 3) + return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W) + + def forward(self, flow, feat0, feat1, last_occ): + """ Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """ + N, _, H, W = flow.shape + feat0_warped_list, feat1_warped_list = [], [] + for i in range(3): + flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear') + feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2]) + feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:]) + feat0_warped_list.append(feat0_warped) + feat1_warped_list.append(feat1_warped) + feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1) + feature0 = self.enc0(feature) + feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1)) + feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1)) + mask_x2 = self.kernel_x2(feature1) + mask_x4 = self.kernel_x4(feature2) + mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous() + mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1) + mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous() + mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1) + up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2) + up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2) + up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4) + up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4) + up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1) + up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1) + up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear') + return [up_flow_x4, up_flow_x2, flow], up_occ diff --git a/bim_vfi_arch/costvol.py b/bim_vfi_arch/costvol.py new file mode 100644 index 0000000..e094dc3 --- /dev/null +++ b/bim_vfi_arch/costvol.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo, intKernelSize): + tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const int intKernelSize, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) { + for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } + + tenOut[intOffset] = fltValue; + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'intKernelSize': intKernelSize, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + self.intKernelSize = intKernelSize + + return tenOut + # end + + @staticmethod + @torch.amp.custom_bwd(device_type='cuda') + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + intKernelSize = self.intKernelSize + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + const int intKernelSize, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) { + for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx); + } + } + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'intKernelSize': intKernelSize, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + const int intKernelSize, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) { + for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]); + } + } + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'intKernelSize': intKernelSize, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None, None + # end +# end diff --git a/bim_vfi_arch/resnet_encoder.py b/bim_vfi_arch/resnet_encoder.py new file mode 100644 index 0000000..e008027 --- /dev/null +++ b/bim_vfi_arch/resnet_encoder.py @@ -0,0 +1,102 @@ +import torch.nn as nn +from typing import Optional, Callable +from torch import Tensor +from functools import partial + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=True, + dilation=dilation, + ) + + +def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """2x2 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first') + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = norm_layer(inplanes) + if stride == 1: + self.conv1 = conv3x3(inplanes, planes, stride) + else: + self.conv1 = conv2x2(inplanes, planes, stride) + self.lrelu = nn.LeakyReLU(inplace=True) + self.bn2 = norm_layer(planes) + self.conv2 = conv3x3(planes, planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.bn1(x) + out = self.conv1(out) + out = self.lrelu(out) + + out = self.bn2(out) + out = self.conv2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.lrelu(out) + out = out + identity + + return out + + +class ResNetPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self, feat_channels: int): + super(ResNetPyramid, self).__init__() + self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1) + self.layer0 = nn.Sequential( + BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d), + ) + self.layer1 = nn.Sequential( + nn.Conv2d(feat_channels, feat_channels * 2, 2, 2), + BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d), + ) + self.layer2 = nn.Sequential( + nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2), + BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d), + ) + self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1) + + def forward(self, img): + C0 = self.layer0(self.conv(img)) + C1 = self.layer1(C0) + C2 = self.conv_last(self.layer2(C1)) + return [C0, C1, C2] diff --git a/bim_vfi_arch/sn.py b/bim_vfi_arch/sn.py new file mode 100644 index 0000000..9aa268a --- /dev/null +++ b/bim_vfi_arch/sn.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +from .backwarp import backwarp + + +class SynthesisNetwork(nn.Module): + def __init__(self, feat_channels): + super(SynthesisNetwork, self).__init__() + input_channels = 6 + 1 + self.conv_down1 = nn.Sequential( + nn.Conv2d(input_channels, feat_channels, 7, padding=3), + nn.PReLU(feat_channels), + nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2)) + self.conv_down2 = nn.Sequential( + nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2)) + self.conv_down3 = nn.Sequential( + nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0), + nn.PReLU(feat_channels * 4), + nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1), + nn.PReLU(feat_channels * 4), + nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1), + nn.PReLU(feat_channels * 4)) + self.conv_up1 = nn.Sequential( + torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1), + nn.PixelShuffle(upscale_factor=2), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2)) + self.conv_up2 = nn.Sequential( + torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1), + nn.PixelShuffle(upscale_factor=2), + nn.PReLU(feat_channels * 1), + nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1), + nn.PReLU(feat_channels * 1)) + self.conv_up3 = nn.Sequential( + nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1), + nn.PReLU(feat_channels * 2), + ) + self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1) + + def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None): + flow_t0 = bi_flow[:, :2] + flow_t1 = bi_flow[:, 2:4] + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + return warped_img0, warped_img1, warped_c0, warped_c1 + + def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ): + warped_img0, warped_img1, warped_c0, warped_c1 = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1) + input_feat = torch.cat( + (warped_img0, warped_img1, occ), 1) + s0 = self.conv_down1(input_feat) + s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None) + s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None) + + x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.conv_up2(torch.cat((x, s1), 1)) + x = self.conv_up3(torch.cat((x, s0), 1)) + + refine = self.conv_out(x) + refine_res = refine[:, :3] + occ_res = refine[:, 3:] + occ_out = occ + occ_res + blending_mask = torch.sigmoid(occ_out) + merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res + interp_img = merged_img + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = occ_out + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + + return interp_img, occ_out, extra_dict diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..0ea7b65 --- /dev/null +++ b/inference.py @@ -0,0 +1,81 @@ +import torch +from .bim_vfi_arch import BiMVFI + + +class BiMVFIModel: + """Clean inference wrapper around BiMVFI for ComfyUI integration.""" + + def __init__(self, checkpoint_path, pyr_level=3, device="cpu"): + self.pyr_level = pyr_level + self.device = device + + self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32) + self._load_checkpoint(checkpoint_path) + self.model.eval() + self.model.to(device) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Handle different checkpoint formats + if "model" in checkpoint: + state_dict = checkpoint["model"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # Strip common prefixes (e.g. "module." from DDP or "model." from wrapper) + cleaned = {} + for k, v in state_dict.items(): + key = k + if key.startswith("module."): + key = key[len("module."):] + if key.startswith("model."): + key = key[len("model."):] + cleaned[key] = v + + self.model.load_state_dict(cleaned) + + def to(self, device): + self.device = device + self.model.to(device) + return self + + @torch.no_grad() + def interpolate_pair(self, frame0, frame1, time_step=0.5): + """Interpolate a single frame between two input frames. + + Args: + frame0: [1, C, H, W] tensor, float32, range [0, 1] + frame1: [1, C, H, W] tensor, float32, range [0, 1] + time_step: float in (0, 1), temporal position of interpolated frame + + Returns: + Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + img0 = frame0.to(device) + img1 = frame1.to(device) + + _, _, h, w = img0.shape + if h >= 2160: + pyr_level = 7 + elif h >= 1080: + pyr_level = 6 + elif h >= 540: + pyr_level = 5 + else: + pyr_level = self.pyr_level + + time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1) + + result_dict = self.model( + img0=img0, img1=img1, + time_step=time_step_tensor, + pyr_level=pyr_level, + ) + + interp = result_dict["imgt_pred"] + interp = torch.clamp(interp, 0, 1) + return interp diff --git a/install.py b/install.py new file mode 100644 index 0000000..35175ee --- /dev/null +++ b/install.py @@ -0,0 +1,47 @@ +import subprocess +import sys +import os + + +def get_cupy_package(): + """Detect PyTorch's CUDA version and return the matching cupy package name.""" + try: + import torch + if not torch.cuda.is_available(): + print("[BIM-VFI] WARNING: CUDA not available. cupy requires CUDA.") + return None + cuda_version = torch.version.cuda + if cuda_version is None: + print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.") + return None + major = cuda_version.split(".")[0] + major = int(major) + cupy_pkg = f"cupy-cuda{major}x" + print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}") + return cupy_pkg + except Exception as e: + print(f"[BIM-VFI] WARNING: Could not detect CUDA version: {e}") + return None + + +def install(): + requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-r", requirements_path + ]) + + # Install cupy matching the current CUDA version + try: + import cupy + print("[BIM-VFI] cupy already installed, skipping.") + except ImportError: + cupy_pkg = get_cupy_package() + if cupy_pkg: + print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...") + subprocess.check_call([ + sys.executable, "-m", "pip", "install", cupy_pkg + ]) + + +if __name__ == "__main__": + install() diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..b84a888 --- /dev/null +++ b/nodes.py @@ -0,0 +1,169 @@ +import os +import logging +import torch +import folder_paths +from comfy.utils import ProgressBar + +from .inference import BiMVFIModel +from .bim_vfi_arch import clear_backwarp_cache + +logger = logging.getLogger("BIM-VFI") + +# Google Drive file ID for the pretrained model +GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC" +MODEL_FILENAME = "bim_vfi.pth" + +# Register the model folder with ComfyUI +MODEL_DIR = os.path.join(folder_paths.models_dir, "bim-vfi") +if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR, exist_ok=True) + + +def get_available_models(): + """List available checkpoint files in the bim-vfi model directory.""" + models = [] + if os.path.isdir(MODEL_DIR): + for f in os.listdir(MODEL_DIR): + if f.endswith((".pth", ".pt", ".ckpt", ".safetensors")): + models.append(f) + if not models: + models.append(MODEL_FILENAME) # Will trigger auto-download + return sorted(models) + + +def download_model_from_gdrive(file_id, dest_path): + """Download a file from Google Drive using gdown.""" + try: + import gdown + except ImportError: + raise RuntimeError( + "gdown is required to auto-download the BIM-VFI model. " + "Install it with: pip install gdown" + ) + url = f"https://drive.google.com/uc?id={file_id}" + logger.info(f"Downloading BIM-VFI model to {dest_path}...") + gdown.download(url, dest_path, quiet=False) + if not os.path.exists(dest_path): + raise RuntimeError(f"Failed to download model to {dest_path}") + logger.info("Download complete.") + + +class LoadBIMVFIModel: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_path": (get_available_models(), {"default": MODEL_FILENAME}), + "pyr_level": ("INT", {"default": 3, "min": 3, "max": 7, "step": 1}), + } + } + + RETURN_TYPES = ("BIM_VFI_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "video/BIM-VFI" + + def load_model(self, model_path, pyr_level): + full_path = os.path.join(MODEL_DIR, model_path) + + if not os.path.exists(full_path): + logger.info(f"Model not found at {full_path}, attempting download...") + download_model_from_gdrive(GDRIVE_FILE_ID, full_path) + + wrapper = BiMVFIModel( + checkpoint_path=full_path, + pyr_level=pyr_level, + device="cpu", + ) + + logger.info(f"BIM-VFI model loaded (pyr_level={pyr_level})") + return (wrapper,) + + +class BIMVFIInterpolate: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "model": ("BIM_VFI_MODEL",), + "multiplier": ([2, 4, 8], {"default": 2}), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1}), + "keep_device": ("BOOLEAN", {"default": True}), + "all_on_gpu": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "interpolate" + CATEGORY = "video/BIM-VFI" + + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu): + if images.shape[0] < 2: + return (images,) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + + # all_on_gpu implies keep_device + if all_on_gpu: + keep_device = True + + # Where to store intermediate frames + storage_device = device if all_on_gpu else torch.device("cpu") + + # Convert from ComfyUI [B, H, W, C] to model [B, C, H, W] + frames = images.permute(0, 3, 1, 2).to(storage_device) + + # After each 2x pass, frame count = 2*N - 1, so compute total pairs across passes + n = frames.shape[0] + total_steps = 0 + for _ in range(num_passes): + total_steps += n - 1 + n = 2 * n - 1 + + pbar = ProgressBar(total_steps) + step = 0 + + if keep_device: + model.to(device) + + for pass_idx in range(num_passes): + new_frames = [] + num_pairs = frames.shape[0] - 1 + + for i in range(num_pairs): + frame0 = frames[i:i+1] # [1, C, H, W] + frame1 = frames[i+1:i+2] # [1, C, H, W] + + if not keep_device: + model.to(device) + + mid = model.interpolate_pair(frame0, frame1, time_step=0.5) + mid = mid.to(storage_device) + + if not keep_device: + model.to("cpu") + + new_frames.append(frames[i:i+1]) + new_frames.append(mid) + + step += 1 + pbar.update_absolute(step, total_steps) + + if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available(): + clear_backwarp_cache() + torch.cuda.empty_cache() + + # Append last frame + new_frames.append(frames[-1:]) + frames = torch.cat(new_frames, dim=0) + + if not all_on_gpu and torch.cuda.is_available(): + clear_backwarp_cache() + torch.cuda.empty_cache() + + # Convert back to ComfyUI [B, H, W, C], on CPU for ComfyUI + result = frames.cpu().permute(0, 2, 3, 1) + return (result,) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..19387ac --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +gdown diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/padder.py b/utils/padder.py new file mode 100644 index 0000000..e4ecfc5 --- /dev/null +++ b/utils/padder.py @@ -0,0 +1,28 @@ +import torch.nn.functional as F + + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [0, pad_wd, 0, pad_ht] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode='constant') + else: + return [F.pad(x, self._pad, mode='constant') for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]]