Files
ComfyUI-Tween/bim_vfi_arch/bimfn.py
Ethanfel db64fc195a Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 18:26:49 +01:00

116 lines
4.4 KiB
Python

import torch
import torch.nn as nn
from .backwarp import backwarp
from .arch import LayerNorm, ResBlock
from .costvol import costvol_func
class BiMMConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_FV = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_r = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
self.conv_phi = nn.Sequential(
LayerNorm(in_channels, data_format='channels_first'),
ResBlock(out_channels, 1)
)
def forward(self, FV, r, phi):
FV_out1 = self.conv_FV(FV)
r_out = self.conv_r(r)
phi_out = self.conv_phi(phi)
FV_out2 = FV_out1 + FV_out1 * r_out * phi_out
return FV_out2, r_out, phi_out
class BiMFN(nn.Module):
def __init__(self, feat_channels):
super(BiMFN, self).__init__()
self.conv_flow = nn.Sequential(
nn.Conv2d(2, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_occ = nn.Sequential(
nn.Conv2d(1, feat_channels * 2, 7, padding=3),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv_corr = nn.Sequential(
nn.LeakyReLU(0.1),
nn.Conv2d(81, feat_channels * 2, 1, padding=0),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
)
self.conv0 = nn.Sequential(
nn.Conv2d(feat_channels * 18, feat_channels * 12, 1, padding=0),
nn.PReLU(feat_channels * 12),
nn.Conv2d(feat_channels * 12, feat_channels * 12, 3, padding=1),
)
self.conv1 = nn.Sequential(
nn.Conv2d(feat_channels * 12, feat_channels * 8, 1, padding=0),
nn.PReLU(feat_channels * 8),
nn.Conv2d(feat_channels * 8, feat_channels * 8, 3, padding=1),
)
self.conv2 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.conv3 = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 6, 1, padding=0),
nn.PReLU(feat_channels * 6),
nn.Conv2d(feat_channels * 6, feat_channels * 6, 3, padding=1),
)
self.dem = nn.Sequential(
nn.Conv2d(1, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.aem = nn.Sequential(
nn.Conv2d(2, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, feat_channels * 6, 1, padding=0),
)
self.bim_mconv = BiMMConv(feat_channels * 6, feat_channels * 6)
self.conv_out = nn.Sequential(
nn.Conv2d(feat_channels * 6, feat_channels * 4, 1, padding=0),
nn.PReLU(feat_channels * 4),
nn.Conv2d(feat_channels * 4, 4, 1, padding=0))
def forward(self, feat0, feat1, r, phi, last_flow, last_occ):
feat0_warp = backwarp(feat0, (last_flow[:, :2]))
feat1_warp = backwarp(feat1, (last_flow[:, 2:]))
volume0 = costvol_func.apply(feat0_warp, feat1_warp, 9)
volume1 = costvol_func.apply(feat1_warp, feat0_warp, 9)
corr0 = self.conv_corr(volume0)
corr1 = self.conv_corr(volume1)
flo0 = self.conv_flow(last_flow[:, :2])
flo1 = self.conv_flow(last_flow[:, 2:])
occ = self.conv_occ(last_occ)
input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, flo0, flo1, occ], 1)
FV = self.conv0(input_feat)
FV = self.conv1(FV)
FV = self.conv2(FV)
FV0 = self.conv3(FV)
r0 = self.dem(r)
phi0 = self.aem(phi)
bim_feat, _, _ = self.bim_mconv(FV0, r0, phi0)
flow_res = self.conv_out(bim_feat)
flow_low = flow_res + last_flow
return flow_low, flow_res