diff --git a/inference.py b/inference.py index d0cdcbb..599b068 100644 --- a/inference.py +++ b/inference.py @@ -43,6 +43,18 @@ class BiMVFIModel: self.model.to(device) return self + def _get_pyr_level(self, h): + if self.auto_pyr_level: + if h >= 2160: + return 7 + elif h >= 1080: + return 6 + elif h >= 540: + return 5 + else: + return 3 + return self.pyr_level + @torch.no_grad() def interpolate_pair(self, frame0, frame1, time_step=0.5): """Interpolate a single frame between two input frames. @@ -59,19 +71,36 @@ class BiMVFIModel: img0 = frame0.to(device) img1 = frame1.to(device) - if self.auto_pyr_level: - _, _, h, _ = img0.shape - if h >= 2160: - pyr_level = 7 - elif h >= 1080: - pyr_level = 6 - elif h >= 540: - pyr_level = 5 - else: - pyr_level = 3 - else: - pyr_level = self.pyr_level - + pyr_level = self._get_pyr_level(img0.shape[2]) + time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1) + + result_dict = self.model( + img0=img0, img1=img1, + time_step=time_step_tensor, + pyr_level=pyr_level, + ) + + interp = result_dict["imgt_pred"] + interp = torch.clamp(interp, 0, 1) + return interp + + @torch.no_grad() + def interpolate_batch(self, frames0, frames1, time_step=0.5): + """Interpolate multiple frame pairs at once. + + Args: + frames0: [B, C, H, W] tensor, float32, range [0, 1] + frames1: [B, C, H, W] tensor, float32, range [0, 1] + time_step: float in (0, 1), temporal position of interpolated frames + + Returns: + Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + img0 = frames0.to(device) + img1 = frames1.to(device) + + pyr_level = self._get_pyr_level(img0.shape[2]) time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1) result_dict = self.model( diff --git a/nodes.py b/nodes.py index 5df0eb3..a2bce66 100644 --- a/nodes.py +++ b/nodes.py @@ -119,6 +119,10 @@ class BIMVFIInterpolate: "default": False, "tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.", }), + "batch_size": ("INT", { + "default": 1, "min": 1, "max": 64, "step": 1, + "tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full. Recommended: 1 for 8GB, 2-4 for 24GB, 4-16 for 48GB+.", + }), } } @@ -127,7 +131,7 @@ class BIMVFIInterpolate: FUNCTION = "interpolate" CATEGORY = "video/BIM-VFI" - def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu): + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size): if images.shape[0] < 2: return (images,) @@ -161,26 +165,32 @@ class BIMVFIInterpolate: new_frames = [] num_pairs = frames.shape[0] - 1 - for i in range(num_pairs): - frame0 = frames[i:i+1] # [1, C, H, W] - frame1 = frames[i+1:i+2] # [1, C, H, W] + for i in range(0, num_pairs, batch_size): + batch_end = min(i + batch_size, num_pairs) + actual_batch = batch_end - i + + # Gather batch of pairs + frames0 = torch.cat([frames[j:j+1] for j in range(i, batch_end)], dim=0) + frames1 = torch.cat([frames[j+1:j+2] for j in range(i, batch_end)], dim=0) if not keep_device: model.to(device) - mid = model.interpolate_pair(frame0, frame1, time_step=0.5) - mid = mid.to(storage_device) + mids = model.interpolate_batch(frames0, frames1, time_step=0.5) + mids = mids.to(storage_device) if not keep_device: model.to("cpu") - new_frames.append(frames[i:i+1]) - new_frames.append(mid) + # Interleave: original frame, then interpolated frame + for j in range(actual_batch): + new_frames.append(frames[i + j:i + j + 1]) + new_frames.append(mids[j:j+1]) - step += 1 + step += actual_batch pbar.update_absolute(step, total_steps) - if not all_on_gpu and (i + 1) % clear_cache_after_n_frames == 0 and torch.cuda.is_available(): + if not all_on_gpu and (batch_end) % clear_cache_after_n_frames == 0 and torch.cuda.is_available(): clear_backwarp_cache() torch.cuda.empty_cache()