Files
ComfyUI-Tween/bim_vfi_arch/resnet_encoder.py
Ethanfel db64fc195a Initial commit: ComfyUI BIM-VFI node for video frame interpolation
Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video
frame interpolation with memory-safe sequential processing.

- LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive
- BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair
  GPU processing, configurable VRAM management (all_on_gpu for high-VRAM
  setups), progress bar, and backwarp cache clearing
- Vendored inference-only architecture from KAIST-VICLab/BiM-VFI
- Auto-detection of CUDA version for cupy installation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 18:26:49 +01:00

103 lines
3.2 KiB
Python

import torch.nn as nn
from typing import Optional, Callable
from torch import Tensor
from functools import partial
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=True,
dilation=dilation,
)
def conv2x2(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""2x2 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=2, stride=stride)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.InstanceNorm2d, data_format='channels_first')
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = norm_layer(inplanes)
if stride == 1:
self.conv1 = conv3x3(inplanes, planes, stride)
else:
self.conv1 = conv2x2(inplanes, planes, stride)
self.lrelu = nn.LeakyReLU(inplace=True)
self.bn2 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.lrelu(out)
out = self.bn2(out)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.lrelu(out)
out = out + identity
return out
class ResNetPyramid(nn.Module):
"""A 3-level feature pyramid, which by default is shared by the motion
estimator and synthesis network.
"""
def __init__(self, feat_channels: int):
super(ResNetPyramid, self).__init__()
self.conv = nn.Conv2d(3, feat_channels, kernel_size=3, stride=1, padding=1)
self.layer0 = nn.Sequential(
BasicBlock(feat_channels, feat_channels, norm_layer=nn.InstanceNorm2d),
)
self.layer1 = nn.Sequential(
nn.Conv2d(feat_channels, feat_channels * 2, 2, 2),
BasicBlock(feat_channels * 2, feat_channels * 2, norm_layer=nn.InstanceNorm2d),
)
self.layer2 = nn.Sequential(
nn.Conv2d(feat_channels * 2, feat_channels * 4, 2, 2),
BasicBlock(feat_channels * 4, feat_channels * 4, norm_layer=nn.InstanceNorm2d),
)
self.conv_last = nn.Conv2d(feat_channels * 4, feat_channels * 4, 3, 1, 1)
def forward(self, img):
C0 = self.layer0(self.conv(img))
C1 = self.layer1(C0)
C2 = self.conv_last(self.layer2(C1))
return [C0, C1, C2]