Wraps BiM-VFI (CVPR 2025) as a ComfyUI custom node for long video frame interpolation with memory-safe sequential processing. - LoadBIMVFIModel: checkpoint loader with auto-download from Google Drive - BIMVFIInterpolate: 2x/4x/8x recursive interpolation with per-pair GPU processing, configurable VRAM management (all_on_gpu for high-VRAM setups), progress bar, and backwarp cache clearing - Vendored inference-only architecture from KAIST-VICLab/BiM-VFI - Auto-detection of CUDA version for cupy installation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
25 lines
1.5 KiB
Python
25 lines
1.5 KiB
Python
import torch
|
|
|
|
objBackwarpcache = {}
|
|
|
|
|
|
def clear_backwarp_cache():
|
|
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
|
|
objBackwarpcache.clear()
|
|
|
|
|
|
def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor, mode='bilinear'):
|
|
if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache:
|
|
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
|
|
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
|
|
objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1)
|
|
|
|
if tenFlow.shape[3] == tenFlow.shape[2]:
|
|
tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0))
|
|
|
|
elif tenFlow.shape[3] != tenFlow.shape[2]:
|
|
tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1)
|
|
|
|
|
|
return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode=mode, padding_mode='zeros', align_corners=True)
|