Files
ComfyUI-Tween/bim_vfi_arch/caun.py
Ethanfel db64fc195a Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 18:26:49 +01:00

73 lines
3.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from .backwarp import backwarp
class CAUN(nn.Module):
def __init__(self, feat_channels):
super(CAUN, self).__init__()
self.enc0 = nn.Sequential(
nn.Conv2d(feat_channels * 8, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc1 = nn.Sequential(
nn.Conv2d(feat_channels * 5, feat_channels * 4, 3, padding=1),
nn.PReLU(feat_channels * 4),
)
self.enc2 = nn.Sequential(
nn.Conv2d(feat_channels * 3, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
)
self.kernel_x2 = nn.Sequential(
nn.Conv2d(feat_channels * 4, feat_channels * 2, 3, padding=1),
nn.PReLU(feat_channels * 2),
nn.Conv2d(feat_channels * 2, 2 * 1 * 9, 3, padding=1)
)
self.kernel_x4 = nn.Sequential(
nn.Conv2d(feat_channels * 1, feat_channels * 1, 3, padding=1),
nn.PReLU(feat_channels * 1),
nn.Conv2d(feat_channels * 1, 2 * 1 * 9, 3, padding=1)
)
def upsample_input(self, inp, mask, upsample_factor):
N, c, H, W = inp.shape
mask = mask.view(N, 1, 9, upsample_factor, upsample_factor, H, W)
mask = torch.softmax(mask, dim=2)
inp = F.pad(inp, [1, 1, 1, 1], mode='replicate')
up_inp = F.unfold(inp, [3, 3])
up_inp = up_inp.view(N, c, 9, 1, 1, H, W)
up_inp = torch.sum(mask * up_inp, dim=2)
up_inp = up_inp.permute(0, 1, 4, 2, 5, 3)
return up_inp.reshape(N, c, upsample_factor*H, upsample_factor*W)
def forward(self, flow, feat0, feat1, last_occ):
""" Upsample flow field [H/4, W/4, 4] -> [H, W, 4] using convex combination """
N, _, H, W = flow.shape
feat0_warped_list, feat1_warped_list = [], []
for i in range(3):
flow_bi = F.interpolate(flow * 2 ** i, scale_factor=2 ** i, mode='bilinear')
feat0_warped = backwarp(feat0[2-i], flow_bi[:, :2])
feat1_warped = backwarp(feat1[2-i], flow_bi[:, 2:])
feat0_warped_list.append(feat0_warped)
feat1_warped_list.append(feat1_warped)
feature = torch.cat([feat0_warped_list[0], feat1_warped_list[0]], dim=1)
feature0 = self.enc0(feature)
feature1 = self.enc1(torch.cat([F.pixel_shuffle(feature0, 2), feat0_warped_list[1], feat1_warped_list[1]], dim=1))
feature2 = self.enc2(torch.cat([F.pixel_shuffle(feature1, 2), feat0_warped_list[2], feat1_warped_list[2]], dim=1))
mask_x2 = self.kernel_x2(feature1)
mask_x4 = self.kernel_x4(feature2)
mask_x2 = mask_x2.view(N, 18, H, 2, W, 2).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x2_0, mask_x2_1 = torch.chunk(mask_x2, 2, dim=1)
mask_x4 = mask_x4.view(N, 18, H, 4, W, 4).permute(0, 1, 3, 5, 2, 4).contiguous()
mask_x4_0, mask_x4_1 = torch.chunk(mask_x4, 2, dim=1)
up_flow_x2_0 = self.upsample_input(flow[:, :2] * 2, mask_x2_0, 2)
up_flow_x2_1 = self.upsample_input(flow[:, 2:] * 2, mask_x2_1, 2)
up_flow_x4_0 = self.upsample_input(flow[:, :2] * 4, mask_x4_0, 4)
up_flow_x4_1 = self.upsample_input(flow[:, 2:] * 4, mask_x4_1, 4)
up_flow_x2 = torch.cat([up_flow_x2_0, up_flow_x2_1], dim=1)
up_flow_x4 = torch.cat([up_flow_x4_0, up_flow_x4_1], dim=1)
up_occ = F.interpolate(last_occ, scale_factor=4, mode='bilinear')
return [up_flow_x4, up_flow_x2, flow], up_occ