Add batch processing support for faster frame interpolation
Processes multiple frame pairs simultaneously instead of one-by-one. New batch_size input (1-64) lets users trade VRAM for speed. Refactored pyr_level logic into shared _get_pyr_level() helper. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
55
inference.py
55
inference.py
@@ -43,6 +43,18 @@ class BiMVFIModel:
|
|||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _get_pyr_level(self, h):
|
||||||
|
if self.auto_pyr_level:
|
||||||
|
if h >= 2160:
|
||||||
|
return 7
|
||||||
|
elif h >= 1080:
|
||||||
|
return 6
|
||||||
|
elif h >= 540:
|
||||||
|
return 5
|
||||||
|
else:
|
||||||
|
return 3
|
||||||
|
return self.pyr_level
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
||||||
"""Interpolate a single frame between two input frames.
|
"""Interpolate a single frame between two input frames.
|
||||||
@@ -59,19 +71,36 @@ class BiMVFIModel:
|
|||||||
img0 = frame0.to(device)
|
img0 = frame0.to(device)
|
||||||
img1 = frame1.to(device)
|
img1 = frame1.to(device)
|
||||||
|
|
||||||
if self.auto_pyr_level:
|
pyr_level = self._get_pyr_level(img0.shape[2])
|
||||||
_, _, h, _ = img0.shape
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||||
if h >= 2160:
|
|
||||||
pyr_level = 7
|
result_dict = self.model(
|
||||||
elif h >= 1080:
|
img0=img0, img1=img1,
|
||||||
pyr_level = 6
|
time_step=time_step_tensor,
|
||||||
elif h >= 540:
|
pyr_level=pyr_level,
|
||||||
pyr_level = 5
|
)
|
||||||
else:
|
|
||||||
pyr_level = 3
|
interp = result_dict["imgt_pred"]
|
||||||
else:
|
interp = torch.clamp(interp, 0, 1)
|
||||||
pyr_level = self.pyr_level
|
return interp
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
||||||
|
"""Interpolate multiple frame pairs at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
time_step: float in (0, 1), temporal position of interpolated frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
img0 = frames0.to(device)
|
||||||
|
img1 = frames1.to(device)
|
||||||
|
|
||||||
|
pyr_level = self._get_pyr_level(img0.shape[2])
|
||||||
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||||
|
|
||||||
result_dict = self.model(
|
result_dict = self.model(
|
||||||
|
|||||||
30
nodes.py
30
nodes.py
@@ -119,6 +119,10 @@ class BIMVFIInterpolate:
|
|||||||
"default": False,
|
"default": False,
|
||||||
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
|
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
|
||||||
}),
|
}),
|
||||||
|
"batch_size": ("INT", {
|
||||||
|
"default": 1, "min": 1, "max": 64, "step": 1,
|
||||||
|
"tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full. Recommended: 1 for 8GB, 2-4 for 24GB, 4-16 for 48GB+.",
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,7 +131,7 @@ class BIMVFIInterpolate:
|
|||||||
FUNCTION = "interpolate"
|
FUNCTION = "interpolate"
|
||||||
CATEGORY = "video/BIM-VFI"
|
CATEGORY = "video/BIM-VFI"
|
||||||
|
|
||||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu):
|
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size):
|
||||||
if images.shape[0] < 2:
|
if images.shape[0] < 2:
|
||||||
return (images,)
|
return (images,)
|
||||||
|
|
||||||
@@ -161,26 +165,32 @@ class BIMVFIInterpolate:
|
|||||||
new_frames = []
|
new_frames = []
|
||||||
num_pairs = frames.shape[0] - 1
|
num_pairs = frames.shape[0] - 1
|
||||||
|
|
||||||
for i in range(num_pairs):
|
for i in range(0, num_pairs, batch_size):
|
||||||
frame0 = frames[i:i+1] # [1, C, H, W]
|
batch_end = min(i + batch_size, num_pairs)
|
||||||
frame1 = frames[i+1:i+2] # [1, C, H, W]
|
actual_batch = batch_end - i
|
||||||
|
|
||||||
|
# Gather batch of pairs
|
||||||
|
frames0 = torch.cat([frames[j:j+1] for j in range(i, batch_end)], dim=0)
|
||||||
|
frames1 = torch.cat([frames[j+1:j+2] for j in range(i, batch_end)], dim=0)
|
||||||
|
|
||||||
if not keep_device:
|
if not keep_device:
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
mid = model.interpolate_pair(frame0, frame1, time_step=0.5)
|
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
|
||||||
mid = mid.to(storage_device)
|
mids = mids.to(storage_device)
|
||||||
|
|
||||||
if not keep_device:
|
if not keep_device:
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
|
|
||||||
new_frames.append(frames[i:i+1])
|
# Interleave: original frame, then interpolated frame
|
||||||
new_frames.append(mid)
|
for j in range(actual_batch):
|
||||||
|
new_frames.append(frames[i + j:i + j + 1])
|
||||||
|
new_frames.append(mids[j:j+1])
|
||||||
|
|
||||||
step += 1
|
step += actual_batch
|
||||||
pbar.update_absolute(step, total_steps)
|
pbar.update_absolute(step, total_steps)
|
||||||
|
|
||||||
if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
|
if not all_on_gpu and (batch_end) % clear_cache_after_n_frames == 0 and torch.cuda.is_available():
|
||||||
clear_backwarp_cache()
|
clear_backwarp_cache()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user