Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
73 lines
3.5 KiB
Python
73 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .backwarp import backwarp
|
|
|
|
|
|
class CAUN(nn.Module):
|
|
def __init__(self, feat_channels):
|
|
super(CAUN, self).__init__()
|
|
self.enc0 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
|
|
nn.PReLU(feat_channels * 4),
|
|
)
|
|
self.enc1 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
|
|
nn.PReLU(feat_channels * 4),
|
|
)
|
|
self.enc2 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
|
|
nn.PReLU(feat_channels * 1),
|
|
)
|
|
self.kernel_x2 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2),
|
|
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
|
|
)
|
|
self.kernel_x4 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
|
nn.PReLU(feat_channels * 1),
|
|
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
|
|
)
|
|
|
|
def upsample_input(self, inp, mask, upsample_factor):
|
|
N, c, H, W = inp.shape
|
|
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
|
|
mask = torch.softmax(mask, dim=2)
|
|
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
|
|
up_inp = F.unfold(inp, [3, 3])
|
|
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
|
|
|
|
up_inp = torch.sum(mask * up_inp, dim=2)
|
|
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
|
|
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
|
|
|
|
def forward(self, flow, feat0, feat1, last_occ):
|
|
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
|
|
N, _, H, W = flow.shape
|
|
feat0_warped_list, feat1_warped_list = [], []
|
|
for i in range(3):
|
|
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
|
|
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
|
|
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
|
|
feat0_warped_list.append(feat0_warped)
|
|
feat1_warped_list.append(feat1_warped)
|
|
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
|
|
feature0 = self.enc0(feature)
|
|
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
|
|
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
|
|
mask_x2 = self.kernel_x2(feature1)
|
|
mask_x4 = self.kernel_x4(feature2)
|
|
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
|
|
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
|
|
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
|
|
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
|
|
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
|
|
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
|
|
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
|
|
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
|
|
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
|
|
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
|
|
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
|
|
return [up_flow_x4, up_flow_x2, flow], up_occ
|