Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
11
__init__.py
Normal file
11
__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .nodes import LoadBIMVFIModel, BIMVFIInterpolate
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoadBIMVFIModel": LoadBIMVFIModel,
|
||||
"BIMVFIInterpolate": BIMVFIInterpolate,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoadBIMVFIModel": "Load BIM-VFI Model",
|
||||
"BIMVFIInterpolate": "BIM-VFI Interpolate",
|
||||
}
|
||||
2
bim_vfi_arch/__init__.py
Normal file
2
bim_vfi_arch/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .bim_vfi import BiMVFI
|
||||
from .backwarp import clear_backwarp_cache
|
||||
42
bim_vfi_arch/arch.py
Normal file
42
bim_vfi_arch/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
elif self.data_format == "channels_first":
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, feat_channels, kernel_size=3, padding_mode='zeros'):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
|
||||
padding_mode=padding_mode)
|
||||
self.act = nn.LeakyReLU()
|
||||
self.conv2 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
def forward(self, x):
|
||||
inp = x
|
||||
x = self.conv2(self.act(self.conv1(x)))
|
||||
return inp + x
|
||||
24
bim_vfi_arch/backwarp.py
Normal file
24
bim_vfi_arch/backwarp.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
objBackwarpcache = {}
|
||||
|
||||
|
||||
def clear_backwarp_cache():
|
||||
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
|
||||
objBackwarpcache.clear()
|
||||
|
||||
|
||||
def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'):
|
||||
if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache:
|
||||
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
|
||||
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
|
||||
objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1)
|
||||
|
||||
if tenFlow.shape[3] == tenFlow.shape[2]:
|
||||
tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0))
|
||||
|
||||
elif tenFlow.shape[3] != tenFlow.shape[2]:
|
||||
tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1)
|
||||
|
||||
|
||||
return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True)
|
||||
114
bim_vfi_arch/bim_vfi.py
Normal file
114
bim_vfi_arch/bim_vfi.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from .backwarp import backwarp
|
||||
from .resnet_encoder import ResNetPyramid
|
||||
from .caun import CAUN
|
||||
from .bimfn import BiMFN
|
||||
from .sn import SynthesisNetwork
|
||||
|
||||
from ..utils.padder import InputPadder
|
||||
|
||||
|
||||
class BiMVFI(nn.Module):
|
||||
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
|
||||
super(BiMVFI, self).__init__()
|
||||
self.pyr_level = pyr_level
|
||||
self.mfe = ResNetPyramid(feat_channels)
|
||||
self.cfe = ResNetPyramid(feat_channels)
|
||||
self.bimfn = BiMFN(feat_channels)
|
||||
self.sn = SynthesisNetwork(feat_channels)
|
||||
self.feat_channels = feat_channels
|
||||
self.caun = CAUN(feat_channels)
|
||||
|
||||
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
|
||||
feat0_pyr = self.mfe(img0)
|
||||
feat1_pyr = self.mfe(img1)
|
||||
cfeat0_pyr = self.cfe(img0)
|
||||
cfeat1_pyr = self.cfe(img1)
|
||||
|
||||
B, _, H, W = feat0_pyr[-1].shape
|
||||
|
||||
# Inference path: prepare uniform BiM
|
||||
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
|
||||
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
|
||||
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
|
||||
|
||||
last_flow = F.interpolate(
|
||||
input=last_flow.detach().clone(), scale_factor=0.5,
|
||||
mode="bilinear", align_corners=False) * 0.5
|
||||
last_occ = F.interpolate(
|
||||
input=last_occ.detach().clone(), scale_factor=0.5,
|
||||
mode="bilinear", align_corners=False)
|
||||
|
||||
flow_low, flow_res = self.bimfn(
|
||||
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
|
||||
|
||||
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
|
||||
flow = bi_flow_pyr[0]
|
||||
|
||||
interp_img, occ, extra_dict = self.sn(
|
||||
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
|
||||
extra_dict.update({'flow_res': flow_res})
|
||||
return flow, occ, interp_img, extra_dict
|
||||
|
||||
def forward(self, img0, img1, time_step,
|
||||
pyr_level=None, **kwargs):
|
||||
if pyr_level is None: pyr_level = self.pyr_level
|
||||
N, _, H, W = img0.shape
|
||||
interp_imgs = []
|
||||
|
||||
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
|
||||
|
||||
# Normalize input images
|
||||
with torch.set_grad_enabled(False):
|
||||
tenStats = [img0, img1]
|
||||
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
|
||||
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
|
||||
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
|
||||
|
||||
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
|
||||
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
|
||||
|
||||
# Pad images for downsampling
|
||||
img0, img1 = padder.pad(img0, img1)
|
||||
|
||||
N, _, H, W = img0.shape
|
||||
|
||||
for level in list(range(pyr_level))[::-1]:
|
||||
# Downsample images if needed
|
||||
if level != 0:
|
||||
scale_factor = 1 / 2 ** level
|
||||
img0_this_lvl = F.interpolate(
|
||||
input=img0, scale_factor=scale_factor,
|
||||
mode="bilinear", align_corners=False, antialias=True)
|
||||
img1_this_lvl = F.interpolate(
|
||||
input=img1, scale_factor=scale_factor,
|
||||
mode="bilinear", align_corners=False, antialias=True)
|
||||
else:
|
||||
img0_this_lvl = img0
|
||||
img1_this_lvl = img1
|
||||
|
||||
# Initialize zero flows for lowest pyramid level
|
||||
if level == pyr_level - 1:
|
||||
last_flow = torch.zeros(
|
||||
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
|
||||
)
|
||||
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
|
||||
else:
|
||||
last_flow = flow
|
||||
last_occ = occ
|
||||
|
||||
# Single pyramid level run
|
||||
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
|
||||
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
|
||||
|
||||
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
|
||||
|
||||
result_dict = {
|
||||
"imgt_preds": interp_imgs,
|
||||
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
|
||||
}
|
||||
|
||||
return result_dict
|
||||
115
bim_vfi_arch/bimfn.py
Normal file
115
bim_vfi_arch/bimfn.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .backwarp import backwarp
|
||||
from .arch import LayerNorm, ResBlock
|
||||
from .costvol import costvol_func
|
||||
|
||||
|
||||
class BiMMConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv_FV = nn.Sequential(
|
||||
LayerNorm(in_channels, data_format='channels_first'),
|
||||
ResBlock(out_channels, 1)
|
||||
)
|
||||
self.conv_r = nn.Sequential(
|
||||
LayerNorm(in_channels, data_format='channels_first'),
|
||||
ResBlock(out_channels, 1)
|
||||
)
|
||||
self.conv_phi = nn.Sequential(
|
||||
LayerNorm(in_channels, data_format='channels_first'),
|
||||
ResBlock(out_channels, 1)
|
||||
)
|
||||
|
||||
def forward(self, FV, r, phi):
|
||||
FV_out1 = self.conv_FV(FV)
|
||||
r_out = self.conv_r(r)
|
||||
phi_out = self.conv_phi(phi)
|
||||
FV_out2 = FV_out1 + FV_out1 * r_out * phi_out
|
||||
return FV_out2, r_out, phi_out
|
||||
|
||||
|
||||
class BiMFN(nn.Module):
|
||||
def __init__(self, feat_channels):
|
||||
super(BiMFN, self).__init__()
|
||||
self.conv_flow = nn.Sequential(
|
||||
nn.Conv2d(2, feat_channels * 2, 7, padding=3),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
)
|
||||
self.conv_occ = nn.Sequential(
|
||||
nn.Conv2d(1, feat_channels * 2, 7, padding=3),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
)
|
||||
self.conv_corr = nn.Sequential(
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(81, feat_channels * 2, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
)
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 12),
|
||||
nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1),
|
||||
)
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 8),
|
||||
nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 6),
|
||||
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
|
||||
)
|
||||
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 6),
|
||||
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
|
||||
)
|
||||
self.dem = nn.Sequential(
|
||||
nn.Conv2d(1, feat_channels * 4, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
|
||||
)
|
||||
self.aem = nn.Sequential(
|
||||
nn.Conv2d(2, feat_channels * 4, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
|
||||
)
|
||||
|
||||
self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6)
|
||||
|
||||
self.conv_out = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
nn.Conv2d(feat_channels * 4, 4, 1, padding=0))
|
||||
|
||||
def forward(self, feat0, feat1, r, phi, last_flow, last_occ):
|
||||
feat0_warp = backwarp(feat0, (last_flow[:, :2]))
|
||||
feat1_warp = backwarp(feat1, (last_flow[:, 2:]))
|
||||
volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9)
|
||||
volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9)
|
||||
corr0 = self.conv_corr(volume0)
|
||||
corr1 = self.conv_corr(volume1)
|
||||
flo0 = self.conv_flow(last_flow[:, :2])
|
||||
flo1 = self.conv_flow(last_flow[:, 2:])
|
||||
occ = self.conv_occ(last_occ)
|
||||
input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1)
|
||||
FV = self.conv0(input_feat)
|
||||
FV = self.conv1(FV)
|
||||
FV = self.conv2(FV)
|
||||
FV0 = self.conv3(FV)
|
||||
r0 = self.dem(r)
|
||||
phi0 = self.aem(phi)
|
||||
bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0)
|
||||
flow_res = self.conv_out(bim_feat)
|
||||
flow_low = flow_res + last_flow
|
||||
|
||||
return flow_low, flow_res
|
||||
72
bim_vfi_arch/caun.py
Normal file
72
bim_vfi_arch/caun.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .backwarp import backwarp
|
||||
|
||||
|
||||
class CAUN(nn.Module):
|
||||
def __init__(self, feat_channels):
|
||||
super(CAUN, self).__init__()
|
||||
self.enc0 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
)
|
||||
self.enc1 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
)
|
||||
self.enc2 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 1),
|
||||
)
|
||||
self.kernel_x2 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
|
||||
)
|
||||
self.kernel_x4 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 1),
|
||||
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
|
||||
)
|
||||
|
||||
def upsample_input(self, inp, mask, upsample_factor):
|
||||
N, c, H, W = inp.shape
|
||||
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
|
||||
up_inp = F.unfold(inp, [3, 3])
|
||||
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
|
||||
|
||||
up_inp = torch.sum(mask * up_inp, dim=2)
|
||||
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
|
||||
|
||||
def forward(self, flow, feat0, feat1, last_occ):
|
||||
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
feat0_warped_list, feat1_warped_list = [], []
|
||||
for i in range(3):
|
||||
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
|
||||
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
|
||||
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
|
||||
feat0_warped_list.append(feat0_warped)
|
||||
feat1_warped_list.append(feat1_warped)
|
||||
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
|
||||
feature0 = self.enc0(feature)
|
||||
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
|
||||
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
|
||||
mask_x2 = self.kernel_x2(feature1)
|
||||
mask_x4 = self.kernel_x4(feature2)
|
||||
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
|
||||
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
|
||||
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
|
||||
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
|
||||
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
|
||||
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
|
||||
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
|
||||
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
|
||||
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
|
||||
return [up_flow_x4, up_flow_x2, flow], up_occ
|
||||
399
bim_vfi_arch/costvol.py
Normal file
399
bim_vfi_arch/costvol.py
Normal file
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import collections
|
||||
import cupy
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import typing
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
objCudacache = {}
|
||||
|
||||
|
||||
def cuda_int32(intIn:int):
|
||||
return cupy.int32(intIn)
|
||||
# end
|
||||
|
||||
|
||||
def cuda_float32(fltIn:float):
|
||||
return cupy.float32(fltIn)
|
||||
# end
|
||||
|
||||
|
||||
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
||||
if 'device' not in objCudacache:
|
||||
objCudacache['device'] = torch.cuda.get_device_name()
|
||||
# end
|
||||
|
||||
strKey = strFunction
|
||||
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
strKey += strVariable
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKey += objValue
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
strKey += str(objValue.dtype)
|
||||
strKey += str(objValue.shape)
|
||||
strKey += str(objValue.stride())
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert(False)
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
strKey += objCudacache['device']
|
||||
|
||||
if strKey not in objCudacache:
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||
strKernel = strKernel.replace('{{type}}', 'unsigned char')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||
strKernel = strKernel.replace('{{type}}', 'half')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||
strKernel = strKernel.replace('{{type}}', 'float')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||
strKernel = strKernel.replace('{{type}}', 'double')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||
strKernel = strKernel.replace('{{type}}', 'int')
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||
strKernel = strKernel.replace('{{type}}', 'long')
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
print(strVariable, objValue.dtype)
|
||||
assert(False)
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert(False)
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intArg = int(objMatch.group(2))
|
||||
|
||||
strTensor = objMatch.group(4)
|
||||
intSizes = objVariables[strTensor].size()
|
||||
|
||||
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(',')
|
||||
|
||||
assert(intArgs == len(strArgs) - 1)
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(',')
|
||||
|
||||
assert(intArgs == len(strArgs) - 1)
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
|
||||
# end
|
||||
|
||||
objCudacache[strKey] = {
|
||||
'strFunction': strFunction,
|
||||
'strKernel': strKernel
|
||||
}
|
||||
# end
|
||||
|
||||
return strKey
|
||||
# end
|
||||
|
||||
|
||||
@cupy.memoize(for_each_device=True)
|
||||
def cuda_launch(strKey:str):
|
||||
if 'CUDA_HOME' not in os.environ:
|
||||
os.environ['CUDA_HOME'] = '/usr/local/cuda/'
|
||||
# end
|
||||
|
||||
return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction'])
|
||||
# end
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
class costvol_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||
|
||||
cuda_launch(cuda_kernel('costvol_out', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) costvol_out(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenOne,
|
||||
const {{type}}* __restrict__ tenTwo,
|
||||
const int intKernelSize,
|
||||
{{type}}* __restrict__ tenOut
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
|
||||
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||
|
||||
{{type}} fltOne[{{intChans}}];
|
||||
|
||||
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
|
||||
}
|
||||
|
||||
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
|
||||
|
||||
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||
{{type}} fltValue = 0.0f;
|
||||
|
||||
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
|
||||
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
|
||||
}
|
||||
}
|
||||
|
||||
tenOut[intOffset] = fltValue;
|
||||
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
|
||||
}
|
||||
}
|
||||
} }
|
||||
''', {
|
||||
'intChans': tenOne.shape[1],
|
||||
'tenOne': tenOne,
|
||||
'tenTwo': tenTwo,
|
||||
'intKernelSize': intKernelSize,
|
||||
'tenOut': tenOut
|
||||
}))(
|
||||
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
|
||||
self.save_for_backward(tenOne, tenTwo)
|
||||
self.intKernelSize = intKernelSize
|
||||
|
||||
return tenOut
|
||||
# end
|
||||
|
||||
@staticmethod
|
||||
@torch.amp.custom_bwd(device_type='cuda')
|
||||
def backward(self, tenOutgrad):
|
||||
tenOne, tenTwo = self.saved_tensors
|
||||
intKernelSize = self.intKernelSize
|
||||
|
||||
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
|
||||
|
||||
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None
|
||||
tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None
|
||||
|
||||
if tenOnegrad is not None:
|
||||
cuda_launch(cuda_kernel('costvol_onegrad', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenOne,
|
||||
const {{type}}* __restrict__ tenTwo,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
const int intKernelSize,
|
||||
{{type}}* __restrict__ tenOnegrad,
|
||||
{{type}}* __restrict__ tenTwograd
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad);
|
||||
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenOnegrad);
|
||||
|
||||
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
|
||||
|
||||
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
|
||||
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx);
|
||||
}
|
||||
}
|
||||
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
|
||||
}
|
||||
}
|
||||
} }
|
||||
''', {
|
||||
'intChans': tenOne.shape[1],
|
||||
'tenOne': tenOne,
|
||||
'tenTwo': tenTwo,
|
||||
'tenOutgrad': tenOutgrad,
|
||||
'intKernelSize': intKernelSize,
|
||||
'tenOnegrad': tenOnegrad,
|
||||
'tenTwograd': tenTwograd
|
||||
}))(
|
||||
grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
# end
|
||||
|
||||
if tenTwograd is not None:
|
||||
cuda_launch(cuda_kernel('costvol_twograd', '''
|
||||
extern "C" __global__ void __launch_bounds__(512) costvol_twograd(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenOne,
|
||||
const {{type}}* __restrict__ tenTwo,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
const int intKernelSize,
|
||||
{{type}}* __restrict__ tenOnegrad,
|
||||
{{type}}* __restrict__ tenTwograd
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd);
|
||||
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenTwograd);
|
||||
|
||||
{{type}} fltOne[{{intChans}}];
|
||||
|
||||
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
|
||||
}
|
||||
|
||||
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
|
||||
|
||||
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
|
||||
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]);
|
||||
}
|
||||
}
|
||||
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
|
||||
}
|
||||
}
|
||||
} }
|
||||
''', {
|
||||
'intChans': tenOne.shape[1],
|
||||
'tenOne': tenOne,
|
||||
'tenTwo': tenTwo,
|
||||
'tenOutgrad': tenOutgrad,
|
||||
'intKernelSize': intKernelSize,
|
||||
'tenOnegrad': tenOnegrad,
|
||||
'tenTwograd': tenTwograd
|
||||
}))(
|
||||
grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
|
||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||
)
|
||||
# end
|
||||
|
||||
return tenOnegrad, tenTwograd, None, None, None
|
||||
# end
|
||||
# end
|
||||
102
bim_vfi_arch/resnet_encoder.py
Normal file
102
bim_vfi_arch/resnet_encoder.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Callable
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
|
||||
def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""2x2 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first')
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.bn1 = norm_layer(inplanes)
|
||||
if stride == 1:
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
else:
|
||||
self.conv1 = conv2x2(inplanes, planes, stride)
|
||||
self.lrelu = nn.LeakyReLU(inplace=True)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.lrelu(out)
|
||||
|
||||
out = self.bn2(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.lrelu(out)
|
||||
out = out + identity
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetPyramid(nn.Module):
|
||||
"""A 3-level feature pyramid, which by default is shared by the motion
|
||||
estimator and synthesis network.
|
||||
"""
|
||||
|
||||
def __init__(self, feat_channels: int):
|
||||
super(ResNetPyramid, self).__init__()
|
||||
self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.layer0 = nn.Sequential(
|
||||
BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d),
|
||||
)
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels, feat_channels * 2, 2, 2),
|
||||
BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d),
|
||||
)
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2),
|
||||
BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d),
|
||||
)
|
||||
self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1)
|
||||
|
||||
def forward(self, img):
|
||||
C0 = self.layer0(self.conv(img))
|
||||
C1 = self.layer1(C0)
|
||||
C2 = self.conv_last(self.layer2(C1))
|
||||
return [C0, C1, C2]
|
||||
95
bim_vfi_arch/sn.py
Normal file
95
bim_vfi_arch/sn.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .backwarp import backwarp
|
||||
|
||||
|
||||
class SynthesisNetwork(nn.Module):
|
||||
def __init__(self, feat_channels):
|
||||
super(SynthesisNetwork, self).__init__()
|
||||
input_channels = 6 + 1
|
||||
self.conv_down1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, feat_channels, 7, padding=3),
|
||||
nn.PReLU(feat_channels),
|
||||
nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2))
|
||||
self.conv_down2 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2))
|
||||
self.conv_down3 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 4),
|
||||
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 4))
|
||||
self.conv_up1 = nn.Sequential(
|
||||
torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1),
|
||||
nn.PixelShuffle(upscale_factor=2),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2))
|
||||
self.conv_up2 = nn.Sequential(
|
||||
torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||
nn.PixelShuffle(upscale_factor=2),
|
||||
nn.PReLU(feat_channels * 1),
|
||||
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 1))
|
||||
self.conv_up3 = nn.Sequential(
|
||||
nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||
nn.PReLU(feat_channels * 2),
|
||||
)
|
||||
self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1)
|
||||
|
||||
def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None):
|
||||
flow_t0 = bi_flow[:, :2]
|
||||
flow_t1 = bi_flow[:, 2:4]
|
||||
warped_c0 = backwarp(c0, flow_t0)
|
||||
warped_c1 = backwarp(c1, flow_t1)
|
||||
if (i0 is None) and (i1 is None):
|
||||
return warped_c0, warped_c1
|
||||
else:
|
||||
warped_img0 = backwarp(i0, flow_t0)
|
||||
warped_img1 = backwarp(i1, flow_t1)
|
||||
return warped_img0, warped_img1, warped_c0, warped_c1
|
||||
|
||||
def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ):
|
||||
warped_img0, warped_img1, warped_c0, warped_c1 = \
|
||||
self.get_warped_representations(
|
||||
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1)
|
||||
input_feat = torch.cat(
|
||||
(warped_img0, warped_img1, occ), 1)
|
||||
s0 = self.conv_down1(input_feat)
|
||||
s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1))
|
||||
warped_c0, warped_c1 = self.get_warped_representations(
|
||||
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None)
|
||||
s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1))
|
||||
warped_c0, warped_c1 = self.get_warped_representations(
|
||||
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None)
|
||||
|
||||
x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1))
|
||||
x = self.conv_up2(torch.cat((x, s1), 1))
|
||||
x = self.conv_up3(torch.cat((x, s0), 1))
|
||||
|
||||
refine = self.conv_out(x)
|
||||
refine_res = refine[:, :3]
|
||||
occ_res = refine[:, 3:]
|
||||
occ_out = occ + occ_res
|
||||
blending_mask = torch.sigmoid(occ_out)
|
||||
merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res
|
||||
interp_img = merged_img
|
||||
|
||||
extra_dict = {}
|
||||
extra_dict["refine_res"] = refine_res
|
||||
extra_dict["refine_mask"] = occ_out
|
||||
extra_dict["warped_img0"] = warped_img0
|
||||
extra_dict["warped_img1"] = warped_img1
|
||||
extra_dict["merged_img"] = merged_img
|
||||
|
||||
return interp_img, occ_out, extra_dict
|
||||
81
inference.py
Normal file
81
inference.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from .bim_vfi_arch import BiMVFI
|
||||
|
||||
|
||||
class BiMVFIModel:
|
||||
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
|
||||
|
||||
def __init__(self, checkpoint_path, pyr_level=3, device="cpu"):
|
||||
self.pyr_level = pyr_level
|
||||
self.device = device
|
||||
|
||||
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
|
||||
self._load_checkpoint(checkpoint_path)
|
||||
self.model.eval()
|
||||
self.model.to(device)
|
||||
|
||||
def _load_checkpoint(self, checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if "model" in checkpoint:
|
||||
state_dict = checkpoint["model"]
|
||||
elif "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# Strip common prefixes (e.g. "module." from DDP or "model." from wrapper)
|
||||
cleaned = {}
|
||||
for k, v in state_dict.items():
|
||||
key = k
|
||||
if key.startswith("module."):
|
||||
key = key[len("module."):]
|
||||
if key.startswith("model."):
|
||||
key = key[len("model."):]
|
||||
cleaned[key] = v
|
||||
|
||||
self.model.load_state_dict(cleaned)
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
||||
"""Interpolate a single frame between two input frames.
|
||||
|
||||
Args:
|
||||
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
||||
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
||||
time_step: float in (0, 1), temporal position of interpolated frame
|
||||
|
||||
Returns:
|
||||
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
img0 = frame0.to(device)
|
||||
img1 = frame1.to(device)
|
||||
|
||||
_, _, h, w = img0.shape
|
||||
if h >= 2160:
|
||||
pyr_level = 7
|
||||
elif h >= 1080:
|
||||
pyr_level = 6
|
||||
elif h >= 540:
|
||||
pyr_level = 5
|
||||
else:
|
||||
pyr_level = self.pyr_level
|
||||
|
||||
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||
|
||||
result_dict = self.model(
|
||||
img0=img0, img1=img1,
|
||||
time_step=time_step_tensor,
|
||||
pyr_level=pyr_level,
|
||||
)
|
||||
|
||||
interp = result_dict["imgt_pred"]
|
||||
interp = torch.clamp(interp, 0, 1)
|
||||
return interp
|
||||
47
install.py
Normal file
47
install.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def get_cupy_package():
|
||||
"""Detect PyTorch's CUDA version and return the matching cupy package name."""
|
||||
try:
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
print("[BIM-VFI] WARNING: CUDA not available. cupy requires CUDA.")
|
||||
return None
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is None:
|
||||
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
|
||||
return None
|
||||
major = cuda_version.split(".")[0]
|
||||
major = int(major)
|
||||
cupy_pkg = f"cupy-cuda{major}x"
|
||||
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
|
||||
return cupy_pkg
|
||||
except Exception as e:
|
||||
print(f"[BIM-VFI] WARNING: Could not detect CUDA version: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def install():
|
||||
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
||||
])
|
||||
|
||||
# Install cupy matching the current CUDA version
|
||||
try:
|
||||
import cupy
|
||||
print("[BIM-VFI] cupy already installed, skipping.")
|
||||
except ImportError:
|
||||
cupy_pkg = get_cupy_package()
|
||||
if cupy_pkg:
|
||||
print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...")
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install", cupy_pkg
|
||||
])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
install()
|
||||
169
nodes.py
Normal file
169
nodes.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import folder_paths
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from .inference import BiMVFIModel
|
||||
from .bim_vfi_arch import clear_backwarp_cache
|
||||
|
||||
logger = logging.getLogger("BIM-VFI")
|
||||
|
||||
# Google Drive file ID for the pretrained model
|
||||
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
|
||||
MODEL_FILENAME = "bim_vfi.pth"
|
||||
|
||||
# Register the model folder with ComfyUI
|
||||
MODEL_DIR = os.path.join(folder_paths.models_dir, "bim-vfi")
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def get_available_models():
|
||||
"""List available checkpoint files in the bim-vfi model directory."""
|
||||
models = []
|
||||
if os.path.isdir(MODEL_DIR):
|
||||
for f in os.listdir(MODEL_DIR):
|
||||
if f.endswith((".pth", ".pt", ".ckpt", ".safetensors")):
|
||||
models.append(f)
|
||||
if not models:
|
||||
models.append(MODEL_FILENAME) # Will trigger auto-download
|
||||
return sorted(models)
|
||||
|
||||
|
||||
def download_model_from_gdrive(file_id, dest_path):
|
||||
"""Download a file from Google Drive using gdown."""
|
||||
try:
|
||||
import gdown
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"gdown is required to auto-download the BIM-VFI model. "
|
||||
"Install it with: pip install gdown"
|
||||
)
|
||||
url = f"https://drive.google.com/uc?id={file_id}"
|
||||
logger.info(f"Downloading BIM-VFI model to {dest_path}...")
|
||||
gdown.download(url, dest_path, quiet=False)
|
||||
if not os.path.exists(dest_path):
|
||||
raise RuntimeError(f"Failed to download model to {dest_path}")
|
||||
logger.info("Download complete.")
|
||||
|
||||
|
||||
class LoadBIMVFIModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model_path": (get_available_models(), {"default": MODEL_FILENAME}),
|
||||
"pyr_level": ("INT", {"default": 3, "min": 3, "max": 7, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BIM_VFI_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "video/BIM-VFI"
|
||||
|
||||
def load_model(self, model_path, pyr_level):
|
||||
full_path = os.path.join(MODEL_DIR, model_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
logger.info(f"Model not found at {full_path}, attempting download...")
|
||||
download_model_from_gdrive(GDRIVE_FILE_ID, full_path)
|
||||
|
||||
wrapper = BiMVFIModel(
|
||||
checkpoint_path=full_path,
|
||||
pyr_level=pyr_level,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
logger.info(f"BIM-VFI model loaded (pyr_level={pyr_level})")
|
||||
return (wrapper,)
|
||||
|
||||
|
||||
class BIMVFIInterpolate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"model": ("BIM_VFI_MODEL",),
|
||||
"multiplier": ([2, 4, 8], {"default": 2}),
|
||||
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1}),
|
||||
"keep_device": ("BOOLEAN", {"default": True}),
|
||||
"all_on_gpu": ("BOOLEAN", {"default": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "interpolate"
|
||||
CATEGORY = "video/BIM-VFI"
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
# all_on_gpu implies keep_device
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
|
||||
# Where to store intermediate frames
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
|
||||
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
|
||||
frames = images.permute(0, 3, 1, 2).to(storage_device)
|
||||
|
||||
# After each 2x pass, frame count = 2*N - 1, so compute total pairs across passes
|
||||
n = frames.shape[0]
|
||||
total_steps = 0
|
||||
for _ in range(num_passes):
|
||||
total_steps += n - 1
|
||||
n = 2 * n - 1
|
||||
|
||||
pbar = ProgressBar(total_steps)
|
||||
step = 0
|
||||
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
for pass_idx in range(num_passes):
|
||||
new_frames = []
|
||||
num_pairs = frames.shape[0] - 1
|
||||
|
||||
for i in range(num_pairs):
|
||||
frame0 = frames[i:i+1] # [1, C, H, W]
|
||||
frame1 = frames[i+1:i+2] # [1, C, H, W]
|
||||
|
||||
if not keep_device:
|
||||
model.to(device)
|
||||
|
||||
mid = model.interpolate_pair(frame0, frame1, time_step=0.5)
|
||||
mid = mid.to(storage_device)
|
||||
|
||||
if not keep_device:
|
||||
model.to("cpu")
|
||||
|
||||
new_frames.append(frames[i:i+1])
|
||||
new_frames.append(mid)
|
||||
|
||||
step += 1
|
||||
pbar.update_absolute(step, total_steps)
|
||||
|
||||
if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
|
||||
clear_backwarp_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Append last frame
|
||||
new_frames.append(frames[-1:])
|
||||
frames = torch.cat(new_frames, dim=0)
|
||||
|
||||
if not all_on_gpu and torch.cuda.is_available():
|
||||
clear_backwarp_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU for ComfyUI
|
||||
result = frames.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
gdown
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/padder.py
Normal file
28
utils/padder.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by divisor """
|
||||
|
||||
def __init__(self, dims, divisor=16):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
||||
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
||||
self._pad = [0, pad_wd, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
if len(inputs) == 1:
|
||||
return F.pad(inputs[0], self._pad, mode='constant')
|
||||
else:
|
||||
return [F.pad(x, self._pad, mode='constant') for x in inputs]
|
||||
|
||||
def unpad(self, *inputs):
|
||||
if len(inputs) == 1:
|
||||
return self._unpad(inputs[0])
|
||||
else:
|
||||
return [self._unpad(x) for x in inputs]
|
||||
|
||||
def _unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
Reference in New Issue
Block a user