Add target FPS mode to all VFI models and remove concat preview
Add source_fps/target_fps inputs to all 4 Interpolate and Segment nodes (BIM, EMA, SGM, GIMM). When target_fps > 0, auto-computes optimal power-of-2 oversample, runs existing recursive t=0.5 interpolation, then selects frames at target timestamps. Handles downsampling (no model calls), same-fps passthrough, and high ratios (e.g. 3→30fps). Segment boundary logic uses global index computation for gap-free stitching. When target_fps=0, existing multiplier behavior is preserved. Remove video preview from TweenConcatVideos: drop preview input, delete web/js/tween_preview.js, remove WEB_DIRECTORY from __init__.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -52,8 +52,6 @@ from .nodes import (
|
||||
LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate,
|
||||
)
|
||||
|
||||
WEB_DIRECTORY = "./web"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoadBIMVFIModel": LoadBIMVFIModel,
|
||||
"BIMVFIInterpolate": BIMVFIInterpolate,
|
||||
|
||||
473
nodes.py
473
nodes.py
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
@@ -16,6 +17,35 @@ from .gimm_vfi_arch import clear_gimm_caches
|
||||
|
||||
logger = logging.getLogger("Tween")
|
||||
|
||||
|
||||
def _compute_target_fps_params(source_fps, target_fps):
|
||||
"""Compute oversampling parameters for target FPS mode.
|
||||
|
||||
Returns (num_passes, mult) where mult = 2^num_passes is the power-of-2
|
||||
multiplier needed to oversample above the target ratio.
|
||||
"""
|
||||
ratio = target_fps / source_fps
|
||||
if ratio <= 1.0:
|
||||
return 0, 1 # no interpolation needed (downsampling or same fps)
|
||||
num_passes = math.ceil(math.log2(ratio))
|
||||
mult = 2 ** num_passes
|
||||
return num_passes, mult
|
||||
|
||||
|
||||
def _select_target_fps_frames(frames, source_fps, target_fps, mult, num_input):
|
||||
"""Pick frames from oversampled [M,C,H,W] tensor to hit target FPS timing.
|
||||
|
||||
For downsampling (mult=1, ratio<=1), selects from original input frames.
|
||||
For upsampling, selects from the oversampled sequence at target timestamps.
|
||||
"""
|
||||
duration = (num_input - 1) / source_fps
|
||||
num_output = int(math.floor(duration * target_fps)) + 1
|
||||
oversampled_fps = source_fps * mult
|
||||
max_idx = frames.shape[0] - 1
|
||||
indices = [min(round(j / target_fps * oversampled_fps), max_idx) for j in range(num_output)]
|
||||
return frames[indices]
|
||||
|
||||
|
||||
# Google Drive file ID for the pretrained BIM-VFI model
|
||||
GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC"
|
||||
MODEL_FILENAME = "bim_vfi.pth"
|
||||
@@ -161,6 +191,14 @@ class BIMVFIInterpolate:
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead. Result is identical to processing all at once.",
|
||||
}),
|
||||
"source_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Input frame rate. Required when target_fps > 0.",
|
||||
}),
|
||||
"target_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,12 +272,25 @@ class BIMVFIInterpolate:
|
||||
return total
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size):
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
# Target FPS mode: auto-compute multiplier from fps ratio
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
if num_passes == 0:
|
||||
# Downsampling or same fps — select from input directly
|
||||
all_frames = images.permute(0, 3, 1, 2)
|
||||
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
|
||||
return (result.cpu().permute(0, 2, 3, 1),)
|
||||
else:
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
mult = multiplier
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
@@ -292,6 +343,11 @@ class BIMVFIInterpolate:
|
||||
result_chunks.append(chunk_result)
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
|
||||
# Target FPS: select frames from oversampled result
|
||||
if use_target_fps:
|
||||
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
|
||||
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||
result = result.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
@@ -328,8 +384,10 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size):
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
# Compute segment boundaries (1-frame overlap)
|
||||
start = segment_index * (segment_size - 1)
|
||||
@@ -340,9 +398,69 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate):
|
||||
return (images[:1], model)
|
||||
|
||||
segment_images = images[start:end]
|
||||
is_continuation = segment_index > 0
|
||||
|
||||
# Delegate to the parent interpolation logic
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
|
||||
# Compute global output frame range for this segment
|
||||
seg_start_time = start / source_fps
|
||||
seg_end_time = (end - 1) / source_fps
|
||||
duration = (total_input - 1) / source_fps
|
||||
total_output = int(math.floor(duration * target_fps)) + 1
|
||||
|
||||
if segment_index == 0:
|
||||
j_start = 0
|
||||
else:
|
||||
j_start = int(math.floor(seg_start_time * target_fps)) + 1
|
||||
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
|
||||
|
||||
if j_start > j_end:
|
||||
return (images[:1], model)
|
||||
|
||||
if num_passes == 0:
|
||||
# Downsampling — select from segment input directly
|
||||
oversampled_fps = source_fps * mult
|
||||
all_seg = segment_images.permute(0, 3, 1, 2)
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
|
||||
local_idx = global_idx - start
|
||||
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
|
||||
out_frames.append(all_seg[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Oversample segment using computed num_passes
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
|
||||
|
||||
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
oversampled = self._interpolate_frames(
|
||||
seg_frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
oversampled_fps = source_fps * mult
|
||||
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_oversamp_idx = round(j / target_fps * oversampled_fps)
|
||||
local_idx = global_oversamp_idx - start * mult
|
||||
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
|
||||
out_frames.append(oversampled[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Standard multiplier mode
|
||||
is_continuation = segment_index > 0
|
||||
(result,) = super().interpolate(
|
||||
segment_images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
@@ -389,11 +507,6 @@ class TweenConcatVideos:
|
||||
"tooltip": "Delete the individual segment files after successful concatenation. "
|
||||
"Useful to avoid leftover files that would pollute the next run.",
|
||||
}),
|
||||
"preview": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Show the concatenated video as a preview on the node. "
|
||||
"Disable to skip the preview widget.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,7 +531,7 @@ class TweenConcatVideos:
|
||||
)
|
||||
return ffmpeg_path
|
||||
|
||||
def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments, preview):
|
||||
def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments):
|
||||
# Resolve output directory — empty or relative paths are relative to ComfyUI output
|
||||
comfy_output = folder_paths.get_output_directory()
|
||||
out_dir = output_directory.strip()
|
||||
@@ -494,28 +607,7 @@ class TweenConcatVideos:
|
||||
if os.path.exists(concat_list_path):
|
||||
os.remove(concat_list_path)
|
||||
|
||||
result = {"result": (output_path,)}
|
||||
|
||||
if preview:
|
||||
# Preview only works when the file is inside ComfyUI's output tree
|
||||
abs_out = os.path.abspath(out_dir)
|
||||
abs_comfy = os.path.abspath(comfy_output)
|
||||
if abs_out.startswith(abs_comfy + os.sep) or abs_out == abs_comfy:
|
||||
subfolder = os.path.relpath(abs_out, abs_comfy) if abs_out != abs_comfy else ""
|
||||
result["ui"] = {
|
||||
"gifs": [{
|
||||
"filename": os.path.basename(output_path),
|
||||
"subfolder": subfolder,
|
||||
"type": "output",
|
||||
"format": "video/mp4",
|
||||
}]
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Video preview skipped: {out_dir} is outside ComfyUI output directory"
|
||||
)
|
||||
|
||||
return result
|
||||
return {"result": (output_path,)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -634,6 +726,14 @@ class EMAVFIInterpolate:
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||
}),
|
||||
"source_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Input frame rate. Required when target_fps > 0.",
|
||||
}),
|
||||
"target_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,12 +800,24 @@ class EMAVFIInterpolate:
|
||||
return total
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size):
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
# Target FPS mode: auto-compute multiplier from fps ratio
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
if num_passes == 0:
|
||||
all_frames = images.permute(0, 3, 1, 2)
|
||||
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
|
||||
return (result.cpu().permute(0, 2, 3, 1),)
|
||||
else:
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
mult = multiplier
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
@@ -758,6 +870,11 @@ class EMAVFIInterpolate:
|
||||
result_chunks.append(chunk_result)
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
|
||||
# Target FPS: select frames from oversampled result
|
||||
if use_target_fps:
|
||||
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
|
||||
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||
result = result.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
@@ -794,28 +911,87 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size):
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
# Compute segment boundaries (1-frame overlap)
|
||||
start = segment_index * (segment_size - 1)
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input - 1:
|
||||
# Past the end — return empty single frame + model
|
||||
return (images[:1], model)
|
||||
|
||||
segment_images = images[start:end]
|
||||
is_continuation = segment_index > 0
|
||||
|
||||
# Delegate to the parent interpolation logic
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
|
||||
seg_start_time = start / source_fps
|
||||
seg_end_time = (end - 1) / source_fps
|
||||
duration = (total_input - 1) / source_fps
|
||||
total_output = int(math.floor(duration * target_fps)) + 1
|
||||
|
||||
if segment_index == 0:
|
||||
j_start = 0
|
||||
else:
|
||||
j_start = int(math.floor(seg_start_time * target_fps)) + 1
|
||||
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
|
||||
|
||||
if j_start > j_end:
|
||||
return (images[:1], model)
|
||||
|
||||
if num_passes == 0:
|
||||
oversampled_fps = source_fps * mult
|
||||
all_seg = segment_images.permute(0, 3, 1, 2)
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
|
||||
local_idx = global_idx - start
|
||||
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
|
||||
out_frames.append(all_seg[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Oversample segment using computed num_passes
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
|
||||
|
||||
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
oversampled = self._interpolate_frames(
|
||||
seg_frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
oversampled_fps = source_fps * mult
|
||||
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_oversamp_idx = round(j / target_fps * oversampled_fps)
|
||||
local_idx = global_oversamp_idx - start * mult
|
||||
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
|
||||
out_frames.append(oversampled[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Standard multiplier mode
|
||||
is_continuation = segment_index > 0
|
||||
(result,) = super().interpolate(
|
||||
segment_images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
)
|
||||
|
||||
if is_continuation:
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
result = result[1:]
|
||||
|
||||
return (result, model)
|
||||
|
||||
@@ -941,6 +1117,14 @@ class SGMVFIInterpolate:
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||
}),
|
||||
"source_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Input frame rate. Required when target_fps > 0.",
|
||||
}),
|
||||
"target_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1007,12 +1191,24 @@ class SGMVFIInterpolate:
|
||||
return total
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size):
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
# Target FPS mode: auto-compute multiplier from fps ratio
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
if num_passes == 0:
|
||||
all_frames = images.permute(0, 3, 1, 2)
|
||||
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
|
||||
return (result.cpu().permute(0, 2, 3, 1),)
|
||||
else:
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
mult = multiplier
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
@@ -1065,6 +1261,11 @@ class SGMVFIInterpolate:
|
||||
result_chunks.append(chunk_result)
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
|
||||
# Target FPS: select frames from oversampled result
|
||||
if use_target_fps:
|
||||
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
|
||||
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||
result = result.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
@@ -1101,28 +1302,87 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate):
|
||||
|
||||
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
segment_index, segment_size):
|
||||
segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
# Compute segment boundaries (1-frame overlap)
|
||||
start = segment_index * (segment_size - 1)
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input - 1:
|
||||
# Past the end — return empty single frame + model
|
||||
return (images[:1], model)
|
||||
|
||||
segment_images = images[start:end]
|
||||
is_continuation = segment_index > 0
|
||||
|
||||
# Delegate to the parent interpolation logic
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
|
||||
seg_start_time = start / source_fps
|
||||
seg_end_time = (end - 1) / source_fps
|
||||
duration = (total_input - 1) / source_fps
|
||||
total_output = int(math.floor(duration * target_fps)) + 1
|
||||
|
||||
if segment_index == 0:
|
||||
j_start = 0
|
||||
else:
|
||||
j_start = int(math.floor(seg_start_time * target_fps)) + 1
|
||||
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
|
||||
|
||||
if j_start > j_end:
|
||||
return (images[:1], model)
|
||||
|
||||
if num_passes == 0:
|
||||
oversampled_fps = source_fps * mult
|
||||
all_seg = segment_images.permute(0, 3, 1, 2)
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
|
||||
local_idx = global_idx - start
|
||||
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
|
||||
out_frames.append(all_seg[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Oversample segment using computed num_passes
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
|
||||
|
||||
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
oversampled = self._interpolate_frames(
|
||||
seg_frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
oversampled_fps = source_fps * mult
|
||||
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_oversamp_idx = round(j / target_fps * oversampled_fps)
|
||||
local_idx = global_oversamp_idx - start * mult
|
||||
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
|
||||
out_frames.append(oversampled[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Standard multiplier mode
|
||||
is_continuation = segment_index > 0
|
||||
(result,) = super().interpolate(
|
||||
segment_images, model, multiplier, clear_cache_after_n_frames,
|
||||
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||
)
|
||||
|
||||
if is_continuation:
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
result = result[1:]
|
||||
|
||||
return (result, model)
|
||||
|
||||
@@ -1265,6 +1525,14 @@ class GIMMVFIInterpolate:
|
||||
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||
}),
|
||||
"source_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Input frame rate. Required when target_fps > 0.",
|
||||
}),
|
||||
"target_fps": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01,
|
||||
"tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1376,14 +1644,31 @@ class GIMMVFIInterpolate:
|
||||
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size):
|
||||
batch_size, chunk_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
if images.shape[0] < 2:
|
||||
return (images,)
|
||||
|
||||
# Target FPS mode: auto-compute multiplier from fps ratio
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
if num_passes == 0:
|
||||
all_frames = images.permute(0, 3, 1, 2)
|
||||
result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0])
|
||||
return (result.cpu().permute(0, 2, 3, 1),)
|
||||
# Override multiplier for single_pass mode
|
||||
multiplier = mult
|
||||
else:
|
||||
mult = multiplier
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if not single_pass:
|
||||
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
if not single_pass or use_target_fps:
|
||||
if use_target_fps:
|
||||
num_passes_recursive = num_passes
|
||||
else:
|
||||
num_passes_recursive = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
@@ -1411,7 +1696,7 @@ class GIMMVFIInterpolate:
|
||||
if single_pass:
|
||||
total_steps = sum(ce - cs - 1 for cs, ce in chunks)
|
||||
else:
|
||||
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
|
||||
total_steps = sum(self._count_steps(ce - cs, num_passes_recursive) for cs, ce in chunks)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
|
||||
@@ -1430,7 +1715,7 @@ class GIMMVFIInterpolate:
|
||||
)
|
||||
else:
|
||||
chunk_result = self._interpolate_frames(
|
||||
chunk_frames, model, num_passes, batch_size,
|
||||
chunk_frames, model, num_passes_recursive, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
@@ -1446,6 +1731,11 @@ class GIMMVFIInterpolate:
|
||||
result_chunks.append(chunk_result)
|
||||
|
||||
result = torch.cat(result_chunks, dim=0)
|
||||
|
||||
# Target FPS: select frames from oversampled result
|
||||
if use_target_fps:
|
||||
result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input)
|
||||
|
||||
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||
result = result.cpu().permute(0, 2, 3, 1)
|
||||
return (result,)
|
||||
@@ -1482,21 +1772,90 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
|
||||
def interpolate(self, images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
batch_size, chunk_size, segment_index, segment_size):
|
||||
batch_size, chunk_size, segment_index, segment_size,
|
||||
source_fps=0.0, target_fps=0.0):
|
||||
total_input = images.shape[0]
|
||||
use_target_fps = target_fps > 0 and source_fps > 0
|
||||
|
||||
# Compute segment boundaries (1-frame overlap)
|
||||
start = segment_index * (segment_size - 1)
|
||||
end = min(start + segment_size, total_input)
|
||||
|
||||
if start >= total_input - 1:
|
||||
# Past the end — return empty single frame + model
|
||||
return (images[:1], model)
|
||||
|
||||
segment_images = images[start:end]
|
||||
is_continuation = segment_index > 0
|
||||
|
||||
# Delegate to the parent interpolation logic
|
||||
if use_target_fps:
|
||||
num_passes, mult = _compute_target_fps_params(source_fps, target_fps)
|
||||
|
||||
seg_start_time = start / source_fps
|
||||
seg_end_time = (end - 1) / source_fps
|
||||
duration = (total_input - 1) / source_fps
|
||||
total_output = int(math.floor(duration * target_fps)) + 1
|
||||
|
||||
if segment_index == 0:
|
||||
j_start = 0
|
||||
else:
|
||||
j_start = int(math.floor(seg_start_time * target_fps)) + 1
|
||||
j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1)
|
||||
|
||||
if j_start > j_end:
|
||||
return (images[:1], model)
|
||||
|
||||
if num_passes == 0:
|
||||
oversampled_fps = source_fps * mult
|
||||
all_seg = segment_images.permute(0, 3, 1, 2)
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1)
|
||||
local_idx = global_idx - start
|
||||
local_idx = max(0, min(local_idx, all_seg.shape[0] - 1))
|
||||
out_frames.append(all_seg[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Oversample segment directly
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if all_on_gpu:
|
||||
keep_device = True
|
||||
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||
seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device)
|
||||
|
||||
if single_pass:
|
||||
total_steps = seg_frames.shape[0] - 1
|
||||
else:
|
||||
total_steps = self._count_steps(seg_frames.shape[0], num_passes)
|
||||
pbar = ProgressBar(total_steps)
|
||||
step_ref = [0]
|
||||
if keep_device:
|
||||
model.to(device)
|
||||
|
||||
if single_pass:
|
||||
oversampled = self._interpolate_frames_single_pass(
|
||||
seg_frames, model, mult,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
else:
|
||||
oversampled = self._interpolate_frames(
|
||||
seg_frames, model, num_passes, batch_size,
|
||||
device, storage_device, keep_device, all_on_gpu,
|
||||
clear_cache_after_n_frames, pbar, step_ref,
|
||||
)
|
||||
oversampled_fps = source_fps * mult
|
||||
|
||||
out_frames = []
|
||||
for j in range(j_start, j_end + 1):
|
||||
global_oversamp_idx = round(j / target_fps * oversampled_fps)
|
||||
local_idx = global_oversamp_idx - start * mult
|
||||
local_idx = max(0, min(local_idx, oversampled.shape[0] - 1))
|
||||
out_frames.append(oversampled[local_idx:local_idx + 1])
|
||||
result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1)
|
||||
return (result, model)
|
||||
|
||||
# Standard multiplier mode
|
||||
is_continuation = segment_index > 0
|
||||
(result,) = super().interpolate(
|
||||
segment_images, model, multiplier, single_pass,
|
||||
clear_cache_after_n_frames, keep_device, all_on_gpu,
|
||||
@@ -1504,6 +1863,6 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate):
|
||||
)
|
||||
|
||||
if is_continuation:
|
||||
result = result[1:] # skip duplicate boundary frame
|
||||
result = result[1:]
|
||||
|
||||
return (result, model)
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
|
||||
function fitHeight(node) {
|
||||
node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]);
|
||||
node?.graph?.setDirtyCanvas(true);
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Tween.VideoPreview",
|
||||
async beforeRegisterNodeDef(nodeType, nodeData) {
|
||||
if (nodeData?.name !== "TweenConcatVideos") return;
|
||||
|
||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
onNodeCreated?.apply(this, arguments);
|
||||
|
||||
const container = document.createElement("div");
|
||||
const previewWidget = this.addDOMWidget("videopreview", "preview", container, {
|
||||
serialize: false,
|
||||
hideOnZoom: false,
|
||||
getValue() { return container.value; },
|
||||
setValue(v) { container.value = v; },
|
||||
});
|
||||
|
||||
previewWidget.computeSize = function (width) {
|
||||
if (this.aspectRatio && !this.videoEl.hidden) {
|
||||
const height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
|
||||
return [width, height > 0 ? height : -4];
|
||||
}
|
||||
return [width, -4];
|
||||
};
|
||||
|
||||
const previewNode = this;
|
||||
|
||||
previewWidget.videoEl = document.createElement("video");
|
||||
previewWidget.videoEl.controls = true;
|
||||
previewWidget.videoEl.loop = true;
|
||||
previewWidget.videoEl.muted = true;
|
||||
previewWidget.videoEl.style.width = "100%";
|
||||
previewWidget.videoEl.hidden = true;
|
||||
|
||||
previewWidget.videoEl.addEventListener("loadedmetadata", () => {
|
||||
previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
|
||||
fitHeight(previewNode);
|
||||
});
|
||||
previewWidget.videoEl.addEventListener("error", () => {
|
||||
previewWidget.videoEl.hidden = true;
|
||||
fitHeight(previewNode);
|
||||
});
|
||||
|
||||
container.appendChild(previewWidget.videoEl);
|
||||
};
|
||||
|
||||
const onExecuted = nodeType.prototype.onExecuted;
|
||||
nodeType.prototype.onExecuted = function (message) {
|
||||
onExecuted?.apply(this, arguments);
|
||||
|
||||
if (!message?.gifs?.length) return;
|
||||
|
||||
const params = message.gifs[0];
|
||||
const previewWidget = this.widgets?.find((w) => w.name === "videopreview");
|
||||
if (!previewWidget) return;
|
||||
|
||||
const query = new URLSearchParams(params);
|
||||
query.set("timestamp", Date.now());
|
||||
previewWidget.videoEl.src = api.apiURL("/view?" + query);
|
||||
previewWidget.videoEl.hidden = false;
|
||||
previewWidget.videoEl.autoplay = true;
|
||||
};
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user