Add EMA-VFI (CVPR 2023) frame interpolation support

Integrate EMA-VFI alongside existing BIM-VFI with three new ComfyUI nodes:
Load EMA-VFI Model, EMA-VFI Interpolate, and EMA-VFI Segment Interpolate.

Architecture files vendored from MCG-NJU/EMA-VFI with device-awareness
fixes (removed hardcoded .cuda() calls), warp cache management, and
relative imports. InputPadder extended to support EMA-VFI's replicate
center-symmetric padding. Auto-installs timm dependency on first load.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 22:30:06 +01:00
parent 0133f61d47
commit 1de086569c
11 changed files with 1334 additions and 18 deletions

317
nodes.py
View File

@@ -8,20 +8,29 @@ import torch
import folder_paths
from comfy.utils import ProgressBar
from .inference import BiMVFIModel
from .inference import BiMVFIModel, EMAVFIModel
from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
logger = logging.getLogger("BIM-VFI")
# Google Drive file ID for the pretrained model
# Google Drive file ID for the pretrained BIM-VFI model
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
MODEL_FILENAME = "bim_vfi.pth"
# Register the model folder with ComfyUI
# Google Drive folder ID for EMA-VFI pretrained models
EMA_GDRIVE_FOLDER_ID = "16jUa3HkQ85Z5lb5gce1yoaWkP-rdCd0o"
EMA_DEFAULT_MODEL = "ours_t.pkl"
# Register model folders with ComfyUI
MODEL_DIR = os.path.join(folder_paths.models_dir, "bim-vfi")
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR, exist_ok=True)
EMA_MODEL_DIR = os.path.join(folder_paths.models_dir, "ema-vfi")
if not os.path.exists(EMA_MODEL_DIR):
os.makedirs(EMA_MODEL_DIR, exist_ok=True)
def get_available_models():
"""List available checkpoint files in the bim-vfi model directory."""
@@ -456,3 +465,305 @@ class BIMVFIConcatVideos:
os.remove(concat_list_path)
return (output_path,)
# ---------------------------------------------------------------------------
# EMA-VFI nodes
# ---------------------------------------------------------------------------
def get_available_ema_models():
"""List available checkpoint files in the ema-vfi model directory."""
models = []
if os.path.isdir(EMA_MODEL_DIR):
for f in os.listdir(EMA_MODEL_DIR):
if f.endswith((".pkl", ".pth", ".pt", ".ckpt", ".safetensors")):
models.append(f)
if not models:
models.append(EMA_DEFAULT_MODEL) # Will trigger auto-download
return sorted(models)
def download_ema_model_from_gdrive(folder_id, dest_path):
"""Download EMA-VFI model from Google Drive folder using gdown."""
try:
import gdown
except ImportError:
raise RuntimeError(
"gdown is required to auto-download the EMA-VFI model. "
"Install it with: pip install gdown"
)
filename = os.path.basename(dest_path)
url = f"https://drive.google.com/drive/folders/{folder_id}"
logger.info(f"Downloading {filename} from Google Drive folder to {dest_path}...")
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
gdown.download_folder(url, output=os.path.dirname(dest_path), quiet=False, remaining_ok=True)
if not os.path.exists(dest_path):
raise RuntimeError(
f"Failed to download {filename}. Please download manually from "
f"https://drive.google.com/drive/folders/{folder_id} "
f"and place it in {os.path.dirname(dest_path)}"
)
logger.info("Download complete.")
class LoadEMAVFIModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": (get_available_ema_models(), {
"default": EMA_DEFAULT_MODEL,
"tooltip": "Checkpoint file from models/ema-vfi/. Auto-downloads on first use if missing. "
"Variant (large/small) and timestep support are auto-detected from filename.",
}),
"tta": ("BOOLEAN", {
"default": False,
"tooltip": "Test-time augmentation: flip input and average with unflipped result. "
"~2x slower but slightly better quality. Recommended for large model only.",
}),
}
}
RETURN_TYPES = ("EMA_VFI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/EMA-VFI"
def load_model(self, model_path, tta):
full_path = os.path.join(EMA_MODEL_DIR, model_path)
if not os.path.exists(full_path):
logger.info(f"Model not found at {full_path}, attempting download...")
download_ema_model_from_gdrive(EMA_GDRIVE_FOLDER_ID, full_path)
wrapper = EMAVFIModel(
checkpoint_path=full_path,
variant="auto",
tta=tta,
device="cpu",
)
t_mode = "arbitrary" if wrapper.supports_arbitrary_t else "fixed (0.5)"
logger.info(f"EMA-VFI model loaded (variant={wrapper.variant_name}, timestep={t_mode}, tta={tta})")
return (wrapper,)
class EMAVFIInterpolate:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
}),
"model": ("EMA_VFI_MODEL", {
"tooltip": "EMA-VFI model from the Load EMA-VFI Model node.",
}),
"multiplier": ([2, 4, 8], {
"default": 2,
"tooltip": "Frame rate multiplier. 2x=one interpolation pass, 4x=two recursive passes, 8x=three. Higher = more frames but longer processing.",
}),
"clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
}),
"keep_device": ("BOOLEAN", {
"default": True,
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
}),
"all_on_gpu": ("BOOLEAN", {
"default": False,
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
}),
"batch_size": ("INT", {
"default": 1, "min": 1, "max": 64, "step": 1,
"tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full.",
}),
"chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "interpolate"
CATEGORY = "video/EMA-VFI"
def _interpolate_frames(self, frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref):
"""Run all interpolation passes on a chunk of frames."""
for pass_idx in range(num_passes):
new_frames = []
num_pairs = frames.shape[0] - 1
pairs_since_clear = 0
for i in range(0, num_pairs, batch_size):
batch_end = min(i + batch_size, num_pairs)
actual_batch = batch_end - i
frames0 = frames[i:batch_end]
frames1 = frames[i + 1:batch_end + 1]
if not keep_device:
model.to(device)
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
mids = mids.to(storage_device)
if not keep_device:
model.to("cpu")
for j in range(actual_batch):
new_frames.append(frames[i + j:i + j + 1])
new_frames.append(mids[j:j+1])
step_ref[0] += actual_batch
pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_ema_warp_cache()
torch.cuda.empty_cache()
pairs_since_clear = 0
new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available():
clear_ema_warp_cache()
torch.cuda.empty_cache()
return frames
@staticmethod
def _count_steps(num_frames, num_passes):
"""Count total interpolation steps for a given input frame count."""
n = num_frames
total = 0
for _ in range(num_passes):
total += n - 1
n = 2 * n - 1
return total
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size):
if images.shape[0] < 2:
return (images,)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
total_input = all_frames.shape[0]
# Build chunk boundaries (1-frame overlap between consecutive chunks)
if chunk_size < 2 or chunk_size >= total_input:
chunks = [(0, total_input)]
else:
chunks = []
start = 0
while start < total_input - 1:
end = min(start + chunk_size, total_input)
chunks.append((start, end))
start = end - 1 # overlap by 1 frame
if end == total_input:
break
# Calculate total progress steps across all chunks
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
result_chunks = []
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
chunk_frames = all_frames[chunk_start:chunk_end].clone()
chunk_result = self._interpolate_frames(
chunk_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
if chunk_idx > 0:
chunk_result = chunk_result[1:]
# Move completed chunk to CPU to bound memory when chunking
if len(chunks) > 1:
chunk_result = chunk_result.cpu()
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
class EMAVFISegmentInterpolate(EMAVFIInterpolate):
"""Process a numbered segment of the input batch for EMA-VFI.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through output forces sequential execution so each segment
saves and frees from RAM before the next starts.
"""
@classmethod
def INPUT_TYPES(cls):
base = EMAVFIInterpolate.INPUT_TYPES()
base["required"]["segment_index"] = ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
"Chain the model output to the next Segment Interpolate to force sequential execution.",
})
base["required"]["segment_size"] = ("INT", {
"default": 500, "min": 2, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
})
return base
RETURN_TYPES = ("IMAGE", "EMA_VFI_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "interpolate"
CATEGORY = "video/EMA-VFI"
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size):
total_input = images.shape[0]
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
(result,) = super().interpolate(
segment_images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
return (result, model)