From 6dd579dcc7d3667161bff04498168ac7a205aae7 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 13 Feb 2026 22:51:04 +0100 Subject: [PATCH] Add target FPS mode to all VFI models and remove concat preview MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add source_fps/target_fps inputs to all 4 Interpolate and Segment nodes (BIM, EMA, SGM, GIMM). When target_fps > 0, auto-computes optimal power-of-2 oversample, runs existing recursive t=0.5 interpolation, then selects frames at target timestamps. Handles downsampling (no model calls), same-fps passthrough, and high ratios (e.g. 3→30fps). Segment boundary logic uses global index computation for gap-free stitching. When target_fps=0, existing multiplier behavior is preserved. Remove video preview from TweenConcatVideos: drop preview input, delete web/js/tween_preview.js, remove WEB_DIRECTORY from __init__.py. Co-Authored-By: Claude Opus 4.6 --- __init__.py | 2 - nodes.py | 473 +++++++++++++++++++++++++++++++++++----- web/js/tween_preview.js | 72 ------ 3 files changed, 416 insertions(+), 131 deletions(-) delete mode 100644 web/js/tween_preview.js diff --git a/__init__.py b/__init__.py index 8aedf02..2981ddc 100644 --- a/__init__.py +++ b/__init__.py @@ -52,8 +52,6 @@ from .nodes import ( LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, ) -WEB_DIRECTORY = "./web" - NODE_CLASS_MAPPINGS = { "LoadBIMVFIModel": LoadBIMVFIModel, "BIMVFIInterpolate": BIMVFIInterpolate, diff --git a/nodes.py b/nodes.py index a25ae3f..9ba64fa 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,4 @@ +import math import os import glob import logging @@ -16,6 +17,35 @@ from .gimm_vfi_arch import clear_gimm_caches logger = logging.getLogger("Tween") + +def _compute_target_fps_params(source_fps, target_fps): + """Compute oversampling parameters for target FPS mode. + + Returns (num_passes, mult) where mult = 2^num_passes is the power-of-2 + multiplier needed to oversample above the target ratio. + """ + ratio = target_fps / source_fps + if ratio <= 1.0: + return 0, 1 # no interpolation needed (downsampling or same fps) + num_passes = math.ceil(math.log2(ratio)) + mult = 2 ** num_passes + return num_passes, mult + + +def _select_target_fps_frames(frames, source_fps, target_fps, mult, num_input): + """Pick frames from oversampled [M,C,H,W] tensor to hit target FPS timing. + + For downsampling (mult=1, ratio<=1), selects from original input frames. + For upsampling, selects from the oversampled sequence at target timestamps. + """ + duration = (num_input - 1) / source_fps + num_output = int(math.floor(duration * target_fps)) + 1 + oversampled_fps = source_fps * mult + max_idx = frames.shape[0] - 1 + indices = [min(round(j / target_fps * oversampled_fps), max_idx) for j in range(num_output)] + return frames[indices] + + # Google Drive file ID for the pretrained BIM-VFI model GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC" MODEL_FILENAME = "bim_vfi.pth" @@ -161,6 +191,14 @@ class BIMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead. Result is identical to processing all at once.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -234,12 +272,25 @@ class BIMVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + # Downsampling or same fps — select from input directly + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -292,6 +343,11 @@ class BIMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -328,8 +384,10 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) @@ -340,9 +398,69 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate): return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + # Compute global output frame range for this segment + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + # Downsampling — select from segment input directly + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, @@ -389,11 +507,6 @@ class TweenConcatVideos: "tooltip": "Delete the individual segment files after successful concatenation. " "Useful to avoid leftover files that would pollute the next run.", }), - "preview": ("BOOLEAN", { - "default": True, - "tooltip": "Show the concatenated video as a preview on the node. " - "Disable to skip the preview widget.", - }), } } @@ -418,7 +531,7 @@ class TweenConcatVideos: ) return ffmpeg_path - def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments, preview): + def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments): # Resolve output directory — empty or relative paths are relative to ComfyUI output comfy_output = folder_paths.get_output_directory() out_dir = output_directory.strip() @@ -494,28 +607,7 @@ class TweenConcatVideos: if os.path.exists(concat_list_path): os.remove(concat_list_path) - result = {"result": (output_path,)} - - if preview: - # Preview only works when the file is inside ComfyUI's output tree - abs_out = os.path.abspath(out_dir) - abs_comfy = os.path.abspath(comfy_output) - if abs_out.startswith(abs_comfy + os.sep) or abs_out == abs_comfy: - subfolder = os.path.relpath(abs_out, abs_comfy) if abs_out != abs_comfy else "" - result["ui"] = { - "gifs": [{ - "filename": os.path.basename(output_path), - "subfolder": subfolder, - "type": "output", - "format": "video/mp4", - }] - } - else: - logger.warning( - f"Video preview skipped: {out_dir} is outside ComfyUI output directory" - ) - - return result + return {"result": (output_path,)} # --------------------------------------------------------------------------- @@ -634,6 +726,14 @@ class EMAVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -700,12 +800,24 @@ class EMAVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -758,6 +870,11 @@ class EMAVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -794,28 +911,87 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) @@ -941,6 +1117,14 @@ class SGMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -1007,12 +1191,24 @@ class SGMVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -1065,6 +1261,11 @@ class SGMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -1101,28 +1302,87 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) @@ -1265,6 +1525,14 @@ class GIMMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -1376,14 +1644,31 @@ class GIMMVFIInterpolate: def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, - batch_size, chunk_size): + batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + # Override multiplier for single_pass mode + multiplier = mult + else: + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if not single_pass: - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + if not single_pass or use_target_fps: + if use_target_fps: + num_passes_recursive = num_passes + else: + num_passes_recursive = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -1411,7 +1696,7 @@ class GIMMVFIInterpolate: if single_pass: total_steps = sum(ce - cs - 1 for cs, ce in chunks) else: - total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks) + total_steps = sum(self._count_steps(ce - cs, num_passes_recursive) for cs, ce in chunks) pbar = ProgressBar(total_steps) step_ref = [0] @@ -1430,7 +1715,7 @@ class GIMMVFIInterpolate: ) else: chunk_result = self._interpolate_frames( - chunk_frames, model, num_passes, batch_size, + chunk_frames, model, num_passes_recursive, batch_size, device, storage_device, keep_device, all_on_gpu, clear_cache_after_n_frames, pbar, step_ref, ) @@ -1446,6 +1731,11 @@ class GIMMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -1482,21 +1772,90 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, - batch_size, chunk_size, segment_index, segment_size): + batch_size, chunk_size, segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment directly + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + if single_pass: + total_steps = seg_frames.shape[0] - 1 + else: + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + if single_pass: + oversampled = self._interpolate_frames_single_pass( + seg_frames, model, mult, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + else: + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, @@ -1504,6 +1863,6 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) diff --git a/web/js/tween_preview.js b/web/js/tween_preview.js deleted file mode 100644 index fe91217..0000000 --- a/web/js/tween_preview.js +++ /dev/null @@ -1,72 +0,0 @@ -import { app } from "../../scripts/app.js"; -import { api } from "../../scripts/api.js"; - -function fitHeight(node) { - node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]); - node?.graph?.setDirtyCanvas(true); -} - -app.registerExtension({ - name: "Tween.VideoPreview", - async beforeRegisterNodeDef(nodeType, nodeData) { - if (nodeData?.name !== "TweenConcatVideos") return; - - const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = function () { - onNodeCreated?.apply(this, arguments); - - const container = document.createElement("div"); - const previewWidget = this.addDOMWidget("videopreview", "preview", container, { - serialize: false, - hideOnZoom: false, - getValue() { return container.value; }, - setValue(v) { container.value = v; }, - }); - - previewWidget.computeSize = function (width) { - if (this.aspectRatio && !this.videoEl.hidden) { - const height = (previewNode.size[0] - 20) / this.aspectRatio + 10; - return [width, height > 0 ? height : -4]; - } - return [width, -4]; - }; - - const previewNode = this; - - previewWidget.videoEl = document.createElement("video"); - previewWidget.videoEl.controls = true; - previewWidget.videoEl.loop = true; - previewWidget.videoEl.muted = true; - previewWidget.videoEl.style.width = "100%"; - previewWidget.videoEl.hidden = true; - - previewWidget.videoEl.addEventListener("loadedmetadata", () => { - previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; - fitHeight(previewNode); - }); - previewWidget.videoEl.addEventListener("error", () => { - previewWidget.videoEl.hidden = true; - fitHeight(previewNode); - }); - - container.appendChild(previewWidget.videoEl); - }; - - const onExecuted = nodeType.prototype.onExecuted; - nodeType.prototype.onExecuted = function (message) { - onExecuted?.apply(this, arguments); - - if (!message?.gifs?.length) return; - - const params = message.gifs[0]; - const previewWidget = this.widgets?.find((w) => w.name === "videopreview"); - if (!previewWidget) return; - - const query = new URLSearchParams(params); - query.set("timestamp", Date.now()); - previewWidget.videoEl.src = api.apiURL("/view?" + query); - previewWidget.videoEl.hidden = false; - previewWidget.videoEl.autoplay = true; - }; - }, -});