Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
114
bim_vfi_arch/bim_vfi.py
Normal file
114
bim_vfi_arch/bim_vfi.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from .backwarp import backwarp
|
||||
from .resnet_encoder import ResNetPyramid
|
||||
from .caun import CAUN
|
||||
from .bimfn import BiMFN
|
||||
from .sn import SynthesisNetwork
|
||||
|
||||
from ..utils.padder import InputPadder
|
||||
|
||||
|
||||
class BiMVFI(nn.Module):
|
||||
def __init__(self, pyr_level=3, feat_channels=32, **kwargs):
|
||||
super(BiMVFI, self).__init__()
|
||||
self.pyr_level = pyr_level
|
||||
self.mfe = ResNetPyramid(feat_channels)
|
||||
self.cfe = ResNetPyramid(feat_channels)
|
||||
self.bimfn = BiMFN(feat_channels)
|
||||
self.sn = SynthesisNetwork(feat_channels)
|
||||
self.feat_channels = feat_channels
|
||||
self.caun = CAUN(feat_channels)
|
||||
|
||||
def forward_one_lvl(self, img0, img1, last_flow, last_occ, time_period=0.5):
|
||||
feat0_pyr = self.mfe(img0)
|
||||
feat1_pyr = self.mfe(img1)
|
||||
cfeat0_pyr = self.cfe(img0)
|
||||
cfeat1_pyr = self.cfe(img1)
|
||||
|
||||
B, _, H, W = feat0_pyr[-1].shape
|
||||
|
||||
# Inference path: prepare uniform BiM
|
||||
r = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * time_period
|
||||
phi = torch.ones((B, 1, H, W), device=feat0_pyr[-1].device) * torch.pi
|
||||
phi = torch.cat([torch.cos(phi), torch.sin(phi)], dim=1)
|
||||
|
||||
last_flow = F.interpolate(
|
||||
input=last_flow.detach().clone(), scale_factor=0.5,
|
||||
mode="bilinear", align_corners=False) * 0.5
|
||||
last_occ = F.interpolate(
|
||||
input=last_occ.detach().clone(), scale_factor=0.5,
|
||||
mode="bilinear", align_corners=False)
|
||||
|
||||
flow_low, flow_res = self.bimfn(
|
||||
feat0_pyr[-1], feat1_pyr[-1], r, phi, last_flow, last_occ)
|
||||
|
||||
bi_flow_pyr, occ = self.caun(flow_low, cfeat0_pyr, cfeat1_pyr, last_occ)
|
||||
flow = bi_flow_pyr[0]
|
||||
|
||||
interp_img, occ, extra_dict = self.sn(
|
||||
img0, img1, cfeat0_pyr, cfeat1_pyr, bi_flow_pyr, occ)
|
||||
extra_dict.update({'flow_res': flow_res})
|
||||
return flow, occ, interp_img, extra_dict
|
||||
|
||||
def forward(self, img0, img1, time_step,
|
||||
pyr_level=None, **kwargs):
|
||||
if pyr_level is None: pyr_level = self.pyr_level
|
||||
N, _, H, W = img0.shape
|
||||
interp_imgs = []
|
||||
|
||||
padder = InputPadder(img0.shape, divisor=int(2 ** (pyr_level + 1)))
|
||||
|
||||
# Normalize input images
|
||||
with torch.set_grad_enabled(False):
|
||||
tenStats = [img0, img1]
|
||||
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats)
|
||||
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + (
|
||||
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt()
|
||||
|
||||
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001)
|
||||
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001)
|
||||
|
||||
# Pad images for downsampling
|
||||
img0, img1 = padder.pad(img0, img1)
|
||||
|
||||
N, _, H, W = img0.shape
|
||||
|
||||
for level in list(range(pyr_level))[::-1]:
|
||||
# Downsample images if needed
|
||||
if level != 0:
|
||||
scale_factor = 1 / 2 ** level
|
||||
img0_this_lvl = F.interpolate(
|
||||
input=img0, scale_factor=scale_factor,
|
||||
mode="bilinear", align_corners=False, antialias=True)
|
||||
img1_this_lvl = F.interpolate(
|
||||
input=img1, scale_factor=scale_factor,
|
||||
mode="bilinear", align_corners=False, antialias=True)
|
||||
else:
|
||||
img0_this_lvl = img0
|
||||
img1_this_lvl = img1
|
||||
|
||||
# Initialize zero flows for lowest pyramid level
|
||||
if level == pyr_level - 1:
|
||||
last_flow = torch.zeros(
|
||||
(N, 4, H // (2 ** (level + 1)), W // (2 ** (level + 1))), device=img0.device
|
||||
)
|
||||
last_occ = torch.zeros(N, 1, H // (2 ** (level + 1)), W // (2 ** (level + 1)), device=img0.device)
|
||||
else:
|
||||
last_flow = flow
|
||||
last_occ = occ
|
||||
|
||||
# Single pyramid level run
|
||||
flow, occ, interp_img, extra_dict = self.forward_one_lvl(
|
||||
img0_this_lvl, img1_this_lvl, last_flow, last_occ, time_step)
|
||||
|
||||
interp_imgs.append((interp_img) * (tenStd_ + 0.0000001) + tenMean_)
|
||||
|
||||
result_dict = {
|
||||
"imgt_preds": interp_imgs,
|
||||
'imgt_pred': padder.unpad(interp_imgs[-1].contiguous()),
|
||||
}
|
||||
|
||||
return result_dict
|
||||
Reference in New Issue
Block a user