Add target FPS mode to all VFI models and remove concat preview

Add source_fps/target_fps inputs to all 4 Interpolate and Segment nodes
(BIM, EMA, SGM, GIMM). When target_fps > 0, auto-computes optimal
power-of-2 oversample, runs existing recursive t=0.5 interpolation,
then selects frames at target timestamps. Handles downsampling (no
model calls), same-fps passthrough, and high ratios (e.g. 3→30fps).
Segment boundary logic uses global index computation for gap-free
stitching. When target_fps=0, existing multiplier behavior is preserved.

Remove video preview from TweenConcatVideos: drop preview input,
delete web/js/tween_preview.js, remove WEB_DIRECTORY from __init__.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 22:51:04 +01:00
parent e253cb244e
commit 6dd579dcc7
3 changed files with 416 additions and 131 deletions

473
nodes.py
View File

@@ -1,3 +1,4 @@
import math
import os
import glob
import logging
@@ -16,6 +17,35 @@ from .gimm_vfi_arch import clear_gimm_caches
logger = logging.getLogger("Tween")
def _compute_target_fps_params(source_fps, target_fps):
"""Compute oversampling parameters for target FPS mode.
Returns (num_passes, mult) where mult = 2^num_passes is the power-of-2
multiplier needed to oversample above the target ratio.
"""
ratio = target_fps / source_fps
if ratio <= 1.0:
return 0, 1 # no interpolation needed (downsampling or same fps)
num_passes = math.ceil(math.log2(ratio))
mult = 2 ** num_passes
return num_passes, mult
def _select_target_fps_frames(frames, source_fps, target_fps, mult, num_input):
"""Pick frames from oversampled [M,C,H,W] tensor to hit target FPS timing.
For downsampling (mult=1, ratio<=1), selects from original input frames.
For upsampling, selects from the oversampled sequence at target timestamps.
"""
duration = (num_input - 1) / source_fps
num_output = int(math.floor(duration * target_fps)) + 1
oversampled_fps = source_fps * mult
max_idx = frames.shape[0] - 1
indices = [min(round(j / target_fps * oversampled_fps), max_idx) for j in range(num_output)]
return frames[indices]
# Google Drive file ID for the pretrained BIM-VFI model
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
MODEL_FILENAME = "bim_vfi.pth"
@@ -161,6 +191,14 @@ class BIMVFIInterpolate:
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead. Result is identical to processing all at once.",
}),
"source_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Input frame rate. Required when target_fps > 0.",
}),
"target_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
}
@@ -234,12 +272,25 @@ class BIMVFIInterpolate:
return total
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size):
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
if images.shape[0] < 2:
return (images,)
# Target FPS mode: auto-compute multiplier from fps ratio
use_target_fps = target_fps > 0 and source_fps > 0
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
if num_passes == 0:
# Downsampling or same fps — select from input directly
all_frames = images.permute(0, 3, 1, 2)
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
return (result.cpu().permute(0, 2, 3, 1),)
else:
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
mult = multiplier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
@@ -292,6 +343,11 @@ class BIMVFIInterpolate:
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Target FPS: select frames from oversampled result
if use_target_fps:
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
@@ -328,8 +384,10 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size):
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
@@ -340,9 +398,69 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
# Compute global output frame range for this segment
seg_start_time = start / source_fps
seg_end_time = (end - 1) / source_fps
duration = (total_input - 1) / source_fps
total_output = int(math.floor(duration * target_fps)) + 1
if segment_index == 0:
j_start = 0
else:
j_start = int(math.floor(seg_start_time * target_fps)) + 1
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
if j_start > j_end:
return (images[:1], model)
if num_passes == 0:
# Downsampling — select from segment input directly
oversampled_fps = source_fps * mult
all_seg = segment_images.permute(0, 3, 1, 2)
out_frames = []
for j in range(j_start, j_end + 1):
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
local_idx = global_idx - start
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
out_frames.append(all_seg[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Oversample segment using computed num_passes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
oversampled = self._interpolate_frames(
seg_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
oversampled_fps = source_fps * mult
out_frames = []
for j in range(j_start, j_end + 1):
global_oversamp_idx = round(j / target_fps * oversampled_fps)
local_idx = global_oversamp_idx - start * mult
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
out_frames.append(oversampled[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Standard multiplier mode
is_continuation = segment_index > 0
(result,) = super().interpolate(
segment_images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
@@ -389,11 +507,6 @@ class TweenConcatVideos:
"tooltip": "Delete the individual segment files after successful concatenation. "
"Useful to avoid leftover files that would pollute the next run.",
}),
"preview": ("BOOLEAN", {
"default": True,
"tooltip": "Show the concatenated video as a preview on the node. "
"Disable to skip the preview widget.",
}),
}
}
@@ -418,7 +531,7 @@ class TweenConcatVideos:
)
return ffmpeg_path
def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments, preview):
def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments):
# Resolve output directory — empty or relative paths are relative to ComfyUI output
comfy_output = folder_paths.get_output_directory()
out_dir = output_directory.strip()
@@ -494,28 +607,7 @@ class TweenConcatVideos:
if os.path.exists(concat_list_path):
os.remove(concat_list_path)
result = {"result": (output_path,)}
if preview:
# Preview only works when the file is inside ComfyUI's output tree
abs_out = os.path.abspath(out_dir)
abs_comfy = os.path.abspath(comfy_output)
if abs_out.startswith(abs_comfy + os.sep) or abs_out == abs_comfy:
subfolder = os.path.relpath(abs_out, abs_comfy) if abs_out != abs_comfy else ""
result["ui"] = {
"gifs": [{
"filename": os.path.basename(output_path),
"subfolder": subfolder,
"type": "output",
"format": "video/mp4",
}]
}
else:
logger.warning(
f"Video preview skipped: {out_dir} is outside ComfyUI output directory"
)
return result
return {"result": (output_path,)}
# ---------------------------------------------------------------------------
@@ -634,6 +726,14 @@ class EMAVFIInterpolate:
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
"source_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Input frame rate. Required when target_fps > 0.",
}),
"target_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
}
@@ -700,12 +800,24 @@ class EMAVFIInterpolate:
return total
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size):
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
if images.shape[0] < 2:
return (images,)
# Target FPS mode: auto-compute multiplier from fps ratio
use_target_fps = target_fps > 0 and source_fps > 0
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
if num_passes == 0:
all_frames = images.permute(0, 3, 1, 2)
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
return (result.cpu().permute(0, 2, 3, 1),)
else:
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
mult = multiplier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
@@ -758,6 +870,11 @@ class EMAVFIInterpolate:
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Target FPS: select frames from oversampled result
if use_target_fps:
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
@@ -794,28 +911,87 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size):
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
seg_start_time = start / source_fps
seg_end_time = (end - 1) / source_fps
duration = (total_input - 1) / source_fps
total_output = int(math.floor(duration * target_fps)) + 1
if segment_index == 0:
j_start = 0
else:
j_start = int(math.floor(seg_start_time * target_fps)) + 1
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
if j_start > j_end:
return (images[:1], model)
if num_passes == 0:
oversampled_fps = source_fps * mult
all_seg = segment_images.permute(0, 3, 1, 2)
out_frames = []
for j in range(j_start, j_end + 1):
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
local_idx = global_idx - start
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
out_frames.append(all_seg[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Oversample segment using computed num_passes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
oversampled = self._interpolate_frames(
seg_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
oversampled_fps = source_fps * mult
out_frames = []
for j in range(j_start, j_end + 1):
global_oversamp_idx = round(j / target_fps * oversampled_fps)
local_idx = global_oversamp_idx - start * mult
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
out_frames.append(oversampled[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Standard multiplier mode
is_continuation = segment_index > 0
(result,) = super().interpolate(
segment_images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
result = result[1:]
return (result, model)
@@ -941,6 +1117,14 @@ class SGMVFIInterpolate:
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
"source_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Input frame rate. Required when target_fps > 0.",
}),
"target_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
}
@@ -1007,12 +1191,24 @@ class SGMVFIInterpolate:
return total
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size):
keep_device, all_on_gpu, batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
if images.shape[0] < 2:
return (images,)
# Target FPS mode: auto-compute multiplier from fps ratio
use_target_fps = target_fps > 0 and source_fps > 0
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
if num_passes == 0:
all_frames = images.permute(0, 3, 1, 2)
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
return (result.cpu().permute(0, 2, 3, 1),)
else:
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
mult = multiplier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
@@ -1065,6 +1261,11 @@ class SGMVFIInterpolate:
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Target FPS: select frames from oversampled result
if use_target_fps:
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
@@ -1101,28 +1302,87 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size):
segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
seg_start_time = start / source_fps
seg_end_time = (end - 1) / source_fps
duration = (total_input - 1) / source_fps
total_output = int(math.floor(duration * target_fps)) + 1
if segment_index == 0:
j_start = 0
else:
j_start = int(math.floor(seg_start_time * target_fps)) + 1
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
if j_start > j_end:
return (images[:1], model)
if num_passes == 0:
oversampled_fps = source_fps * mult
all_seg = segment_images.permute(0, 3, 1, 2)
out_frames = []
for j in range(j_start, j_end + 1):
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
local_idx = global_idx - start
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
out_frames.append(all_seg[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Oversample segment using computed num_passes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
oversampled = self._interpolate_frames(
seg_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
oversampled_fps = source_fps * mult
out_frames = []
for j in range(j_start, j_end + 1):
global_oversamp_idx = round(j / target_fps * oversampled_fps)
local_idx = global_oversamp_idx - start * mult
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
out_frames.append(oversampled[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Standard multiplier mode
is_continuation = segment_index > 0
(result,) = super().interpolate(
segment_images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
result = result[1:]
return (result, model)
@@ -1265,6 +1525,14 @@ class GIMMVFIInterpolate:
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
"source_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Input frame rate. Required when target_fps > 0.",
}),
"target_fps": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
}),
}
}
@@ -1376,14 +1644,31 @@ class GIMMVFIInterpolate:
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size):
batch_size, chunk_size,
source_fps=0.0, target_fps=0.0):
if images.shape[0] < 2:
return (images,)
# Target FPS mode: auto-compute multiplier from fps ratio
use_target_fps = target_fps > 0 and source_fps > 0
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
if num_passes == 0:
all_frames = images.permute(0, 3, 1, 2)
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
return (result.cpu().permute(0, 2, 3, 1),)
# Override multiplier for single_pass mode
multiplier = mult
else:
mult = multiplier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not single_pass:
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if not single_pass or use_target_fps:
if use_target_fps:
num_passes_recursive = num_passes
else:
num_passes_recursive = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
@@ -1411,7 +1696,7 @@ class GIMMVFIInterpolate:
if single_pass:
total_steps = sum(ce - cs - 1 for cs, ce in chunks)
else:
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
total_steps = sum(self._count_steps(ce - cs, num_passes_recursive) for cs, ce in chunks)
pbar = ProgressBar(total_steps)
step_ref = [0]
@@ -1430,7 +1715,7 @@ class GIMMVFIInterpolate:
)
else:
chunk_result = self._interpolate_frames(
chunk_frames, model, num_passes, batch_size,
chunk_frames, model, num_passes_recursive, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
@@ -1446,6 +1731,11 @@ class GIMMVFIInterpolate:
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Target FPS: select frames from oversampled result
if use_target_fps:
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
@@ -1482,21 +1772,90 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
def interpolate(self, images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
batch_size, chunk_size, segment_index, segment_size):
batch_size, chunk_size, segment_index, segment_size,
source_fps=0.0, target_fps=0.0):
total_input = images.shape[0]
use_target_fps = target_fps > 0 and source_fps > 0
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
if use_target_fps:
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
seg_start_time = start / source_fps
seg_end_time = (end - 1) / source_fps
duration = (total_input - 1) / source_fps
total_output = int(math.floor(duration * target_fps)) + 1
if segment_index == 0:
j_start = 0
else:
j_start = int(math.floor(seg_start_time * target_fps)) + 1
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
if j_start > j_end:
return (images[:1], model)
if num_passes == 0:
oversampled_fps = source_fps * mult
all_seg = segment_images.permute(0, 3, 1, 2)
out_frames = []
for j in range(j_start, j_end + 1):
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
local_idx = global_idx - start
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
out_frames.append(all_seg[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Oversample segment directly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
if single_pass:
total_steps = seg_frames.shape[0] - 1
else:
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
if single_pass:
oversampled = self._interpolate_frames_single_pass(
seg_frames, model, mult,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
else:
oversampled = self._interpolate_frames(
seg_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
oversampled_fps = source_fps * mult
out_frames = []
for j in range(j_start, j_end + 1):
global_oversamp_idx = round(j / target_fps * oversampled_fps)
local_idx = global_oversamp_idx - start * mult
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
out_frames.append(oversampled[local_idx:local_idx + 1])
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
return (result, model)
# Standard multiplier mode
is_continuation = segment_index > 0
(result,) = super().interpolate(
segment_images, model, multiplier, single_pass,
clear_cache_after_n_frames, keep_device, all_on_gpu,
@@ -1504,6 +1863,6 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
result = result[1:]
return (result, model)