Add batch processing support for faster frame interpolation

Processes multiple frame pairs simultaneously instead of one-by-one.
New batch_size input (1-64) lets users trade VRAM for speed.
Refactored pyr_level logic into shared _get_pyr_level() helper.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 18:54:40 +01:00
parent 69a4aebfe7
commit 993a3a72b1
2 changed files with 62 additions and 23 deletions

View File

@@ -119,6 +119,10 @@ class BIMVFIInterpolate:
"default": False,
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
}),
"batch_size": ("INT", {
"default": 1, "min": 1, "max": 64, "step": 1,
"tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full. Recommended: 1 for 8GB, 2-4 for 24GB, 4-16 for 48GB+.",
}),
}
}
@@ -127,7 +131,7 @@ class BIMVFIInterpolate:
FUNCTION = "interpolate"
CATEGORY = "video/BIM-VFI"
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size):
if images.shape[0] < 2:
return (images,)
@@ -161,26 +165,32 @@ class BIMVFIInterpolate:
new_frames = []
num_pairs = frames.shape[0] - 1
for i in range(num_pairs):
frame0 = frames[i:i+1] # [1, C, H, W]
frame1 = frames[i+1:i+2] # [1, C, H, W]
for i in range(0, num_pairs, batch_size):
batch_end = min(i + batch_size, num_pairs)
actual_batch = batch_end - i
# Gather batch of pairs
frames0 = torch.cat([frames[j:j+1] for j in range(i, batch_end)], dim=0)
frames1 = torch.cat([frames[j+1:j+2] for j in range(i, batch_end)], dim=0)
if not keep_device:
model.to(device)
mid = model.interpolate_pair(frame0, frame1, time_step=0.5)
mid = mid.to(storage_device)
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
mids = mids.to(storage_device)
if not keep_device:
model.to("cpu")
new_frames.append(frames[i:i+1])
new_frames.append(mid)
# Interleave: original frame, then interpolated frame
for j in range(actual_batch):
new_frames.append(frames[i + j:i + j + 1])
new_frames.append(mids[j:j+1])
step += 1
step += actual_batch
pbar.update_absolute(step, total_steps)
if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
if not all_on_gpu and (batch_end) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
clear_backwarp_cache()
torch.cuda.empty_cache()