Add batch processing support for faster frame interpolation
Processes multiple frame pairs simultaneously instead of one-by-one. New batch_size input (1-64) lets users trade VRAM for speed. Refactored pyr_level logic into shared _get_pyr_level() helper. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
55
inference.py
55
inference.py
@@ -43,6 +43,18 @@ class BiMVFIModel:
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def _get_pyr_level(self, h):
|
||||
if self.auto_pyr_level:
|
||||
if h >= 2160:
|
||||
return 7
|
||||
elif h >= 1080:
|
||||
return 6
|
||||
elif h >= 540:
|
||||
return 5
|
||||
else:
|
||||
return 3
|
||||
return self.pyr_level
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
||||
"""Interpolate a single frame between two input frames.
|
||||
@@ -59,19 +71,36 @@ class BiMVFIModel:
|
||||
img0 = frame0.to(device)
|
||||
img1 = frame1.to(device)
|
||||
|
||||
if self.auto_pyr_level:
|
||||
_, _, h, _ = img0.shape
|
||||
if h >= 2160:
|
||||
pyr_level = 7
|
||||
elif h >= 1080:
|
||||
pyr_level = 6
|
||||
elif h >= 540:
|
||||
pyr_level = 5
|
||||
else:
|
||||
pyr_level = 3
|
||||
else:
|
||||
pyr_level = self.pyr_level
|
||||
|
||||
pyr_level = self._get_pyr_level(img0.shape[2])
|
||||
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||
|
||||
result_dict = self.model(
|
||||
img0=img0, img1=img1,
|
||||
time_step=time_step_tensor,
|
||||
pyr_level=pyr_level,
|
||||
)
|
||||
|
||||
interp = result_dict["imgt_pred"]
|
||||
interp = torch.clamp(interp, 0, 1)
|
||||
return interp
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
||||
"""Interpolate multiple frame pairs at once.
|
||||
|
||||
Args:
|
||||
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
||||
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
||||
time_step: float in (0, 1), temporal position of interpolated frames
|
||||
|
||||
Returns:
|
||||
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
||||
"""
|
||||
device = next(self.model.parameters()).device
|
||||
img0 = frames0.to(device)
|
||||
img1 = frames1.to(device)
|
||||
|
||||
pyr_level = self._get_pyr_level(img0.shape[2])
|
||||
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||
|
||||
result_dict = self.model(
|
||||
|
||||
Reference in New Issue
Block a user