Files
ComfyUI-Tween/bim_vfi_arch/bim_vfi.py
Ethanfel db64fc195a Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 18:26:49 +01:00

115 lines
4.3 KiB
Python

import torch
import torch.nn.functional as F
import torch.nn as nn
from .backwarp import backwarp
from .resnet_encoder import ResNetPyramid
from .caun import CAUN
from .bimfn import BiMFN
from .sn import SynthesisNetwork
from ..utils.padder import InputPadder
class BiMVFI(nn.Module):
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
super(BiMVFI, self).__init__()
self.pyr_level = pyr_level
self.mfe = ResNetPyramid(feat_channels)
self.cfe = ResNetPyramid(feat_channels)
self.bimfn = BiMFN(feat_channels)
self.sn = SynthesisNetwork(feat_channels)
self.feat_channels = feat_channels
self.caun = CAUN(feat_channels)
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
feat0_pyr = self.mfe(img0)
feat1_pyr = self.mfe(img1)
cfeat0_pyr = self.cfe(img0)
cfeat1_pyr = self.cfe(img1)
B, _, H, W = feat0_pyr[-1].shape
# Inference path: prepare uniform BiM
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
last_flow = F.interpolate(
input=last_flow.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False) * 0.5
last_occ = F.interpolate(
input=last_occ.detach().clone(), scale_factor=0.5,
mode="bilinear", align_corners=False)
flow_low, flow_res = self.bimfn(
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
flow = bi_flow_pyr[0]
interp_img, occ, extra_dict = self.sn(
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
extra_dict.update({'flow_res': flow_res})
return flow, occ, interp_img, extra_dict
def forward(self, img0, img1, time_step,
pyr_level=None, **kwargs):
if pyr_level is None: pyr_level = self.pyr_level
N, _, H, W = img0.shape
interp_imgs = []
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
# Normalize input images
with torch.set_grad_enabled(False):
tenStats = [img0, img1]
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
# Pad images for downsampling
img0, img1 = padder.pad(img0, img1)
N, _, H, W = img0.shape
for level in list(range(pyr_level))[::-1]:
# Downsample images if needed
if level != 0:
scale_factor = 1 / 2 ** level
img0_this_lvl = F.interpolate(
input=img0, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
img1_this_lvl = F.interpolate(
input=img1, scale_factor=scale_factor,
mode="bilinear", align_corners=False, antialias=True)
else:
img0_this_lvl = img0
img1_this_lvl = img1
# Initialize zero flows for lowest pyramid level
if level == pyr_level - 1:
last_flow = torch.zeros(
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
)
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
else:
last_flow = flow
last_occ = occ
# Single pyramid level run
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
result_dict = {
"imgt_preds": interp_imgs,
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
}
return result_dict