diff --git a/__init__.py b/__init__.py index 8aedf02..2981ddc 100644 --- a/__init__.py +++ b/__init__.py @@ -52,8 +52,6 @@ from .nodes import ( LoadGIMMVFIModel, GIMMVFIInterpolate, GIMMVFISegmentInterpolate, ) -WEB_DIRECTORY = "./web" - NODE_CLASS_MAPPINGS = { "LoadBIMVFIModel": LoadBIMVFIModel, "BIMVFIInterpolate": BIMVFIInterpolate, diff --git a/nodes.py b/nodes.py index a25ae3f..9ba64fa 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,4 @@ +import math import os import glob import logging @@ -16,6 +17,35 @@ from .gimm_vfi_arch import clear_gimm_caches logger = logging.getLogger("Tween") + +def _compute_target_fps_params(source_fps, target_fps): + """Compute oversampling parameters for target FPS mode. + + Returns (num_passes, mult) where mult = 2^num_passes is the power-of-2 + multiplier needed to oversample above the target ratio. + """ + ratio = target_fps / source_fps + if ratio <= 1.0: + return 0, 1 # no interpolation needed (downsampling or same fps) + num_passes = math.ceil(math.log2(ratio)) + mult = 2 ** num_passes + return num_passes, mult + + +def _select_target_fps_frames(frames, source_fps, target_fps, mult, num_input): + """Pick frames from oversampled [M,C,H,W] tensor to hit target FPS timing. + + For downsampling (mult=1, ratio<=1), selects from original input frames. + For upsampling, selects from the oversampled sequence at target timestamps. + """ + duration = (num_input - 1) / source_fps + num_output = int(math.floor(duration * target_fps)) + 1 + oversampled_fps = source_fps * mult + max_idx = frames.shape[0] - 1 + indices = [min(round(j / target_fps * oversampled_fps), max_idx) for j in range(num_output)] + return frames[indices] + + # Google Drive file ID for the pretrained BIM-VFI model GDRIVE_FILE_ID = "18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC" MODEL_FILENAME = "bim_vfi.pth" @@ -161,6 +191,14 @@ class BIMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead. Result is identical to processing all at once.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -234,12 +272,25 @@ class BIMVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + # Downsampling or same fps — select from input directly + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -292,6 +343,11 @@ class BIMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -328,8 +384,10 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) @@ -340,9 +398,69 @@ class BIMVFISegmentInterpolate(BIMVFIInterpolate): return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + # Compute global output frame range for this segment + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + # Downsampling — select from segment input directly + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, @@ -389,11 +507,6 @@ class TweenConcatVideos: "tooltip": "Delete the individual segment files after successful concatenation. " "Useful to avoid leftover files that would pollute the next run.", }), - "preview": ("BOOLEAN", { - "default": True, - "tooltip": "Show the concatenated video as a preview on the node. " - "Disable to skip the preview widget.", - }), } } @@ -418,7 +531,7 @@ class TweenConcatVideos: ) return ffmpeg_path - def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments, preview): + def concat(self, model, output_directory, filename_prefix, output_filename, delete_segments): # Resolve output directory — empty or relative paths are relative to ComfyUI output comfy_output = folder_paths.get_output_directory() out_dir = output_directory.strip() @@ -494,28 +607,7 @@ class TweenConcatVideos: if os.path.exists(concat_list_path): os.remove(concat_list_path) - result = {"result": (output_path,)} - - if preview: - # Preview only works when the file is inside ComfyUI's output tree - abs_out = os.path.abspath(out_dir) - abs_comfy = os.path.abspath(comfy_output) - if abs_out.startswith(abs_comfy + os.sep) or abs_out == abs_comfy: - subfolder = os.path.relpath(abs_out, abs_comfy) if abs_out != abs_comfy else "" - result["ui"] = { - "gifs": [{ - "filename": os.path.basename(output_path), - "subfolder": subfolder, - "type": "output", - "format": "video/mp4", - }] - } - else: - logger.warning( - f"Video preview skipped: {out_dir} is outside ComfyUI output directory" - ) - - return result + return {"result": (output_path,)} # --------------------------------------------------------------------------- @@ -634,6 +726,14 @@ class EMAVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -700,12 +800,24 @@ class EMAVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -758,6 +870,11 @@ class EMAVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -794,28 +911,87 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) @@ -941,6 +1117,14 @@ class SGMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -1007,12 +1191,24 @@ class SGMVFIInterpolate: return total def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, - keep_device, all_on_gpu, batch_size, chunk_size): + keep_device, all_on_gpu, batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + else: + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -1065,6 +1261,11 @@ class SGMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -1101,28 +1302,87 @@ class SGMVFISegmentInterpolate(SGMVFIInterpolate): def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, - segment_index, segment_size): + segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment using computed num_passes + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, clear_cache_after_n_frames, keep_device, all_on_gpu, batch_size, chunk_size, ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) @@ -1265,6 +1525,14 @@ class GIMMVFIInterpolate: "default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", }), + "source_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Input frame rate. Required when target_fps > 0.", + }), + "target_fps": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, + "tooltip": "Target output FPS. When > 0, overrides multiplier and auto-computes the optimal power-of-2 oversample then selects frames. 0 = use multiplier.", + }), } } @@ -1376,14 +1644,31 @@ class GIMMVFIInterpolate: def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, - batch_size, chunk_size): + batch_size, chunk_size, + source_fps=0.0, target_fps=0.0): if images.shape[0] < 2: return (images,) + # Target FPS mode: auto-compute multiplier from fps ratio + use_target_fps = target_fps > 0 and source_fps > 0 + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + if num_passes == 0: + all_frames = images.permute(0, 3, 1, 2) + result = _select_target_fps_frames(all_frames, source_fps, target_fps, mult, all_frames.shape[0]) + return (result.cpu().permute(0, 2, 3, 1),) + # Override multiplier for single_pass mode + multiplier = mult + else: + mult = multiplier + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if not single_pass: - num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + if not single_pass or use_target_fps: + if use_target_fps: + num_passes_recursive = num_passes + else: + num_passes_recursive = {2: 1, 4: 2, 8: 3}[multiplier] if all_on_gpu: keep_device = True @@ -1411,7 +1696,7 @@ class GIMMVFIInterpolate: if single_pass: total_steps = sum(ce - cs - 1 for cs, ce in chunks) else: - total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks) + total_steps = sum(self._count_steps(ce - cs, num_passes_recursive) for cs, ce in chunks) pbar = ProgressBar(total_steps) step_ref = [0] @@ -1430,7 +1715,7 @@ class GIMMVFIInterpolate: ) else: chunk_result = self._interpolate_frames( - chunk_frames, model, num_passes, batch_size, + chunk_frames, model, num_passes_recursive, batch_size, device, storage_device, keep_device, all_on_gpu, clear_cache_after_n_frames, pbar, step_ref, ) @@ -1446,6 +1731,11 @@ class GIMMVFIInterpolate: result_chunks.append(chunk_result) result = torch.cat(result_chunks, dim=0) + + # Target FPS: select frames from oversampled result + if use_target_fps: + result = _select_target_fps_frames(result, source_fps, target_fps, mult, total_input) + # Convert back to ComfyUI [B, H, W, C], on CPU result = result.cpu().permute(0, 2, 3, 1) return (result,) @@ -1482,21 +1772,90 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): def interpolate(self, images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, - batch_size, chunk_size, segment_index, segment_size): + batch_size, chunk_size, segment_index, segment_size, + source_fps=0.0, target_fps=0.0): total_input = images.shape[0] + use_target_fps = target_fps > 0 and source_fps > 0 # Compute segment boundaries (1-frame overlap) start = segment_index * (segment_size - 1) end = min(start + segment_size, total_input) if start >= total_input - 1: - # Past the end — return empty single frame + model return (images[:1], model) segment_images = images[start:end] - is_continuation = segment_index > 0 - # Delegate to the parent interpolation logic + if use_target_fps: + num_passes, mult = _compute_target_fps_params(source_fps, target_fps) + + seg_start_time = start / source_fps + seg_end_time = (end - 1) / source_fps + duration = (total_input - 1) / source_fps + total_output = int(math.floor(duration * target_fps)) + 1 + + if segment_index == 0: + j_start = 0 + else: + j_start = int(math.floor(seg_start_time * target_fps)) + 1 + j_end = min(int(math.floor(seg_end_time * target_fps)), total_output - 1) + + if j_start > j_end: + return (images[:1], model) + + if num_passes == 0: + oversampled_fps = source_fps * mult + all_seg = segment_images.permute(0, 3, 1, 2) + out_frames = [] + for j in range(j_start, j_end + 1): + global_idx = min(round(j / target_fps * oversampled_fps), total_input - 1) + local_idx = global_idx - start + local_idx = max(0, min(local_idx, all_seg.shape[0] - 1)) + out_frames.append(all_seg[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Oversample segment directly + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if all_on_gpu: + keep_device = True + storage_device = device if all_on_gpu else torch.device("cpu") + seg_frames = segment_images.permute(0, 3, 1, 2).to(storage_device) + + if single_pass: + total_steps = seg_frames.shape[0] - 1 + else: + total_steps = self._count_steps(seg_frames.shape[0], num_passes) + pbar = ProgressBar(total_steps) + step_ref = [0] + if keep_device: + model.to(device) + + if single_pass: + oversampled = self._interpolate_frames_single_pass( + seg_frames, model, mult, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + else: + oversampled = self._interpolate_frames( + seg_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + oversampled_fps = source_fps * mult + + out_frames = [] + for j in range(j_start, j_end + 1): + global_oversamp_idx = round(j / target_fps * oversampled_fps) + local_idx = global_oversamp_idx - start * mult + local_idx = max(0, min(local_idx, oversampled.shape[0] - 1)) + out_frames.append(oversampled[local_idx:local_idx + 1]) + result = torch.cat(out_frames, dim=0).cpu().permute(0, 2, 3, 1) + return (result, model) + + # Standard multiplier mode + is_continuation = segment_index > 0 (result,) = super().interpolate( segment_images, model, multiplier, single_pass, clear_cache_after_n_frames, keep_device, all_on_gpu, @@ -1504,6 +1863,6 @@ class GIMMVFISegmentInterpolate(GIMMVFIInterpolate): ) if is_continuation: - result = result[1:] # skip duplicate boundary frame + result = result[1:] return (result, model) diff --git a/web/js/tween_preview.js b/web/js/tween_preview.js deleted file mode 100644 index fe91217..0000000 --- a/web/js/tween_preview.js +++ /dev/null @@ -1,72 +0,0 @@ -import { app } from "../../scripts/app.js"; -import { api } from "../../scripts/api.js"; - -function fitHeight(node) { - node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]); - node?.graph?.setDirtyCanvas(true); -} - -app.registerExtension({ - name: "Tween.VideoPreview", - async beforeRegisterNodeDef(nodeType, nodeData) { - if (nodeData?.name !== "TweenConcatVideos") return; - - const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = function () { - onNodeCreated?.apply(this, arguments); - - const container = document.createElement("div"); - const previewWidget = this.addDOMWidget("videopreview", "preview", container, { - serialize: false, - hideOnZoom: false, - getValue() { return container.value; }, - setValue(v) { container.value = v; }, - }); - - previewWidget.computeSize = function (width) { - if (this.aspectRatio && !this.videoEl.hidden) { - const height = (previewNode.size[0] - 20) / this.aspectRatio + 10; - return [width, height > 0 ? height : -4]; - } - return [width, -4]; - }; - - const previewNode = this; - - previewWidget.videoEl = document.createElement("video"); - previewWidget.videoEl.controls = true; - previewWidget.videoEl.loop = true; - previewWidget.videoEl.muted = true; - previewWidget.videoEl.style.width = "100%"; - previewWidget.videoEl.hidden = true; - - previewWidget.videoEl.addEventListener("loadedmetadata", () => { - previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; - fitHeight(previewNode); - }); - previewWidget.videoEl.addEventListener("error", () => { - previewWidget.videoEl.hidden = true; - fitHeight(previewNode); - }); - - container.appendChild(previewWidget.videoEl); - }; - - const onExecuted = nodeType.prototype.onExecuted; - nodeType.prototype.onExecuted = function (message) { - onExecuted?.apply(this, arguments); - - if (!message?.gifs?.length) return; - - const params = message.gifs[0]; - const previewWidget = this.widgets?.find((w) => w.name === "videopreview"); - if (!previewWidget) return; - - const query = new URLSearchParams(params); - query.set("timestamp", Date.now()); - previewWidget.videoEl.src = api.apiURL("/view?" + query); - previewWidget.videoEl.hidden = false; - previewWidget.videoEl.autoplay = true; - }; - }, -});