Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
96 lines
4.2 KiB
Python
96 lines
4.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from .backwarp import backwarp
|
|
|
|
|
|
class SynthesisNetwork(nn.Module):
|
|
def __init__(self, feat_channels):
|
|
super(SynthesisNetwork, self).__init__()
|
|
input_channels = 6 + 1
|
|
self.conv_down1 = nn.Sequential(
|
|
nn.Conv2d(input_channels, feat_channels, 7, padding=3),
|
|
nn.PReLU(feat_channels),
|
|
nn.Conv2d(feat_channels, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2))
|
|
self.conv_down2 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 4, feat_channels * 2, 2, stride=2, padding=0),
|
|
nn.PReLU(feat_channels * 2),
|
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2),
|
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2))
|
|
self.conv_down3 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 6, feat_channels * 4, 2, stride=2, padding=0),
|
|
nn.PReLU(feat_channels * 4),
|
|
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
|
nn.PReLU(feat_channels * 4),
|
|
nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
|
nn.PReLU(feat_channels * 4))
|
|
self.conv_up1 = nn.Sequential(
|
|
torch.nn.Conv2d(feat_channels * 12, feat_channels * 8, 3, padding=1),
|
|
nn.PixelShuffle(upscale_factor=2),
|
|
nn.PReLU(feat_channels * 2),
|
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2))
|
|
self.conv_up2 = nn.Sequential(
|
|
torch.nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, padding=1),
|
|
nn.PixelShuffle(upscale_factor=2),
|
|
nn.PReLU(feat_channels * 1),
|
|
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
|
|
nn.PReLU(feat_channels * 1))
|
|
self.conv_up3 = nn.Sequential(
|
|
nn.Conv2d(feat_channels * 3, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2),
|
|
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
|
|
nn.PReLU(feat_channels * 2),
|
|
)
|
|
self.conv_out = nn.Conv2d(feat_channels * 2, 4, 3, padding=1)
|
|
|
|
def get_warped_representations(self, bi_flow, c0, c1, i0=None, i1=None):
|
|
flow_t0 = bi_flow[:, :2]
|
|
flow_t1 = bi_flow[:, 2:4]
|
|
warped_c0 = backwarp(c0, flow_t0)
|
|
warped_c1 = backwarp(c1, flow_t1)
|
|
if (i0 is None) and (i1 is None):
|
|
return warped_c0, warped_c1
|
|
else:
|
|
warped_img0 = backwarp(i0, flow_t0)
|
|
warped_img1 = backwarp(i1, flow_t1)
|
|
return warped_img0, warped_img1, warped_c0, warped_c1
|
|
|
|
def forward(self, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, occ):
|
|
warped_img0, warped_img1, warped_c0, warped_c1 = \
|
|
self.get_warped_representations(
|
|
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1)
|
|
input_feat = torch.cat(
|
|
(warped_img0, warped_img1, occ), 1)
|
|
s0 = self.conv_down1(input_feat)
|
|
s1 = self.conv_down2(torch.cat((s0, warped_c0, warped_c1), 1))
|
|
warped_c0, warped_c1 = self.get_warped_representations(
|
|
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None)
|
|
s2 = self.conv_down3(torch.cat((s1, warped_c0, warped_c1), 1))
|
|
warped_c0, warped_c1 = self.get_warped_representations(
|
|
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None)
|
|
|
|
x = self.conv_up1(torch.cat((s2, warped_c0, warped_c1), 1))
|
|
x = self.conv_up2(torch.cat((x, s1), 1))
|
|
x = self.conv_up3(torch.cat((x, s0), 1))
|
|
|
|
refine = self.conv_out(x)
|
|
refine_res = refine[:, :3]
|
|
occ_res = refine[:, 3:]
|
|
occ_out = occ + occ_res
|
|
blending_mask = torch.sigmoid(occ_out)
|
|
merged_img = (warped_img0 * blending_mask + warped_img1 * (1 - blending_mask)) + refine_res
|
|
interp_img = merged_img
|
|
|
|
extra_dict = {}
|
|
extra_dict["refine_res"] = refine_res
|
|
extra_dict["refine_mask"] = occ_out
|
|
extra_dict["warped_img0"] = warped_img0
|
|
extra_dict["warped_img1"] = warped_img1
|
|
extra_dict["merged_img"] = merged_img
|
|
|
|
return interp_img, occ_out, extra_dict
|