Add EMA-VFI (CVPR 2023) frame interpolation support

Integrate EMA-VFI alongside existing BIM-VFI with three new ComfyUI nodes:
Load EMA-VFI Model, EMA-VFI Interpolate, and EMA-VFI Segment Interpolate.

Architecture files vendored from MCG-NJU/EMA-VFI with device-awareness
fixes (removed hardcoded .cuda() calls), warp cache management, and
relative imports. InputPadder extended to support EMA-VFI's replicate
center-symmetric padding. Auto-installs timm dependency on first load.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 22:30:06 +01:00
parent 0133f61d47
commit 1de086569c
11 changed files with 1334 additions and 18 deletions

View File

@@ -1,5 +1,15 @@
import logging
from functools import partial
import torch
import torch.nn as nn
from .bim_vfi_arch import BiMVFI
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow
from .utils.padder import InputPadder
logger = logging.getLogger("BIM-VFI")
class BiMVFIModel:
@@ -112,3 +122,163 @@ class BiMVFIModel:
interp = result_dict["imgt_pred"]
interp = torch.clamp(interp, 0, 1)
return interp
# ---------------------------------------------------------------------------
# EMA-VFI model wrapper
# ---------------------------------------------------------------------------
def _ema_init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]):
"""Build EMA-VFI model config dicts (backbone + multiscale)."""
return {
'embed_dims': [F, 2*F, 4*F, 8*F, 16*F],
'motion_dims': [0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
'num_heads': [8*F//32, 16*F//32],
'mlp_ratios': [4, 4],
'qkv_bias': True,
'norm_layer': partial(nn.LayerNorm, eps=1e-6),
'depths': depth,
'window_sizes': [W, W]
}, {
'embed_dims': [F, 2*F, 4*F, 8*F, 16*F],
'motion_dims': [0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
'depths': depth,
'num_heads': [8*F//32, 16*F//32],
'window_sizes': [W, W],
'scales': [4, 8, 16],
'hidden_dims': [4*F, 4*F],
'c': F
}
def _ema_detect_variant(filename):
"""Auto-detect model variant and timestep support from filename.
Returns (F, depth, supports_arbitrary_t).
"""
name = filename.lower()
is_small = "small" in name
supports_t = "_t." in name or "_t_" in name or name.endswith("_t")
if is_small:
return 16, [2, 2, 2, 2, 2], supports_t
else:
return 32, [2, 2, 2, 4, 4], supports_t
class EMAVFIModel:
"""Clean inference wrapper around EMA-VFI for ComfyUI integration."""
def __init__(self, checkpoint_path, variant="auto", tta=False, device="cpu"):
import os
filename = os.path.basename(checkpoint_path)
if variant == "auto":
F_dim, depth, self.supports_arbitrary_t = _ema_detect_variant(filename)
elif variant == "small":
F_dim, depth = 16, [2, 2, 2, 2, 2]
self.supports_arbitrary_t = "_t." in filename.lower() or "_t_" in filename.lower()
else: # large
F_dim, depth = 32, [2, 2, 2, 4, 4]
self.supports_arbitrary_t = "_t." in filename.lower() or "_t_" in filename.lower()
self.tta = tta
self.device = device
self.variant_name = "small" if F_dim == 16 else "large"
backbone_cfg, multiscale_cfg = _ema_init_model_config(F=F_dim, depth=depth)
backbone = ema_feature_extractor(**backbone_cfg)
self.model = EMAMultiScaleFlow(backbone, **multiscale_cfg)
self._load_checkpoint(checkpoint_path)
self.model.eval()
self.model.to(device)
def _load_checkpoint(self, checkpoint_path):
"""Load checkpoint with module prefix stripping and buffer filtering."""
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Handle wrapped checkpoint formats
if isinstance(state_dict, dict):
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Strip "module." prefix and filter out attn_mask/HW buffers
cleaned = {}
for k, v in state_dict.items():
if "attn_mask" in k or k.endswith(".HW"):
continue
key = k
if key.startswith("module."):
key = key[len("module."):]
cleaned[key] = v
self.model.load_state_dict(cleaned)
def to(self, device):
"""Move model to device (returns self for chaining)."""
self.device = device
self.model.to(device)
return self
@torch.no_grad()
def _inference(self, img0, img1, timestep=0.5):
"""Run single inference pass. Inputs already padded, on device."""
B = img0.shape[0]
imgs = torch.cat((img0, img1), 1)
if self.tta:
imgs_ = imgs.flip(2).flip(3)
input_batch = torch.cat((imgs, imgs_), 0)
_, _, _, preds = self.model(input_batch, timestep=timestep)
return (preds[:B] + preds[B:].flip(2).flip(3)) / 2.
else:
_, _, _, pred = self.model(imgs, timestep=timestep)
return pred
@torch.no_grad()
def interpolate_pair(self, frame0, frame1, time_step=0.5):
"""Interpolate a single frame between two input frames.
Args:
frame0: [1, C, H, W] tensor, float32, range [0, 1]
frame1: [1, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1)
Returns:
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
img0 = frame0.to(device)
img1 = frame1.to(device)
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
img0, img1 = padder.pad(img0, img1)
pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1)
@torch.no_grad()
def interpolate_batch(self, frames0, frames1, time_step=0.5):
"""Interpolate multiple frame pairs at once.
Args:
frames0: [B, C, H, W] tensor, float32, range [0, 1]
frames1: [B, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1)
Returns:
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
img0 = frames0.to(device)
img1 = frames1.to(device)
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
img0, img1 = padder.pad(img0, img1)
pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1)