diff --git a/core/__init__.py b/core/__init__.py index 29527c4..cbff440 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -23,7 +23,7 @@ from .models import ( DatabaseError, ) from .database import DatabaseManager -from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, OPTICAL_FLOW_PRESETS +from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, BimVfiEnv, OPTICAL_FLOW_PRESETS from .manager import SymlinkManager from .video import encode_image_sequence, encode_from_file_list, find_ffmpeg @@ -54,6 +54,7 @@ __all__ = [ 'RifeDownloader', 'PracticalRifeEnv', 'FilmEnv', + 'BimVfiEnv', 'SymlinkManager', 'OPTICAL_FLOW_PRESETS', 'encode_image_sequence', diff --git a/core/bim_vfi_worker.py b/core/bim_vfi_worker.py new file mode 100644 index 0000000..bcdc998 --- /dev/null +++ b/core/bim_vfi_worker.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +"""BiM-VFI interpolation worker - runs in isolated venv with PyTorch. + +This script is executed via subprocess from the main application. +It handles frame interpolation using KAIST VICLab's BiM-VFI model +(Bidirectional Motion Field-Guided Frame Interpolation). + +BiM-VFI is designed for non-uniform motions and produces high-quality +results, especially for complex scenes (CVPR 2025). + +Supports two modes: +1. Single frame: --output with --timestep +2. Batch mode: --output-dir with --frame-count (generates all frames at once) +""" + +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +# Checkpoint filename +BIM_VFI_CHECKPOINT = "bim_vfi.pth" + + +def load_image(path: Path, device: torch.device) -> torch.Tensor: + """Load image as tensor. + + Args: + path: Path to image file. + device: Device to load tensor to. + + Returns: + Image tensor (1, 3, H, W) normalized to [0, 1]. + """ + img = Image.open(path).convert('RGB') + arr = np.array(img).astype(np.float32) / 255.0 + tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) + return tensor.to(device) + + +def save_image(tensor: torch.Tensor, path: Path) -> None: + """Save tensor as image. + + Args: + tensor: Image tensor (1, 3, H, W) or (3, H, W) normalized to [0, 1]. + path: Output path. + """ + if tensor.dim() == 4: + tensor = tensor.squeeze(0) + arr = tensor.permute(1, 2, 0).cpu().numpy() + arr = (arr * 255).clip(0, 255).astype(np.uint8) + Image.fromarray(arr).save(path) + + +# Global model cache +_model_cache: dict = {} + + +def get_model(repo_dir: Path, model_dir: Path, device: torch.device): + """Get or load BiM-VFI model (cached). + + Args: + repo_dir: Path to the cloned BiM-VFI repository. + model_dir: Directory containing the checkpoint. + device: Device to run on. + + Returns: + BiM-VFI model instance. + """ + cache_key = f"bim_vfi_{device}" + if cache_key not in _model_cache: + # Add repo to sys.path so we can import BiM-VFI modules + repo_str = str(repo_dir) + if repo_str not in sys.path: + sys.path.insert(0, repo_str) + + checkpoint_path = model_dir / BIM_VFI_CHECKPOINT + + if not checkpoint_path.exists(): + raise FileNotFoundError( + f"BiM-VFI checkpoint not found at {checkpoint_path}. " + "Please download it from Google Drive and place it there." + ) + + print(f"Loading BiM-VFI model from {checkpoint_path}...", file=sys.stderr) + + # Import BiM-VFI's component registry + from modules.components.components import make_components + + # Create model with default config + cfg = {'name': 'bim_vfi', 'args': {'pyr_level': 3, 'feat_channels': 32}} + model = make_components(cfg) + + # Load checkpoint + ckpt = torch.load(str(checkpoint_path), map_location='cpu', weights_only=False) + model.load_state_dict(ckpt['model']) + + model.eval() + model.to(device) + _model_cache[cache_key] = model + print("Model loaded.", file=sys.stderr) + + return _model_cache[cache_key] + + +def get_pyr_level(height: int) -> int: + """Get appropriate pyramid level based on image height. + + Args: + height: Image height in pixels. + + Returns: + Recommended pyramid level. + """ + if height >= 2160: + return 7 + elif height >= 1080: + return 6 + else: + return 5 + + +def get_scale_factor(height: int) -> float: + """Get appropriate scale factor based on image height. + + Args: + height: Image height in pixels. + + Returns: + Recommended scale factor. + """ + if height >= 2160: + return 0.25 + elif height >= 1080: + return 0.5 + else: + return 1.0 + + +@torch.no_grad() +def interpolate_single( + model, img0: torch.Tensor, img1: torch.Tensor, t: float +) -> torch.Tensor: + """Perform single frame interpolation using BiM-VFI. + + Args: + model: BiM-VFI model instance. + img0: First frame tensor (1, 3, H, W) normalized to [0, 1]. + img1: Second frame tensor (1, 3, H, W) normalized to [0, 1]. + t: Interpolation timestep (0.0 to 1.0). + + Returns: + Interpolated frame tensor. + """ + h = img0.shape[2] + pyr_level = get_pyr_level(h) + scale_factor = get_scale_factor(h) + + time_step = torch.tensor([t]).view(1, 1, 1, 1).to(img0.device) + + # Distance weights for better quality + dis0 = torch.ones((1, 1, h, img0.shape[3]), device=img0.device) * t + dis1 = 1 - dis0 + + results_dict = model( + img0=img0, img1=img1, + time_step=time_step, + dis0=dis0, dis1=dis1, + scale_factor=scale_factor, + ratio=(1.0 / scale_factor), + pyr_level=pyr_level, + nr_lvl_skipped=0 + ) + + return results_dict['imgt_pred'].clamp(0, 1) + + +@torch.no_grad() +def interpolate_batch( + model, img0: torch.Tensor, img1: torch.Tensor, frame_count: int +) -> list[torch.Tensor]: + """Generate multiple interpolated frames using BiM-VFI. + + Args: + model: BiM-VFI model instance. + img0: First frame tensor (1, 3, H, W) normalized to [0, 1]. + img1: Second frame tensor (1, 3, H, W) normalized to [0, 1]. + frame_count: Number of frames to generate between img0 and img1. + + Returns: + List of interpolated frame tensors in order. + """ + frames = [] + for i in range(frame_count): + t = (i + 1) / (frame_count + 1) + frame = interpolate_single(model, img0, img1, t) + frames.append(frame) + return frames + + +def main(): + parser = argparse.ArgumentParser(description='BiM-VFI frame interpolation worker') + parser.add_argument('--input0', required=True, help='Path to first input image') + parser.add_argument('--input1', required=True, help='Path to second input image') + parser.add_argument('--output', help='Path to output image (single frame mode)') + parser.add_argument('--output-dir', help='Output directory (batch mode)') + parser.add_argument('--output-pattern', default='frame_{:04d}.png', + help='Output filename pattern for batch mode') + parser.add_argument('--timestep', type=float, default=0.5, + help='Interpolation timestep 0-1 (single frame mode)') + parser.add_argument('--frame-count', type=int, + help='Number of frames to generate (batch mode)') + parser.add_argument('--repo-dir', required=True, help='Path to BiM-VFI repo clone') + parser.add_argument('--model-dir', required=True, help='Model cache directory') + parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use') + + args = parser.parse_args() + + # Validate arguments + batch_mode = args.output_dir is not None and args.frame_count is not None + single_mode = args.output is not None + + if not batch_mode and not single_mode: + print("Error: Must specify either --output (single) or --output-dir + --frame-count (batch)", + file=sys.stderr) + return 1 + + try: + # Select device + if args.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Load model + repo_dir = Path(args.repo_dir) + model_dir = Path(args.model_dir) + model = get_model(repo_dir, model_dir, device) + + # Load images + img0 = load_image(Path(args.input0), device) + img1 = load_image(Path(args.input1), device) + + if batch_mode: + # Batch mode - generate all frames at once + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Generating {args.frame_count} frames...", file=sys.stderr) + frames = interpolate_batch(model, img0, img1, args.frame_count) + + for i, frame in enumerate(frames): + output_path = output_dir / args.output_pattern.format(i) + save_image(frame, output_path) + print(f"Saved {output_path.name}", file=sys.stderr) + + print(f"Success: Generated {len(frames)} frames", file=sys.stderr) + else: + # Single frame mode + result = interpolate_single(model, img0, img1, args.timestep) + save_image(result, Path(args.output)) + print("Success", file=sys.stderr) + + return 0 + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/core/blender.py b/core/blender.py index 451c5c2..df91ec3 100644 --- a/core/blender.py +++ b/core/blender.py @@ -477,6 +477,302 @@ class FilmEnv: return False, str(e), [] +class BimVfiEnv: + """Manages BiM-VFI frame interpolation using shared venv with RIFE.""" + + VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE + REPO_DIR = CACHE_DIR / 'BiM-VFI' + MODEL_CACHE_DIR = CACHE_DIR / 'bim-vfi' + CHECKPOINT_FILENAME = 'bim_vfi.pth' + REPO_URL = 'https://github.com/KAIST-VICLab/BiM-VFI.git' + # Google Drive file ID for the checkpoint + GDRIVE_FILE_ID = '18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC' + + # Extra pip packages needed beyond the base torch venv + EXTRA_PACKAGES = [ + 'basicsr-fixed', 'imageio', 'pyyaml', 'opencv-python', + 'lpips', 'ptflops', + ] + + @classmethod + def get_venv_python(cls) -> Optional[Path]: + """Get path to venv Python executable.""" + if cls.VENV_DIR.exists(): + if sys.platform == 'win32': + return cls.VENV_DIR / 'Scripts' / 'python.exe' + return cls.VENV_DIR / 'bin' / 'python' + return None + + @classmethod + def get_checkpoint_path(cls) -> Path: + """Get path to the BiM-VFI checkpoint.""" + return cls.MODEL_CACHE_DIR / cls.CHECKPOINT_FILENAME + + @classmethod + def is_setup(cls) -> bool: + """Check if venv exists, repo is cloned, and checkpoint is present.""" + python = cls.get_venv_python() + if not python or not python.exists(): + return False + if not cls.REPO_DIR.exists(): + return False + return cls.get_checkpoint_path().exists() + + @classmethod + def setup_bim_vfi(cls, progress_callback=None, cancelled_check=None) -> bool: + """Clone repo, install deps, and download checkpoint. + + Args: + progress_callback: Optional callback(message, percent) for progress. + cancelled_check: Optional callable that returns True if cancelled. + + Returns: + True if setup was successful. + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False + + try: + # Step 1: Clone repo if needed + if not cls.REPO_DIR.exists(): + if progress_callback: + progress_callback("Cloning BiM-VFI repository...", 10) + if cancelled_check and cancelled_check(): + return False + + result = subprocess.run( + ['git', 'clone', '--depth', '1', cls.REPO_URL, str(cls.REPO_DIR)], + capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + print(f"[BiM-VFI] git clone failed: {result.stderr}", file=sys.stderr) + return False + + # Step 2: Install extra packages + if progress_callback: + progress_callback("Installing BiM-VFI dependencies...", 30) + if cancelled_check and cancelled_check(): + return False + + pip = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'pip' + result = subprocess.run( + [str(pip), 'install', '--quiet'] + cls.EXTRA_PACKAGES, + capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + print(f"[BiM-VFI] pip install failed: {result.stderr}", file=sys.stderr) + return False + + # Step 2b: Install cupy (needs CUDA-specific wheel) + if progress_callback: + progress_callback("Installing cupy (CUDA support)...", 50) + if cancelled_check and cancelled_check(): + return False + + # Try cupy-cuda12x first (CUDA 12), fall back to cupy-cuda11x + cupy_installed = False + for cupy_pkg in ['cupy-cuda12x', 'cupy-cuda11x']: + result = subprocess.run( + [str(pip), 'install', '--quiet', cupy_pkg], + capture_output=True, text=True, timeout=600 + ) + if result.returncode == 0: + cupy_installed = True + break + + if not cupy_installed: + print("[BiM-VFI] Warning: cupy install failed, trying generic cupy", file=sys.stderr) + subprocess.run( + [str(pip), 'install', '--quiet', 'cupy'], + capture_output=True, text=True, timeout=600 + ) + + # Step 3: Download checkpoint + checkpoint_path = cls.get_checkpoint_path() + if not checkpoint_path.exists(): + if progress_callback: + progress_callback("Downloading BiM-VFI checkpoint (~300MB)...", 60) + if cancelled_check and cancelled_check(): + return False + + cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + # Use gdown to download from Google Drive + result = subprocess.run( + [str(pip), 'install', '--quiet', 'gdown'], + capture_output=True, text=True, timeout=120 + ) + + gdown_bin = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'gdown' + tmp_path = checkpoint_path.with_suffix('.tmp') + result = subprocess.run( + [str(gdown_bin), '--id', cls.GDRIVE_FILE_ID, + '--output', str(tmp_path)], + capture_output=True, text=True, timeout=600 + ) + if result.returncode == 0 and tmp_path.exists(): + tmp_path.rename(checkpoint_path) + else: + tmp_path.unlink(missing_ok=True) + error = result.stderr.strip() if result.stderr else "unknown error" + print(f"[BiM-VFI] Download failed: {error}", file=sys.stderr) + print("[BiM-VFI] Please download manually from Google Drive and place at " + f"{checkpoint_path}", file=sys.stderr) + return False + + if progress_callback: + progress_callback("BiM-VFI setup complete!", 100) + + return cls.is_setup() + + except Exception as e: + print(f"[BiM-VFI] Setup error: {e}", file=sys.stderr) + return False + + @classmethod + def get_worker_script(cls) -> Path: + """Get path to the BiM-VFI worker script.""" + return Path(__file__).parent / 'bim_vfi_worker.py' + + @classmethod + def run_interpolation( + cls, + img_a_path: Path, + img_b_path: Path, + output_path: Path, + t: float + ) -> tuple[bool, str]: + """Run BiM-VFI interpolation via subprocess in venv. + + Args: + img_a_path: Path to first input image. + img_b_path: Path to second input image. + output_path: Path to output image. + t: Timestep for interpolation (0.0 to 1.0). + + Returns: + Tuple of (success, error_message). + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False, "venv python not found" + + script = cls.get_worker_script() + if not script.exists(): + return False, f"worker script not found: {script}" + + cmd = [ + str(python), str(script), + '--input0', str(img_a_path), + '--input1', str(img_b_path), + '--output', str(output_path), + '--timestep', str(t), + '--repo-dir', str(cls.REPO_DIR), + '--model-dir', str(cls.MODEL_CACHE_DIR) + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300 # 5 minute timeout per frame (BiM-VFI can be slow) + ) + if result.returncode == 0 and output_path.exists(): + return True, "" + else: + error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}" + return False, error + except subprocess.TimeoutExpired: + return False, "timeout (300s)" + except Exception as e: + return False, str(e) + + @classmethod + def run_batch_interpolation( + cls, + img_a_path: Path, + img_b_path: Path, + output_dir: Path, + frame_count: int, + output_pattern: str = 'frame_{:04d}.png' + ) -> tuple[bool, str, list[Path]]: + """Run BiM-VFI batch interpolation via subprocess in venv. + + Args: + img_a_path: Path to first input image. + img_b_path: Path to second input image. + output_dir: Directory to save output frames. + frame_count: Number of frames to generate. + output_pattern: Filename pattern for output frames. + + Returns: + Tuple of (success, error_message, list_of_output_paths). + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False, "venv python not found", [] + + script = cls.get_worker_script() + if not script.exists(): + return False, f"worker script not found: {script}", [] + + output_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + str(python), str(script), + '--input0', str(img_a_path), + '--input1', str(img_b_path), + '--output-dir', str(output_dir), + '--frame-count', str(frame_count), + '--output-pattern', output_pattern, + '--repo-dir', str(cls.REPO_DIR), + '--model-dir', str(cls.MODEL_CACHE_DIR) + ] + + try: + timeout = max(300, frame_count * 45) # At least 5 min, +45s per frame + + print(f"[BiM-VFI] Running batch interpolation: {frame_count} frames", file=sys.stderr) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout + ) + + output_paths = [ + output_dir / output_pattern.format(i) + for i in range(frame_count) + ] + existing_paths = [p for p in output_paths if p.exists()] + + if result.returncode == 0 and len(existing_paths) == frame_count: + print(f"[BiM-VFI] Success: generated {len(existing_paths)} frames", file=sys.stderr) + return True, "", output_paths + else: + error_parts = [] + if result.returncode != 0: + error_parts.append(f"returncode={result.returncode}") + if result.stderr and result.stderr.strip(): + error_parts.append(f"stderr: {result.stderr.strip()}") + if len(existing_paths) != frame_count: + error_parts.append(f"expected {frame_count} frames, got {len(existing_paths)}") + + error = "; ".join(error_parts) if error_parts else "unknown error" + print(f"[BiM-VFI] Failed: {error}", file=sys.stderr) + return False, error, existing_paths + + except subprocess.TimeoutExpired: + print(f"[BiM-VFI] Timeout after {timeout}s", file=sys.stderr) + return False, f"timeout ({timeout}s)", [] + except Exception as e: + print(f"[BiM-VFI] Exception: {e}", file=sys.stderr) + return False, str(e), [] + + class RifeDownloader: """Handles automatic download and caching of rife-ncnn-vulkan binary.""" @@ -1100,6 +1396,53 @@ class ImageBlender: # Fall back to Practical-RIFE return ImageBlender.practical_rife_blend(img_a, img_b, t) + @staticmethod + def bim_vfi_blend( + img_a: Image.Image, + img_b: Image.Image, + t: float + ) -> Image.Image: + """Blend using BiM-VFI for high-quality interpolation. + + BiM-VFI (Bidirectional Motion Field-Guided Frame Interpolation) + handles non-uniform motions well (CVPR 2025). + + Args: + img_a: First PIL Image (source frame). + img_b: Second PIL Image (target frame). + t: Interpolation factor 0.0 (100% A) to 1.0 (100% B). + + Returns: + AI-interpolated blended PIL Image. + """ + if not BimVfiEnv.is_setup(): + print("[BiM-VFI] Not set up, falling back to FILM", file=sys.stderr) + return ImageBlender.film_blend(img_a, img_b, t) + + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmp = Path(tmpdir) + input_a = tmp / 'a.png' + input_b = tmp / 'b.png' + output_file = tmp / 'out.png' + + img_a.convert('RGB').save(input_a) + img_b.convert('RGB').save(input_b) + + success, error_msg = BimVfiEnv.run_interpolation( + input_a, input_b, output_file, t + ) + + if success and output_file.exists(): + return Image.open(output_file).copy() + else: + print(f"[BiM-VFI] Interpolation failed: {error_msg}, falling back to FILM", file=sys.stderr) + + except Exception as e: + print(f"[BiM-VFI] Exception: {e}, falling back to FILM", file=sys.stderr) + + return ImageBlender.film_blend(img_a, img_b, t) + @staticmethod def blend_images( img_a_path: Path, @@ -1652,7 +1995,13 @@ class TransitionGenerator: img_a_path, img_b_path, frame_count, dest, base_seq_num ) - # For RIFE (or FILM fallback), generate frames one at a time + # For BiM-VFI, use batch mode + if method == DirectInterpolationMethod.BIM_VFI and BimVfiEnv.is_setup(): + return self._generate_bim_vfi_frames_batch( + img_a_path, img_b_path, frame_count, dest, base_seq_num + ) + + # For RIFE (or fallback), generate frames one at a time # Load source images img_a = Image.open(img_a_path) img_b = Image.open(img_b_path) @@ -1674,6 +2023,8 @@ class TransitionGenerator: # Generate interpolated frame if method == DirectInterpolationMethod.FILM: blended = ImageBlender.film_blend(img_a, img_b, t) + elif method == DirectInterpolationMethod.BIM_VFI: + blended = ImageBlender.bim_vfi_blend(img_a, img_b, t) else: # RIFE blended = ImageBlender.practical_rife_blend( img_a, img_b, t, @@ -1844,3 +2195,110 @@ class TransitionGenerator: )) return results + + def _generate_bim_vfi_frames_batch( + self, + img_a_path: Path, + img_b_path: Path, + frame_count: int, + dest: Path, + base_seq_num: int + ) -> list[BlendResult]: + """Generate BiM-VFI frames using batch mode. + + Args: + img_a_path: Path to last frame of first sequence. + img_b_path: Path to first frame of second sequence. + frame_count: Number of interpolated frames to generate. + dest: Destination directory for generated frames. + base_seq_num: Starting sequence number for continuous naming. + + Returns: + List of BlendResult objects. + """ + results = [] + + temp_pattern = 'bimvfi_temp_{:04d}.png' + + success, error, temp_paths = BimVfiEnv.run_batch_interpolation( + img_a_path, + img_b_path, + dest, + frame_count, + temp_pattern + ) + + if not success: + for i in range(frame_count): + t = (i + 1) / (frame_count + 1) + ext = f".{self.settings.output_format.lower()}" + seq_num = base_seq_num + i + output_name = f"seq_{seq_num:05d}{ext}" + output_path = dest / output_name + + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=error + )) + return results + + for i, temp_path in enumerate(temp_paths): + t = (i + 1) / (frame_count + 1) + ext = f".{self.settings.output_format.lower()}" + seq_num = base_seq_num + i + output_name = f"seq_{seq_num:05d}{ext}" + output_path = dest / output_name + + try: + if temp_path.exists(): + frame = Image.open(temp_path) + + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + frame = frame.convert('RGB') + + save_kwargs = {} + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + save_kwargs['quality'] = self.settings.output_quality + elif self.settings.output_format.lower() == 'webp': + save_kwargs['lossless'] = True + save_kwargs['method'] = self.settings.webp_method + elif self.settings.output_format.lower() == 'png': + save_kwargs['compress_level'] = 6 + + frame.save(output_path, **save_kwargs) + frame.close() + + if temp_path != output_path: + temp_path.unlink(missing_ok=True) + + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=True + )) + else: + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=f"Temp file not found: {temp_path}" + )) + except Exception as e: + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=str(e) + )) + + return results diff --git a/core/models.py b/core/models.py index 91ffec4..4a3c32c 100644 --- a/core/models.py +++ b/core/models.py @@ -36,6 +36,7 @@ class DirectInterpolationMethod(Enum): """Method for direct frame interpolation between sequences.""" RIFE = 'rife' FILM = 'film' + BIM_VFI = 'bim_vfi' # --- Data Classes --- diff --git a/ui/main_window.py b/ui/main_window.py index 6164fad..e85c1ea 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -67,6 +67,7 @@ from core import ( find_ffmpeg, PracticalRifeEnv, FilmEnv, + BimVfiEnv, SymlinkManager, OPTICAL_FLOW_PRESETS, ) @@ -236,8 +237,12 @@ class DirectTransitionDialog(QDialog): self.method_combo = QComboBox() self.method_combo.addItem("RIFE (Fast, small motion)", DirectInterpolationMethod.RIFE) self.method_combo.addItem("FILM (Slow, large motion)", DirectInterpolationMethod.FILM) - if method == DirectInterpolationMethod.FILM: - self.method_combo.setCurrentIndex(1) + self.method_combo.addItem("BiM-VFI (Best quality, slowest)", DirectInterpolationMethod.BIM_VFI) + # Set current method + for i in range(self.method_combo.count()): + if self.method_combo.itemData(i) == method: + self.method_combo.setCurrentIndex(i) + break form_layout.addRow("Method:", self.method_combo) # Frame count @@ -268,7 +273,8 @@ class DirectTransitionDialog(QDialog): # Explanation explain = QLabel( "RIFE: Fast AI interpolation, best for small motion and color shifts.\n" - "FILM: Google Research model, better for large motion and scene gaps.\n\n" + "FILM: Google Research model, better for large motion and scene gaps.\n" + "BiM-VFI: CVPR 2025, best quality for non-uniform/complex motions.\n\n" "Generated frames bridge the gap between the last frame of this\n" "sequence and the first frame of the next MAIN sequence." ) @@ -305,6 +311,7 @@ class DirectTransitionDialog(QDialog): rife_ready = PracticalRifeEnv.is_setup() film_ready = FilmEnv.is_setup() if rife_ready else False + bim_ready = BimVfiEnv.is_setup() if rife_ready else False if method == DirectInterpolationMethod.RIFE: if rife_ready: @@ -316,13 +323,13 @@ class DirectTransitionDialog(QDialog): self.status_label.setStyleSheet("color: orange; font-size: 10px;") self.setup_btn.setVisible(True) self.setup_btn.setText("Setup PyTorch Environment") - else: # FILM + elif method == DirectInterpolationMethod.FILM: if film_ready: self.status_label.setText("FILM: Ready") self.status_label.setStyleSheet("color: green; font-size: 10px;") self.setup_btn.setVisible(False) elif rife_ready: - self.status_label.setText("FILM: Package not installed") + self.status_label.setText("FILM: Model not downloaded") self.status_label.setStyleSheet("color: orange; font-size: 10px;") self.setup_btn.setVisible(True) self.setup_btn.setText("Install FILM Package") @@ -331,6 +338,21 @@ class DirectTransitionDialog(QDialog): self.status_label.setStyleSheet("color: orange; font-size: 10px;") self.setup_btn.setVisible(True) self.setup_btn.setText("Setup PyTorch Environment") + else: # BIM_VFI + if bim_ready: + self.status_label.setText("BiM-VFI: Ready") + self.status_label.setStyleSheet("color: green; font-size: 10px;") + self.setup_btn.setVisible(False) + elif rife_ready: + self.status_label.setText("BiM-VFI: Not installed (repo + model needed)") + self.status_label.setStyleSheet("color: orange; font-size: 10px;") + self.setup_btn.setVisible(True) + self.setup_btn.setText("Install BiM-VFI") + else: + self.status_label.setText("BiM-VFI: Not installed (PyTorch required first)") + self.status_label.setStyleSheet("color: orange; font-size: 10px;") + self.setup_btn.setVisible(True) + self.setup_btn.setText("Setup PyTorch Environment") def _on_setup(self) -> None: """Handle setup button click.""" @@ -397,6 +419,38 @@ class DirectTransitionDialog(QDialog): ) return + # If BiM-VFI selected and we need to install it + if method == DirectInterpolationMethod.BIM_VFI and not BimVfiEnv.is_setup(): + progress = QProgressDialog( + "Setting up BiM-VFI...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("Setup") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(0) + + def progress_cb(msg, pct): + progress.setLabelText(msg) + progress.setValue(pct) + + def cancelled_check(): + QApplication.processEvents() + return progress.wasCanceled() + + success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check) + progress.close() + + if not success: + if not cancelled_check(): + QMessageBox.warning( + self, "Setup Failed", + "Failed to set up BiM-VFI.\n\n" + "If the checkpoint download failed, you can download it " + "manually from Google Drive and place it at:\n" + f"{BimVfiEnv.get_checkpoint_path()}" + ) + return + self._update_status() def _on_remove(self) -> None: