Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
11
__init__.py
Normal file
11
__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from .nodes import LoadBIMVFIModel, BIMVFIInterpolate
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoadBIMVFIModel": LoadBIMVFIModel,
|
||||||
|
"BIMVFIInterpolate": BIMVFIInterpolate,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"LoadBIMVFIModel": "Load BIM-VFI Model",
|
||||||
|
"BIMVFIInterpolate": "BIM-VFI Interpolate",
|
||||||
|
}
|
||||||
2
bim_vfi_arch/__init__.py
Normal file
2
bim_vfi_arch/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .bim_vfi import BiMVFI
|
||||||
|
from .backwarp import clear_backwarp_cache
|
||||||
42
bim_vfi_arch/arch.py
Normal file
42
bim_vfi_arch/arch.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||||
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||||
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||||
|
with shape (batch_size, channels, height, width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||||
|
self.eps = eps
|
||||||
|
self.data_format = data_format
|
||||||
|
if self.data_format not in ["channels_last", "channels_first"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.normalized_shape = (normalized_shape,)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.data_format == "channels_last":
|
||||||
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
elif self.data_format == "channels_first":
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, feat_channels, kernel_size=3, padding_mode='zeros'):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
|
||||||
|
padding_mode=padding_mode)
|
||||||
|
self.act = nn.LeakyReLU()
|
||||||
|
self.conv2 = nn.Conv2d(feat_channels, feat_channels, kernel_size, padding=(kernel_size - 1) // 2,
|
||||||
|
padding_mode=padding_mode)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
inp = x
|
||||||
|
x = self.conv2(self.act(self.conv1(x)))
|
||||||
|
return inp + x
|
||||||
24
bim_vfi_arch/backwarp.py
Normal file
24
bim_vfi_arch/backwarp.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
objBackwarpcache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def clear_backwarp_cache():
|
||||||
|
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
|
||||||
|
objBackwarpcache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'):
|
||||||
|
if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache:
|
||||||
|
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
|
||||||
|
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
|
||||||
|
objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1)
|
||||||
|
|
||||||
|
if tenFlow.shape[3] == tenFlow.shape[2]:
|
||||||
|
tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0))
|
||||||
|
|
||||||
|
elif tenFlow.shape[3] != tenFlow.shape[2]:
|
||||||
|
tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True)
|
||||||
114
bim_vfi_arch/bim_vfi.py
Normal file
114
bim_vfi_arch/bim_vfi.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .backwarp import backwarp
|
||||||
|
from .resnet_encoder import ResNetPyramid
|
||||||
|
from .caun import CAUN
|
||||||
|
from .bimfn import BiMFN
|
||||||
|
from .sn import SynthesisNetwork
|
||||||
|
|
||||||
|
from ..utils.padder import InputPadder
|
||||||
|
|
||||||
|
|
||||||
|
class BiMVFI(nn.Module):
|
||||||
|
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
|
||||||
|
super(BiMVFI, self).__init__()
|
||||||
|
self.pyr_level = pyr_level
|
||||||
|
self.mfe = ResNetPyramid(feat_channels)
|
||||||
|
self.cfe = ResNetPyramid(feat_channels)
|
||||||
|
self.bimfn = BiMFN(feat_channels)
|
||||||
|
self.sn = SynthesisNetwork(feat_channels)
|
||||||
|
self.feat_channels = feat_channels
|
||||||
|
self.caun = CAUN(feat_channels)
|
||||||
|
|
||||||
|
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
|
||||||
|
feat0_pyr = self.mfe(img0)
|
||||||
|
feat1_pyr = self.mfe(img1)
|
||||||
|
cfeat0_pyr = self.cfe(img0)
|
||||||
|
cfeat1_pyr = self.cfe(img1)
|
||||||
|
|
||||||
|
B, _, H, W = feat0_pyr[-1].shape
|
||||||
|
|
||||||
|
# Inference path: prepare uniform BiM
|
||||||
|
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
|
||||||
|
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
|
||||||
|
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
|
||||||
|
|
||||||
|
last_flow = F.interpolate(
|
||||||
|
input=last_flow.detach().clone(), scale_factor=0.5,
|
||||||
|
mode="bilinear", align_corners=False) * 0.5
|
||||||
|
last_occ = F.interpolate(
|
||||||
|
input=last_occ.detach().clone(), scale_factor=0.5,
|
||||||
|
mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
|
flow_low, flow_res = self.bimfn(
|
||||||
|
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
|
||||||
|
|
||||||
|
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
|
||||||
|
flow = bi_flow_pyr[0]
|
||||||
|
|
||||||
|
interp_img, occ, extra_dict = self.sn(
|
||||||
|
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
|
||||||
|
extra_dict.update({'flow_res': flow_res})
|
||||||
|
return flow, occ, interp_img, extra_dict
|
||||||
|
|
||||||
|
def forward(self, img0, img1, time_step,
|
||||||
|
pyr_level=None, **kwargs):
|
||||||
|
if pyr_level is None: pyr_level = self.pyr_level
|
||||||
|
N, _, H, W = img0.shape
|
||||||
|
interp_imgs = []
|
||||||
|
|
||||||
|
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
|
||||||
|
|
||||||
|
# Normalize input images
|
||||||
|
with torch.set_grad_enabled(False):
|
||||||
|
tenStats = [img0, img1]
|
||||||
|
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
|
||||||
|
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
|
||||||
|
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
|
||||||
|
|
||||||
|
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
|
||||||
|
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
|
||||||
|
|
||||||
|
# Pad images for downsampling
|
||||||
|
img0, img1 = padder.pad(img0, img1)
|
||||||
|
|
||||||
|
N, _, H, W = img0.shape
|
||||||
|
|
||||||
|
for level in list(range(pyr_level))[::-1]:
|
||||||
|
# Downsample images if needed
|
||||||
|
if level != 0:
|
||||||
|
scale_factor = 1 / 2 ** level
|
||||||
|
img0_this_lvl = F.interpolate(
|
||||||
|
input=img0, scale_factor=scale_factor,
|
||||||
|
mode="bilinear", align_corners=False, antialias=True)
|
||||||
|
img1_this_lvl = F.interpolate(
|
||||||
|
input=img1, scale_factor=scale_factor,
|
||||||
|
mode="bilinear", align_corners=False, antialias=True)
|
||||||
|
else:
|
||||||
|
img0_this_lvl = img0
|
||||||
|
img1_this_lvl = img1
|
||||||
|
|
||||||
|
# Initialize zero flows for lowest pyramid level
|
||||||
|
if level == pyr_level - 1:
|
||||||
|
last_flow = torch.zeros(
|
||||||
|
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
|
||||||
|
)
|
||||||
|
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
|
||||||
|
else:
|
||||||
|
last_flow = flow
|
||||||
|
last_occ = occ
|
||||||
|
|
||||||
|
# Single pyramid level run
|
||||||
|
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
|
||||||
|
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
|
||||||
|
|
||||||
|
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"imgt_preds": interp_imgs,
|
||||||
|
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_dict
|
||||||
115
bim_vfi_arch/bimfn.py
Normal file
115
bim_vfi_arch/bimfn.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .backwarp import backwarp
|
||||||
|
from .arch import LayerNorm, ResBlock
|
||||||
|
from .costvol import costvol_func
|
||||||
|
|
||||||
|
|
||||||
|
class BiMMConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_FV = nn.Sequential(
|
||||||
|
LayerNorm(in_channels, data_format='channels_first'),
|
||||||
|
ResBlock(out_channels, 1)
|
||||||
|
)
|
||||||
|
self.conv_r = nn.Sequential(
|
||||||
|
LayerNorm(in_channels, data_format='channels_first'),
|
||||||
|
ResBlock(out_channels, 1)
|
||||||
|
)
|
||||||
|
self.conv_phi = nn.Sequential(
|
||||||
|
LayerNorm(in_channels, data_format='channels_first'),
|
||||||
|
ResBlock(out_channels, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, FV, r, phi):
|
||||||
|
FV_out1 = self.conv_FV(FV)
|
||||||
|
r_out = self.conv_r(r)
|
||||||
|
phi_out = self.conv_phi(phi)
|
||||||
|
FV_out2 = FV_out1 + FV_out1 * r_out * phi_out
|
||||||
|
return FV_out2, r_out, phi_out
|
||||||
|
|
||||||
|
|
||||||
|
class BiMFN(nn.Module):
|
||||||
|
def __init__(self, feat_channels):
|
||||||
|
super(BiMFN, self).__init__()
|
||||||
|
self.conv_flow = nn.Sequential(
|
||||||
|
nn.Conv2d(2, feat_channels * 2, 7, padding=3),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
)
|
||||||
|
self.conv_occ = nn.Sequential(
|
||||||
|
nn.Conv2d(1, feat_channels * 2, 7, padding=3),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
)
|
||||||
|
self.conv_corr = nn.Sequential(
|
||||||
|
nn.LeakyReLU(0.1),
|
||||||
|
nn.Conv2d(81, feat_channels * 2, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
)
|
||||||
|
self.conv0 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 12),
|
||||||
|
nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 8),
|
||||||
|
nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 6),
|
||||||
|
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv3 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 6),
|
||||||
|
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.dem = nn.Sequential(
|
||||||
|
nn.Conv2d(1, feat_channels * 4, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
|
||||||
|
)
|
||||||
|
self.aem = nn.Sequential(
|
||||||
|
nn.Conv2d(2, feat_channels * 4, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6)
|
||||||
|
|
||||||
|
self.conv_out = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
nn.Conv2d(feat_channels * 4, 4, 1, padding=0))
|
||||||
|
|
||||||
|
def forward(self, feat0, feat1, r, phi, last_flow, last_occ):
|
||||||
|
feat0_warp = backwarp(feat0, (last_flow[:, :2]))
|
||||||
|
feat1_warp = backwarp(feat1, (last_flow[:, 2:]))
|
||||||
|
volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9)
|
||||||
|
volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9)
|
||||||
|
corr0 = self.conv_corr(volume0)
|
||||||
|
corr1 = self.conv_corr(volume1)
|
||||||
|
flo0 = self.conv_flow(last_flow[:, :2])
|
||||||
|
flo1 = self.conv_flow(last_flow[:, 2:])
|
||||||
|
occ = self.conv_occ(last_occ)
|
||||||
|
input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1)
|
||||||
|
FV = self.conv0(input_feat)
|
||||||
|
FV = self.conv1(FV)
|
||||||
|
FV = self.conv2(FV)
|
||||||
|
FV0 = self.conv3(FV)
|
||||||
|
r0 = self.dem(r)
|
||||||
|
phi0 = self.aem(phi)
|
||||||
|
bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0)
|
||||||
|
flow_res = self.conv_out(bim_feat)
|
||||||
|
flow_low = flow_res + last_flow
|
||||||
|
|
||||||
|
return flow_low, flow_res
|
||||||
72
bim_vfi_arch/caun.py
Normal file
72
bim_vfi_arch/caun.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .backwarp import backwarp
|
||||||
|
|
||||||
|
|
||||||
|
class CAUN(nn.Module):
|
||||||
|
def __init__(self, feat_channels):
|
||||||
|
super(CAUN, self).__init__()
|
||||||
|
self.enc0 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
)
|
||||||
|
self.enc1 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
)
|
||||||
|
self.enc2 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 1),
|
||||||
|
)
|
||||||
|
self.kernel_x2 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
|
||||||
|
)
|
||||||
|
self.kernel_x4 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 1),
|
||||||
|
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsample_input(self, inp, mask, upsample_factor):
|
||||||
|
N, c, H, W = inp.shape
|
||||||
|
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
|
||||||
|
up_inp = F.unfold(inp, [3, 3])
|
||||||
|
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_inp = torch.sum(mask * up_inp, dim=2)
|
||||||
|
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
|
||||||
|
|
||||||
|
def forward(self, flow, feat0, feat1, last_occ):
|
||||||
|
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
feat0_warped_list, feat1_warped_list = [], []
|
||||||
|
for i in range(3):
|
||||||
|
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
|
||||||
|
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
|
||||||
|
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
|
||||||
|
feat0_warped_list.append(feat0_warped)
|
||||||
|
feat1_warped_list.append(feat1_warped)
|
||||||
|
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
|
||||||
|
feature0 = self.enc0(feature)
|
||||||
|
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
|
||||||
|
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
|
||||||
|
mask_x2 = self.kernel_x2(feature1)
|
||||||
|
mask_x4 = self.kernel_x4(feature2)
|
||||||
|
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||||
|
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
|
||||||
|
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||||
|
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
|
||||||
|
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
|
||||||
|
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
|
||||||
|
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
|
||||||
|
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
|
||||||
|
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
|
||||||
|
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
|
||||||
|
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
|
||||||
|
return [up_flow_x4, up_flow_x2, flow], up_occ
|
||||||
399
bim_vfi_arch/costvol.py
Normal file
399
bim_vfi_arch/costvol.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import cupy
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
objCudacache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_int32(intIn:int):
|
||||||
|
return cupy.int32(intIn)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_float32(fltIn:float):
|
||||||
|
return cupy.float32(fltIn)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
||||||
|
if 'device' not in objCudacache:
|
||||||
|
objCudacache['device'] = torch.cuda.get_device_name()
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey = strFunction
|
||||||
|
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
strKey += strVariable
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKey += objValue
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
strKey += str(objValue.dtype)
|
||||||
|
strKey += str(objValue.shape)
|
||||||
|
strKey += str(objValue.stride())
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey += objCudacache['device']
|
||||||
|
|
||||||
|
if strKey not in objCudacache:
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'unsigned char')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'half')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'float')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'double')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'int')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'long')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
print(strVariable, objValue.dtype)
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArg = int(objMatch.group(2))
|
||||||
|
|
||||||
|
strTensor = objMatch.group(4)
|
||||||
|
intSizes = objVariables[strTensor].size()
|
||||||
|
|
||||||
|
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(',')
|
||||||
|
|
||||||
|
assert(intArgs == len(strArgs) - 1)
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(',')
|
||||||
|
|
||||||
|
assert(intArgs == len(strArgs) - 1)
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
|
||||||
|
# end
|
||||||
|
|
||||||
|
objCudacache[strKey] = {
|
||||||
|
'strFunction': strFunction,
|
||||||
|
'strKernel': strKernel
|
||||||
|
}
|
||||||
|
# end
|
||||||
|
|
||||||
|
return strKey
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
@cupy.memoize(for_each_device=True)
|
||||||
|
def cuda_launch(strKey:str):
|
||||||
|
if 'CUDA_HOME' not in os.environ:
|
||||||
|
os.environ['CUDA_HOME'] = '/usr/local/cuda/'
|
||||||
|
# end
|
||||||
|
|
||||||
|
return cupy.RawModule(code=objCudacache[strKey]['strKernel']).get_function(objCudacache[strKey]['strFunction'])
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
class costvol_func(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||||
|
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||||
|
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||||
|
|
||||||
|
cuda_launch(cuda_kernel('costvol_out', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) costvol_out(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenOne,
|
||||||
|
const {{type}}* __restrict__ tenTwo,
|
||||||
|
const int intKernelSize,
|
||||||
|
{{type}}* __restrict__ tenOut
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||||
|
|
||||||
|
{{type}} fltOne[{{intChans}}];
|
||||||
|
|
||||||
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||||
|
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
|
||||||
|
}
|
||||||
|
|
||||||
|
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX);
|
||||||
|
|
||||||
|
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||||
|
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||||
|
{{type}} fltValue = 0.0f;
|
||||||
|
|
||||||
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) {
|
||||||
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||||
|
fltValue += (fltOne[intValue] * VALUE_4(tenTwo, intN, intValue, intOy, intOx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tenOut[intOffset] = fltValue;
|
||||||
|
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'intChans': tenOne.shape[1],
|
||||||
|
'tenOne': tenOne,
|
||||||
|
'tenTwo': tenTwo,
|
||||||
|
'intKernelSize': intKernelSize,
|
||||||
|
'tenOut': tenOut
|
||||||
|
}))(
|
||||||
|
grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.save_for_backward(tenOne, tenTwo)
|
||||||
|
self.intKernelSize = intKernelSize
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
# end
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.amp.custom_bwd(device_type='cuda')
|
||||||
|
def backward(self, tenOutgrad):
|
||||||
|
tenOne, tenTwo = self.saved_tensors
|
||||||
|
intKernelSize = self.intKernelSize
|
||||||
|
|
||||||
|
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
|
||||||
|
|
||||||
|
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None
|
||||||
|
tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None
|
||||||
|
|
||||||
|
if tenOnegrad is not None:
|
||||||
|
cuda_launch(cuda_kernel('costvol_onegrad', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenOne,
|
||||||
|
const {{type}}* __restrict__ tenTwo,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
const int intKernelSize,
|
||||||
|
{{type}}* __restrict__ tenOnegrad,
|
||||||
|
{{type}}* __restrict__ tenTwograd
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenOnegrad);
|
||||||
|
|
||||||
|
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
|
||||||
|
|
||||||
|
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||||
|
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||||
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
|
||||||
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||||
|
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += tenOutgrad[intOffset] * VALUE_4(tenTwo, intN, intValue, intOy, intOx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'intChans': tenOne.shape[1],
|
||||||
|
'tenOne': tenOne,
|
||||||
|
'tenTwo': tenTwo,
|
||||||
|
'tenOutgrad': tenOutgrad,
|
||||||
|
'intKernelSize': intKernelSize,
|
||||||
|
'tenOnegrad': tenOnegrad,
|
||||||
|
'tenTwograd': tenTwograd
|
||||||
|
}))(
|
||||||
|
grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
if tenTwograd is not None:
|
||||||
|
cuda_launch(cuda_kernel('costvol_twograd', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) costvol_twograd(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenOne,
|
||||||
|
const {{type}}* __restrict__ tenTwo,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
const int intKernelSize,
|
||||||
|
{{type}}* __restrict__ tenOnegrad,
|
||||||
|
{{type}}* __restrict__ tenTwograd
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenTwograd);
|
||||||
|
|
||||||
|
{{type}} fltOne[{{intChans}}];
|
||||||
|
|
||||||
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||||
|
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX);
|
||||||
|
}
|
||||||
|
|
||||||
|
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX);
|
||||||
|
|
||||||
|
for (int intOy = intY - (intKernelSize - 1) / 2; intOy <= intY + (intKernelSize - 1) / 2; intOy += 1) {
|
||||||
|
for (int intOx = intX - (intKernelSize - 1) / 2; intOx <= intX + (intKernelSize - 1) / 2; intOx += 1) {
|
||||||
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) {
|
||||||
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) {
|
||||||
|
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], tenOutgrad[intOffset] * fltOne[intValue]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'intChans': tenOne.shape[1],
|
||||||
|
'tenOne': tenOne,
|
||||||
|
'tenTwo': tenTwo,
|
||||||
|
'tenOutgrad': tenOutgrad,
|
||||||
|
'intKernelSize': intKernelSize,
|
||||||
|
'tenOnegrad': tenOnegrad,
|
||||||
|
'tenTwograd': tenTwograd
|
||||||
|
}))(
|
||||||
|
grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), intKernelSize, tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
return tenOnegrad, tenTwograd, None, None, None
|
||||||
|
# end
|
||||||
|
# end
|
||||||
102
bim_vfi_arch/resnet_encoder.py
Normal file
102
bim_vfi_arch/resnet_encoder.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Callable
|
||||||
|
from torch import Tensor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_planes,
|
||||||
|
out_planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=True,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||||
|
"""2x2 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion: int = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inplanes: int,
|
||||||
|
planes: int,
|
||||||
|
stride: int = 1,
|
||||||
|
downsample: Optional[nn.Module] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
base_width: int = 64,
|
||||||
|
dilation: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first')
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
self.bn1 = norm_layer(inplanes)
|
||||||
|
if stride == 1:
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
else:
|
||||||
|
self.conv1 = conv2x2(inplanes, planes, stride)
|
||||||
|
self.lrelu = nn.LeakyReLU(inplace=True)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.bn1(x)
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.lrelu(out)
|
||||||
|
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out = self.lrelu(out)
|
||||||
|
out = out + identity
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetPyramid(nn.Module):
|
||||||
|
"""A 3-level feature pyramid, which by default is shared by the motion
|
||||||
|
estimator and synthesis network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feat_channels: int):
|
||||||
|
super(ResNetPyramid, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.layer0 = nn.Sequential(
|
||||||
|
BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d),
|
||||||
|
)
|
||||||
|
self.layer1 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels, feat_channels * 2, 2, 2),
|
||||||
|
BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d),
|
||||||
|
)
|
||||||
|
self.layer2 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2),
|
||||||
|
BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d),
|
||||||
|
)
|
||||||
|
self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
C0 = self.layer0(self.conv(img))
|
||||||
|
C1 = self.layer1(C0)
|
||||||
|
C2 = self.conv_last(self.layer2(C1))
|
||||||
|
return [C0, C1, C2]
|
||||||
95
bim_vfi_arch/sn.py
Normal file
95
bim_vfi_arch/sn.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .backwarp import backwarp
|
||||||
|
|
||||||
|
|
||||||
|
class SynthesisNetwork(nn.Module):
|
||||||
|
def __init__(self, feat_channels):
|
||||||
|
super(SynthesisNetwork, self).__init__()
|
||||||
|
input_channels = 6 + 1
|
||||||
|
self.conv_down1 = nn.Sequential(
|
||||||
|
nn.Conv2d(input_channels, feat_channels, 7, padding=3),
|
||||||
|
nn.PReLU(feat_channels),
|
||||||
|
nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2))
|
||||||
|
self.conv_down2 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2))
|
||||||
|
self.conv_down3 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 4),
|
||||||
|
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 4))
|
||||||
|
self.conv_up1 = nn.Sequential(
|
||||||
|
torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1),
|
||||||
|
nn.PixelShuffle(upscale_factor=2),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2))
|
||||||
|
self.conv_up2 = nn.Sequential(
|
||||||
|
torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
||||||
|
nn.PixelShuffle(upscale_factor=2),
|
||||||
|
nn.PReLU(feat_channels * 1),
|
||||||
|
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 1))
|
||||||
|
self.conv_up3 = nn.Sequential(
|
||||||
|
nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
||||||
|
nn.PReLU(feat_channels * 2),
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1)
|
||||||
|
|
||||||
|
def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None):
|
||||||
|
flow_t0 = bi_flow[:, :2]
|
||||||
|
flow_t1 = bi_flow[:, 2:4]
|
||||||
|
warped_c0 = backwarp(c0, flow_t0)
|
||||||
|
warped_c1 = backwarp(c1, flow_t1)
|
||||||
|
if (i0 is None) and (i1 is None):
|
||||||
|
return warped_c0, warped_c1
|
||||||
|
else:
|
||||||
|
warped_img0 = backwarp(i0, flow_t0)
|
||||||
|
warped_img1 = backwarp(i1, flow_t1)
|
||||||
|
return warped_img0, warped_img1, warped_c0, warped_c1
|
||||||
|
|
||||||
|
def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ):
|
||||||
|
warped_img0, warped_img1, warped_c0, warped_c1 = \
|
||||||
|
self.get_warped_representations(
|
||||||
|
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1)
|
||||||
|
input_feat = torch.cat(
|
||||||
|
(warped_img0, warped_img1, occ), 1)
|
||||||
|
s0 = self.conv_down1(input_feat)
|
||||||
|
s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1))
|
||||||
|
warped_c0, warped_c1 = self.get_warped_representations(
|
||||||
|
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None)
|
||||||
|
s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1))
|
||||||
|
warped_c0, warped_c1 = self.get_warped_representations(
|
||||||
|
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None)
|
||||||
|
|
||||||
|
x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1))
|
||||||
|
x = self.conv_up2(torch.cat((x, s1), 1))
|
||||||
|
x = self.conv_up3(torch.cat((x, s0), 1))
|
||||||
|
|
||||||
|
refine = self.conv_out(x)
|
||||||
|
refine_res = refine[:, :3]
|
||||||
|
occ_res = refine[:, 3:]
|
||||||
|
occ_out = occ + occ_res
|
||||||
|
blending_mask = torch.sigmoid(occ_out)
|
||||||
|
merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res
|
||||||
|
interp_img = merged_img
|
||||||
|
|
||||||
|
extra_dict = {}
|
||||||
|
extra_dict["refine_res"] = refine_res
|
||||||
|
extra_dict["refine_mask"] = occ_out
|
||||||
|
extra_dict["warped_img0"] = warped_img0
|
||||||
|
extra_dict["warped_img1"] = warped_img1
|
||||||
|
extra_dict["merged_img"] = merged_img
|
||||||
|
|
||||||
|
return interp_img, occ_out, extra_dict
|
||||||
81
inference.py
Normal file
81
inference.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import torch
|
||||||
|
from .bim_vfi_arch import BiMVFI
|
||||||
|
|
||||||
|
|
||||||
|
class BiMVFIModel:
|
||||||
|
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_path, pyr_level=3, device="cpu"):
|
||||||
|
self.pyr_level = pyr_level
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
|
||||||
|
self._load_checkpoint(checkpoint_path)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
def _load_checkpoint(self, checkpoint_path):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
# Handle different checkpoint formats
|
||||||
|
if "model" in checkpoint:
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
elif "state_dict" in checkpoint:
|
||||||
|
state_dict = checkpoint["state_dict"]
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
# Strip common prefixes (e.g. "module." from DDP or "model." from wrapper)
|
||||||
|
cleaned = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
key = k
|
||||||
|
if key.startswith("module."):
|
||||||
|
key = key[len("module."):]
|
||||||
|
if key.startswith("model."):
|
||||||
|
key = key[len("model."):]
|
||||||
|
cleaned[key] = v
|
||||||
|
|
||||||
|
self.model.load_state_dict(cleaned)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.device = device
|
||||||
|
self.model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
||||||
|
"""Interpolate a single frame between two input frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
time_step: float in (0, 1), temporal position of interpolated frame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
img0 = frame0.to(device)
|
||||||
|
img1 = frame1.to(device)
|
||||||
|
|
||||||
|
_, _, h, w = img0.shape
|
||||||
|
if h >= 2160:
|
||||||
|
pyr_level = 7
|
||||||
|
elif h >= 1080:
|
||||||
|
pyr_level = 6
|
||||||
|
elif h >= 540:
|
||||||
|
pyr_level = 5
|
||||||
|
else:
|
||||||
|
pyr_level = self.pyr_level
|
||||||
|
|
||||||
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||||
|
|
||||||
|
result_dict = self.model(
|
||||||
|
img0=img0, img1=img1,
|
||||||
|
time_step=time_step_tensor,
|
||||||
|
pyr_level=pyr_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
interp = result_dict["imgt_pred"]
|
||||||
|
interp = torch.clamp(interp, 0, 1)
|
||||||
|
return interp
|
||||||
47
install.py
Normal file
47
install.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_cupy_package():
|
||||||
|
"""Detect PyTorch's CUDA version and return the matching cupy package name."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("[BIM-VFI] WARNING: CUDA not available. cupy requires CUDA.")
|
||||||
|
return None
|
||||||
|
cuda_version = torch.version.cuda
|
||||||
|
if cuda_version is None:
|
||||||
|
print("[BIM-VFI] WARNING: PyTorch has no CUDA version info.")
|
||||||
|
return None
|
||||||
|
major = cuda_version.split(".")[0]
|
||||||
|
major = int(major)
|
||||||
|
cupy_pkg = f"cupy-cuda{major}x"
|
||||||
|
print(f"[BIM-VFI] Detected CUDA {cuda_version}, will use {cupy_pkg}")
|
||||||
|
return cupy_pkg
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[BIM-VFI] WARNING: Could not detect CUDA version: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def install():
|
||||||
|
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
||||||
|
])
|
||||||
|
|
||||||
|
# Install cupy matching the current CUDA version
|
||||||
|
try:
|
||||||
|
import cupy
|
||||||
|
print("[BIM-VFI] cupy already installed, skipping.")
|
||||||
|
except ImportError:
|
||||||
|
cupy_pkg = get_cupy_package()
|
||||||
|
if cupy_pkg:
|
||||||
|
print(f"[BIM-VFI] Installing {cupy_pkg} to match PyTorch CUDA...")
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", cupy_pkg
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
install()
|
||||||
169
nodes.py
Normal file
169
nodes.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
from .inference import BiMVFIModel
|
||||||
|
from .bim_vfi_arch import clear_backwarp_cache
|
||||||
|
|
||||||
|
logger = logging.getLogger("BIM-VFI")
|
||||||
|
|
||||||
|
# Google Drive file ID for the pretrained model
|
||||||
|
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
|
||||||
|
MODEL_FILENAME = "bim_vfi.pth"
|
||||||
|
|
||||||
|
# Register the model folder with ComfyUI
|
||||||
|
MODEL_DIR = os.path.join(folder_paths.models_dir, "bim-vfi")
|
||||||
|
if not os.path.exists(MODEL_DIR):
|
||||||
|
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_models():
|
||||||
|
"""List available checkpoint files in the bim-vfi model directory."""
|
||||||
|
models = []
|
||||||
|
if os.path.isdir(MODEL_DIR):
|
||||||
|
for f in os.listdir(MODEL_DIR):
|
||||||
|
if f.endswith((".pth", ".pt", ".ckpt", ".safetensors")):
|
||||||
|
models.append(f)
|
||||||
|
if not models:
|
||||||
|
models.append(MODEL_FILENAME) # Will trigger auto-download
|
||||||
|
return sorted(models)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_from_gdrive(file_id, dest_path):
|
||||||
|
"""Download a file from Google Drive using gdown."""
|
||||||
|
try:
|
||||||
|
import gdown
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"gdown is required to auto-download the BIM-VFI model. "
|
||||||
|
"Install it with: pip install gdown"
|
||||||
|
)
|
||||||
|
url = f"https://drive.google.com/uc?id={file_id}"
|
||||||
|
logger.info(f"Downloading BIM-VFI model to {dest_path}...")
|
||||||
|
gdown.download(url, dest_path, quiet=False)
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
raise RuntimeError(f"Failed to download model to {dest_path}")
|
||||||
|
logger.info("Download complete.")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBIMVFIModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_path": (get_available_models(), {"default": MODEL_FILENAME}),
|
||||||
|
"pyr_level": ("INT", {"default": 3, "min": 3, "max": 7, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("BIM_VFI_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = "video/BIM-VFI"
|
||||||
|
|
||||||
|
def load_model(self, model_path, pyr_level):
|
||||||
|
full_path = os.path.join(MODEL_DIR, model_path)
|
||||||
|
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
logger.info(f"Model not found at {full_path}, attempting download...")
|
||||||
|
download_model_from_gdrive(GDRIVE_FILE_ID, full_path)
|
||||||
|
|
||||||
|
wrapper = BiMVFIModel(
|
||||||
|
checkpoint_path=full_path,
|
||||||
|
pyr_level=pyr_level,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"BIM-VFI model loaded (pyr_level={pyr_level})")
|
||||||
|
return (wrapper,)
|
||||||
|
|
||||||
|
|
||||||
|
class BIMVFIInterpolate:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"model": ("BIM_VFI_MODEL",),
|
||||||
|
"multiplier": ([2, 4, 8], {"default": 2}),
|
||||||
|
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1}),
|
||||||
|
"keep_device": ("BOOLEAN", {"default": True}),
|
||||||
|
"all_on_gpu": ("BOOLEAN", {"default": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
FUNCTION = "interpolate"
|
||||||
|
CATEGORY = "video/BIM-VFI"
|
||||||
|
|
||||||
|
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu):
|
||||||
|
if images.shape[0] < 2:
|
||||||
|
return (images,)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||||
|
|
||||||
|
# all_on_gpu implies keep_device
|
||||||
|
if all_on_gpu:
|
||||||
|
keep_device = True
|
||||||
|
|
||||||
|
# Where to store intermediate frames
|
||||||
|
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||||
|
|
||||||
|
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
|
||||||
|
frames = images.permute(0, 3, 1, 2).to(storage_device)
|
||||||
|
|
||||||
|
# After each 2x pass, frame count = 2*N - 1, so compute total pairs across passes
|
||||||
|
n = frames.shape[0]
|
||||||
|
total_steps = 0
|
||||||
|
for _ in range(num_passes):
|
||||||
|
total_steps += n - 1
|
||||||
|
n = 2 * n - 1
|
||||||
|
|
||||||
|
pbar = ProgressBar(total_steps)
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
if keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
for pass_idx in range(num_passes):
|
||||||
|
new_frames = []
|
||||||
|
num_pairs = frames.shape[0] - 1
|
||||||
|
|
||||||
|
for i in range(num_pairs):
|
||||||
|
frame0 = frames[i:i+1] # [1, C, H, W]
|
||||||
|
frame1 = frames[i+1:i+2] # [1, C, H, W]
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
mid = model.interpolate_pair(frame0, frame1, time_step=0.5)
|
||||||
|
mid = mid.to(storage_device)
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
new_frames.append(frames[i:i+1])
|
||||||
|
new_frames.append(mid)
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
pbar.update_absolute(step, total_steps)
|
||||||
|
|
||||||
|
if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
|
||||||
|
clear_backwarp_cache()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Append last frame
|
||||||
|
new_frames.append(frames[-1:])
|
||||||
|
frames = torch.cat(new_frames, dim=0)
|
||||||
|
|
||||||
|
if not all_on_gpu and torch.cuda.is_available():
|
||||||
|
clear_backwarp_cache()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Convert back to ComfyUI [B, H, W, C], on CPU for ComfyUI
|
||||||
|
result = frames.cpu().permute(0, 2, 3, 1)
|
||||||
|
return (result,)
|
||||||
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
gdown
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/padder.py
Normal file
28
utils/padder.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
""" Pads images such that dimensions are divisible by divisor """
|
||||||
|
|
||||||
|
def __init__(self, dims, divisor=16):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
||||||
|
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
||||||
|
self._pad = [0, pad_wd, 0, pad_ht]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return F.pad(inputs[0], self._pad, mode='constant')
|
||||||
|
else:
|
||||||
|
return [F.pad(x, self._pad, mode='constant') for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, *inputs):
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return self._unpad(inputs[0])
|
||||||
|
else:
|
||||||
|
return [self._unpad(x) for x in inputs]
|
||||||
|
|
||||||
|
def _unpad(self, x):
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||||
|
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||||
Reference in New Issue
Block a user