diff --git a/core/__init__.py b/core/__init__.py index bf3852b..952a1b2 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -4,8 +4,10 @@ from .models import ( BlendCurve, BlendMethod, FolderType, + DirectInterpolationMethod, TransitionSettings, PerTransitionSettings, + DirectTransitionSettings, BlendResult, TransitionSpec, LinkResult, @@ -19,15 +21,17 @@ from .models import ( DatabaseError, ) from .database import DatabaseManager -from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, OPTICAL_FLOW_PRESETS +from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, OPTICAL_FLOW_PRESETS from .manager import SymlinkManager __all__ = [ 'BlendCurve', 'BlendMethod', 'FolderType', + 'DirectInterpolationMethod', 'TransitionSettings', 'PerTransitionSettings', + 'DirectTransitionSettings', 'BlendResult', 'TransitionSpec', 'LinkResult', @@ -44,6 +48,7 @@ __all__ = [ 'TransitionGenerator', 'RifeDownloader', 'PracticalRifeEnv', + 'FilmEnv', 'SymlinkManager', 'OPTICAL_FLOW_PRESETS', ] diff --git a/core/blender.py b/core/blender.py index a46b32f..ef6998e 100644 --- a/core/blender.py +++ b/core/blender.py @@ -23,6 +23,8 @@ from .models import ( PerTransitionSettings, BlendResult, TransitionSpec, + DirectInterpolationMethod, + DirectTransitionSettings, ) @@ -251,6 +253,230 @@ class PracticalRifeEnv: return False, str(e) +class FilmEnv: + """Manages FILM frame interpolation using shared venv with RIFE.""" + + VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE + MODEL_CACHE_DIR = CACHE_DIR / 'film' + MODEL_FILENAME = 'film_net_fp32.pt' + MODEL_URL = 'https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.2/film_net_fp32.pt' + + # Keep REPO_DIR for backward compat (but unused now - model is downloaded directly) + REPO_DIR = CACHE_DIR / 'frame-interpolation-pytorch' + + @classmethod + def get_venv_python(cls) -> Optional[Path]: + """Get path to venv Python executable.""" + if cls.VENV_DIR.exists(): + if sys.platform == 'win32': + return cls.VENV_DIR / 'Scripts' / 'python.exe' + return cls.VENV_DIR / 'bin' / 'python' + return None + + @classmethod + def get_model_path(cls) -> Path: + """Get path to the FILM TorchScript model.""" + return cls.MODEL_CACHE_DIR / cls.MODEL_FILENAME + + @classmethod + def is_setup(cls) -> bool: + """Check if venv exists and FILM model is downloaded.""" + python = cls.get_venv_python() + if not python or not python.exists(): + return False + # Check if model is downloaded + return cls.get_model_path().exists() + + @classmethod + def setup_film(cls, progress_callback=None, cancelled_check=None) -> bool: + """Download FILM model and ensure venv is ready. + + Args: + progress_callback: Optional callback(message, percent) for progress. + cancelled_check: Optional callable that returns True if cancelled. + + Returns: + True if setup was successful. + """ + python = cls.get_venv_python() + if not python or not python.exists(): + # Need to set up base venv first via PracticalRifeEnv + return False + + try: + model_path = cls.get_model_path() + + if not model_path.exists(): + if progress_callback: + progress_callback("Downloading FILM model (~380MB)...", 30) + if cancelled_check and cancelled_check(): + return False + + # Download the pre-trained TorchScript model + cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) + urllib.request.urlretrieve(cls.MODEL_URL, model_path) + + if progress_callback: + progress_callback("FILM setup complete!", 100) + + return cls.is_setup() + + except Exception as e: + print(f"[FILM] Setup error: {e}", file=sys.stderr) + return False + + @classmethod + def get_worker_script(cls) -> Path: + """Get path to the FILM worker script.""" + return Path(__file__).parent / 'film_worker.py' + + @classmethod + def run_interpolation( + cls, + img_a_path: Path, + img_b_path: Path, + output_path: Path, + t: float + ) -> tuple[bool, str]: + """Run FILM interpolation via subprocess in venv. + + Args: + img_a_path: Path to first input image. + img_b_path: Path to second input image. + output_path: Path to output image. + t: Timestep for interpolation (0.0 to 1.0). + + Returns: + Tuple of (success, error_message). + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False, "venv python not found" + + script = cls.get_worker_script() + if not script.exists(): + return False, f"worker script not found: {script}" + + cmd = [ + str(python), str(script), + '--input0', str(img_a_path), + '--input1', str(img_b_path), + '--output', str(output_path), + '--timestep', str(t), + '--repo-dir', str(cls.REPO_DIR), + '--model-dir', str(cls.MODEL_CACHE_DIR) + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=180 # 3 minute timeout per frame (FILM is slower) + ) + if result.returncode == 0 and output_path.exists(): + return True, "" + else: + error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}" + return False, error + except subprocess.TimeoutExpired: + return False, "timeout (180s)" + except Exception as e: + return False, str(e) + + @classmethod + def run_batch_interpolation( + cls, + img_a_path: Path, + img_b_path: Path, + output_dir: Path, + frame_count: int, + output_pattern: str = 'frame_{:04d}.png' + ) -> tuple[bool, str, list[Path]]: + """Run FILM batch interpolation via subprocess in venv. + + Generates all frames at once using FILM's recursive approach, + which produces better results than generating frames independently. + + Args: + img_a_path: Path to first input image. + img_b_path: Path to second input image. + output_dir: Directory to save output frames. + frame_count: Number of frames to generate. + output_pattern: Filename pattern for output frames. + + Returns: + Tuple of (success, error_message, list_of_output_paths). + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False, "venv python not found", [] + + script = cls.get_worker_script() + if not script.exists(): + return False, f"worker script not found: {script}", [] + + output_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + str(python), str(script), + '--input0', str(img_a_path), + '--input1', str(img_b_path), + '--output-dir', str(output_dir), + '--frame-count', str(frame_count), + '--output-pattern', output_pattern, + '--repo-dir', str(cls.REPO_DIR), + '--model-dir', str(cls.MODEL_CACHE_DIR) + ] + + try: + # Longer timeout for batch - scale with frame count + timeout = max(300, frame_count * 30) # At least 5 min, +30s per frame + + print(f"[FILM] Running batch interpolation: {frame_count} frames", file=sys.stderr) + print(f"[FILM] Command: {' '.join(cmd)}", file=sys.stderr) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout + ) + + # Collect output paths + output_paths = [ + output_dir / output_pattern.format(i) + for i in range(frame_count) + ] + existing_paths = [p for p in output_paths if p.exists()] + + if result.returncode == 0 and len(existing_paths) == frame_count: + print(f"[FILM] Success: generated {len(existing_paths)} frames", file=sys.stderr) + return True, "", output_paths + else: + # Combine stdout and stderr for better error reporting + error_parts = [] + if result.returncode != 0: + error_parts.append(f"returncode={result.returncode}") + if result.stdout and result.stdout.strip(): + error_parts.append(f"stdout: {result.stdout.strip()}") + if result.stderr and result.stderr.strip(): + error_parts.append(f"stderr: {result.stderr.strip()}") + if len(existing_paths) != frame_count: + error_parts.append(f"expected {frame_count} frames, got {len(existing_paths)}") + + error = "; ".join(error_parts) if error_parts else "unknown error" + print(f"[FILM] Failed: {error}", file=sys.stderr) + return False, error, existing_paths + + except subprocess.TimeoutExpired: + print(f"[FILM] Timeout after {timeout}s", file=sys.stderr) + return False, f"timeout ({timeout}s)", [] + except Exception as e: + print(f"[FILM] Exception: {e}", file=sys.stderr) + return False, str(e), [] + + class RifeDownloader: """Handles automatic download and caching of rife-ncnn-vulkan binary.""" @@ -824,6 +1050,56 @@ class ImageBlender: # Fall back to ncnn RIFE or optical flow return ImageBlender.rife_blend(img_a, img_b, t) + @staticmethod + def film_blend( + img_a: Image.Image, + img_b: Image.Image, + t: float + ) -> Image.Image: + """Blend using FILM for large motion interpolation. + + FILM (Frame Interpolation for Large Motion) is Google Research's + high-quality frame interpolation model, better for large motion. + + Args: + img_a: First PIL Image (source frame). + img_b: Second PIL Image (target frame). + t: Interpolation factor 0.0 (100% A) to 1.0 (100% B). + + Returns: + AI-interpolated blended PIL Image. + """ + if not FilmEnv.is_setup(): + print("[FILM] Not set up, falling back to Practical-RIFE", file=sys.stderr) + return ImageBlender.practical_rife_blend(img_a, img_b, t) + + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmp = Path(tmpdir) + input_a = tmp / 'a.png' + input_b = tmp / 'b.png' + output_file = tmp / 'out.png' + + # Save input images + img_a.convert('RGB').save(input_a) + img_b.convert('RGB').save(input_b) + + # Run FILM via subprocess + success, error_msg = FilmEnv.run_interpolation( + input_a, input_b, output_file, t + ) + + if success and output_file.exists(): + return Image.open(output_file).copy() + else: + print(f"[FILM] Interpolation failed: {error_msg}, falling back to Practical-RIFE", file=sys.stderr) + + except Exception as e: + print(f"[FILM] Exception: {e}, falling back to Practical-RIFE", file=sys.stderr) + + # Fall back to Practical-RIFE + return ImageBlender.practical_rife_blend(img_a, img_b, t) + @staticmethod def blend_images( img_a_path: Path, @@ -1312,3 +1588,241 @@ class TransitionGenerator: return self.generate_asymmetric_blend_frames( spec, dest, folder_idx_main, base_file_idx ) + + def generate_direct_interpolation_frames( + self, + img_a_path: Path, + img_b_path: Path, + frame_count: int, + method: DirectInterpolationMethod, + dest: Path, + folder_idx: int, + base_file_idx: int, + practical_rife_model: str = 'v4.25', + practical_rife_ensemble: bool = False + ) -> list[BlendResult]: + """Generate AI-interpolated frames between two images. + + Used for direct transitions between MAIN sequences without + a transition folder. + + For FILM: Uses batch mode to generate all frames at once (better quality). + For RIFE: Generates frames one at a time (RIFE handles arbitrary timesteps well). + + Args: + img_a_path: Path to last frame of first sequence. + img_b_path: Path to first frame of second sequence. + frame_count: Number of interpolated frames to generate. + method: Interpolation method (RIFE or FILM). + dest: Destination directory for generated frames. + folder_idx: Folder index for sequence naming. + base_file_idx: Starting file index for sequence naming. + practical_rife_model: Practical-RIFE model version. + practical_rife_ensemble: Enable Practical-RIFE ensemble mode. + + Returns: + List of BlendResult objects. + """ + results = [] + dest.mkdir(parents=True, exist_ok=True) + + # For FILM, use batch mode to generate all frames at once + if method == DirectInterpolationMethod.FILM and FilmEnv.is_setup(): + return self._generate_film_frames_batch( + img_a_path, img_b_path, frame_count, dest, folder_idx, base_file_idx + ) + + # For RIFE (or FILM fallback), generate frames one at a time + # Load source images + img_a = Image.open(img_a_path) + img_b = Image.open(img_b_path) + + # Handle different sizes - resize B to match A + if img_a.size != img_b.size: + img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS) + + # Normalize to RGBA + if img_a.mode != 'RGBA': + img_a = img_a.convert('RGBA') + if img_b.mode != 'RGBA': + img_b = img_b.convert('RGBA') + + for i in range(frame_count): + # Evenly space t values between 0 and 1 (exclusive) + t = (i + 1) / (frame_count + 1) + + # Generate interpolated frame + if method == DirectInterpolationMethod.FILM: + blended = ImageBlender.film_blend(img_a, img_b, t) + else: # RIFE + blended = ImageBlender.practical_rife_blend( + img_a, img_b, t, + practical_rife_model, practical_rife_ensemble + ) + + # Generate output filename + ext = f".{self.settings.output_format.lower()}" + file_idx = base_file_idx + i + output_name = f"seq{folder_idx + 1:02d}_trans_{file_idx:04d}{ext}" + output_path = dest / output_name + + # Save the blended frame + try: + # Convert back to RGB if saving to JPEG + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + blended = blended.convert('RGB') + + # Save with appropriate options + save_kwargs = {} + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + save_kwargs['quality'] = self.settings.output_quality + elif self.settings.output_format.lower() == 'webp': + save_kwargs['lossless'] = True + save_kwargs['method'] = self.settings.webp_method + elif self.settings.output_format.lower() == 'png': + save_kwargs['compress_level'] = 6 + + blended.save(output_path, **save_kwargs) + + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=True + )) + except Exception as e: + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=str(e) + )) + + # Close loaded images + img_a.close() + img_b.close() + + return results + + def _generate_film_frames_batch( + self, + img_a_path: Path, + img_b_path: Path, + frame_count: int, + dest: Path, + folder_idx: int, + base_file_idx: int + ) -> list[BlendResult]: + """Generate FILM frames using batch mode for better quality. + + FILM works best when generating all frames at once using its + recursive approach, rather than generating arbitrary timesteps. + + Args: + img_a_path: Path to last frame of first sequence. + img_b_path: Path to first frame of second sequence. + frame_count: Number of interpolated frames to generate. + dest: Destination directory for generated frames. + folder_idx: Folder index for sequence naming. + base_file_idx: Starting file index for sequence naming. + + Returns: + List of BlendResult objects. + """ + results = [] + + # Generate frames using FILM batch mode + # Use a temp pattern, then rename to final names + temp_pattern = 'film_temp_{:04d}.png' + + success, error, temp_paths = FilmEnv.run_batch_interpolation( + img_a_path, + img_b_path, + dest, + frame_count, + temp_pattern + ) + + if not success: + # Return error results for all frames + for i in range(frame_count): + t = (i + 1) / (frame_count + 1) + ext = f".{self.settings.output_format.lower()}" + file_idx = base_file_idx + i + output_name = f"seq{folder_idx + 1:02d}_trans_{file_idx:04d}{ext}" + output_path = dest / output_name + + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=error + )) + return results + + # Rename temp files to final names and convert format if needed + for i, temp_path in enumerate(temp_paths): + t = (i + 1) / (frame_count + 1) + ext = f".{self.settings.output_format.lower()}" + file_idx = base_file_idx + i + output_name = f"seq{folder_idx + 1:02d}_trans_{file_idx:04d}{ext}" + output_path = dest / output_name + + try: + if temp_path.exists(): + # Load the temp frame + frame = Image.open(temp_path) + + # Convert format if needed + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + frame = frame.convert('RGB') + + # Save with appropriate options + save_kwargs = {} + if self.settings.output_format.lower() in ('jpg', 'jpeg'): + save_kwargs['quality'] = self.settings.output_quality + elif self.settings.output_format.lower() == 'webp': + save_kwargs['lossless'] = True + save_kwargs['method'] = self.settings.webp_method + elif self.settings.output_format.lower() == 'png': + save_kwargs['compress_level'] = 6 + + frame.save(output_path, **save_kwargs) + frame.close() + + # Remove temp file if different from output + if temp_path != output_path: + temp_path.unlink(missing_ok=True) + + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=True + )) + else: + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=f"Temp file not found: {temp_path}" + )) + except Exception as e: + results.append(BlendResult( + output_path=output_path, + source_a=img_a_path, + source_b=img_b_path, + blend_factor=t, + success=False, + error=str(e) + )) + + return results diff --git a/core/film_worker.py b/core/film_worker.py new file mode 100644 index 0000000..fa41edb --- /dev/null +++ b/core/film_worker.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python +"""FILM interpolation worker - runs in isolated venv with PyTorch. + +This script is executed via subprocess from the main application. +It handles frame interpolation using Google Research's FILM model +(Frame Interpolation for Large Motion) via the frame-interpolation-pytorch repo. + +FILM is better than RIFE for large motion and scene gaps, but slower. + +Supports two modes: +1. Single frame: --output with --timestep +2. Batch mode: --output-dir with --frame-count (generates all frames at once) +""" + +import argparse +import sys +import urllib.request +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +# Model download URL +FILM_MODEL_URL = "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.2/film_net_fp32.pt" +FILM_MODEL_FILENAME = "film_net_fp32.pt" + + +def load_image(path: Path, device: torch.device) -> torch.Tensor: + """Load image as tensor. + + Args: + path: Path to image file. + device: Device to load tensor to. + + Returns: + Image tensor (1, 3, H, W) normalized to [0, 1]. + """ + img = Image.open(path).convert('RGB') + arr = np.array(img).astype(np.float32) / 255.0 + tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) + return tensor.to(device) + + +def save_image(tensor: torch.Tensor, path: Path) -> None: + """Save tensor as image. + + Args: + tensor: Image tensor (1, 3, H, W) or (3, H, W) normalized to [0, 1]. + path: Output path. + """ + if tensor.dim() == 4: + tensor = tensor.squeeze(0) + arr = tensor.permute(1, 2, 0).cpu().numpy() + arr = (arr * 255).clip(0, 255).astype(np.uint8) + Image.fromarray(arr).save(path) + + +# Global model cache +_model_cache: dict = {} + + +def download_model(model_dir: Path) -> Path: + """Download FILM model if not present. + + Args: + model_dir: Directory to store the model. + + Returns: + Path to the downloaded model file. + """ + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / FILM_MODEL_FILENAME + + if not model_path.exists(): + print(f"Downloading FILM model to {model_path}...", file=sys.stderr) + urllib.request.urlretrieve(FILM_MODEL_URL, model_path) + print("Download complete.", file=sys.stderr) + + return model_path + + +def get_model(model_dir: Path, device: torch.device): + """Get or load FILM model (cached). + + Args: + model_dir: Model cache directory (for model downloads). + device: Device to run on. + + Returns: + FILM TorchScript model instance. + """ + cache_key = f"film_{device}" + if cache_key not in _model_cache: + # Download model if needed + model_path = download_model(model_dir) + + # Load pre-trained TorchScript model + print(f"Loading FILM model from {model_path}...", file=sys.stderr) + model = torch.jit.load(str(model_path), map_location='cpu') + model.eval() + model.to(device) + _model_cache[cache_key] = model + print("Model loaded.", file=sys.stderr) + + return _model_cache[cache_key] + + +@torch.no_grad() +def interpolate_single(model, img0: torch.Tensor, img1: torch.Tensor, t: float) -> torch.Tensor: + """Perform single frame interpolation using FILM. + + Args: + model: FILM TorchScript model instance. + img0: First frame tensor (1, 3, H, W) normalized to [0, 1]. + img1: Second frame tensor (1, 3, H, W) normalized to [0, 1]. + t: Interpolation timestep (0.0 to 1.0). + + Returns: + Interpolated frame tensor. + """ + # FILM TorchScript model expects dt as tensor of shape (1, 1) + dt = img0.new_full((1, 1), t) + + result = model(img0, img1, dt) + + if isinstance(result, tuple): + result = result[0] + + return result.clamp(0, 1) + + +@torch.no_grad() +def interpolate_batch(model, img0: torch.Tensor, img1: torch.Tensor, frame_count: int) -> list[torch.Tensor]: + """Generate multiple interpolated frames using FILM's recursive approach. + + FILM works best when generating frames recursively - it first generates + the middle frame, then fills in the gaps. This produces more consistent + results than generating arbitrary timesteps independently. + + Args: + model: FILM model instance. + img0: First frame tensor (1, 3, H, W) normalized to [0, 1]. + img1: Second frame tensor (1, 3, H, W) normalized to [0, 1]. + frame_count: Number of frames to generate between img0 and img1. + + Returns: + List of interpolated frame tensors in order. + """ + # Calculate timesteps for evenly spaced frames + timesteps = [(i + 1) / (frame_count + 1) for i in range(frame_count)] + + # Try to use the model's batch/recursive interpolation if available + try: + # Some implementations have an interpolate_recursively method + if hasattr(model, 'interpolate_recursively'): + # This generates 2^n - 1 frames, so we need to handle arbitrary counts + results = model.interpolate_recursively(img0, img1, frame_count) + if len(results) >= frame_count: + return results[:frame_count] + except (AttributeError, TypeError): + pass + + # Fall back to recursive binary interpolation for better quality + # This mimics FILM's natural recursive approach + frames = {} # timestep -> tensor + + def recursive_interpolate(t_left: float, t_right: float, img_left: torch.Tensor, img_right: torch.Tensor, depth: int = 0): + """Recursively interpolate to fill the gap.""" + if depth > 10: # Prevent infinite recursion + return + + t_mid = (t_left + t_right) / 2 + + # Check if we need a frame near t_mid + need_frame = False + for t in timesteps: + if t not in frames and abs(t - t_mid) < 0.5 / (frame_count + 1): + need_frame = True + break + + if not need_frame: + # Check if any remaining timesteps are in this range + remaining = [t for t in timesteps if t not in frames and t_left < t < t_right] + if not remaining: + return + + # Generate middle frame + mid_frame = interpolate_single(model, img_left, img_right, 0.5) + + # Assign to nearest needed timestep + for t in timesteps: + if t not in frames and abs(t - t_mid) < 0.5 / (frame_count + 1): + frames[t] = mid_frame + break + + # Recurse into left and right halves + recursive_interpolate(t_left, t_mid, img_left, mid_frame, depth + 1) + recursive_interpolate(t_mid, t_right, mid_frame, img_right, depth + 1) + + # Start recursive interpolation + recursive_interpolate(0.0, 1.0, img0, img1) + + # Fill any remaining timesteps with direct interpolation + for t in timesteps: + if t not in frames: + frames[t] = interpolate_single(model, img0, img1, t) + + # Return frames in order + return [frames[t] for t in timesteps] + + +def main(): + parser = argparse.ArgumentParser(description='FILM frame interpolation worker') + parser.add_argument('--input0', required=True, help='Path to first input image') + parser.add_argument('--input1', required=True, help='Path to second input image') + parser.add_argument('--output', help='Path to output image (single frame mode)') + parser.add_argument('--output-dir', help='Output directory (batch mode)') + parser.add_argument('--output-pattern', default='frame_{:04d}.png', + help='Output filename pattern for batch mode') + parser.add_argument('--timestep', type=float, default=0.5, + help='Interpolation timestep 0-1 (single frame mode)') + parser.add_argument('--frame-count', type=int, + help='Number of frames to generate (batch mode)') + parser.add_argument('--repo-dir', help='Unused (kept for backward compat)') + parser.add_argument('--model-dir', required=True, help='Model cache directory') + parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use') + + args = parser.parse_args() + + # Validate arguments + batch_mode = args.output_dir is not None and args.frame_count is not None + single_mode = args.output is not None + + if not batch_mode and not single_mode: + print("Error: Must specify either --output (single) or --output-dir + --frame-count (batch)", + file=sys.stderr) + return 1 + + try: + # Select device + if args.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Load model + model_dir = Path(args.model_dir) + model = get_model(model_dir, device) + + # Load images + img0 = load_image(Path(args.input0), device) + img1 = load_image(Path(args.input1), device) + + if batch_mode: + # Batch mode - generate all frames at once + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Generating {args.frame_count} frames...", file=sys.stderr) + frames = interpolate_batch(model, img0, img1, args.frame_count) + + for i, frame in enumerate(frames): + output_path = output_dir / args.output_pattern.format(i) + save_image(frame, output_path) + print(f"Saved {output_path.name}", file=sys.stderr) + + print(f"Success: Generated {len(frames)} frames", file=sys.stderr) + else: + # Single frame mode + result = interpolate_single(model, img0, img1, args.timestep) + save_image(result, Path(args.output)) + print("Success", file=sys.stderr) + + return 0 + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/core/models.py b/core/models.py index 69e279f..a66a71d 100644 --- a/core/models.py +++ b/core/models.py @@ -32,6 +32,12 @@ class FolderType(Enum): TRANSITION = 'transition' +class DirectInterpolationMethod(Enum): + """Method for direct frame interpolation between sequences.""" + RIFE = 'rife' + FILM = 'film' + + # --- Data Classes --- @dataclass @@ -68,6 +74,15 @@ class PerTransitionSettings: right_overlap: int = 16 # frames from trans folder start +@dataclass +class DirectTransitionSettings: + """Settings for direct AI interpolation between sequences (no transition folder).""" + after_folder: Path # The folder after which this transition occurs + frame_count: int = 16 # Number of interpolated frames to generate + method: DirectInterpolationMethod = DirectInterpolationMethod.FILM + enabled: bool = True + + @dataclass class BlendResult: """Result of an image blend operation.""" diff --git a/ui/main_window.py b/ui/main_window.py index b214709..a633b9b 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -49,14 +49,17 @@ from core import ( BlendCurve, BlendMethod, FolderType, + DirectInterpolationMethod, TransitionSettings, PerTransitionSettings, + DirectTransitionSettings, TransitionSpec, SymlinkError, DatabaseManager, TransitionGenerator, RifeDownloader, PracticalRifeEnv, + FilmEnv, SymlinkManager, OPTICAL_FLOW_PRESETS, ) @@ -196,6 +199,216 @@ class OverlapDialog(QDialog): return self.left_spin.value(), self.right_spin.value() +class DirectTransitionDialog(QDialog): + """Dialog for configuring direct frame interpolation between MAIN sequences.""" + + def __init__( + self, + parent: Optional[QWidget], + folder_name: str, + frame_count: int = 16, + method: DirectInterpolationMethod = DirectInterpolationMethod.FILM, + enabled: bool = True + ) -> None: + super().__init__(parent) + self.setWindowTitle("Direct Interpolation Settings") + self.setMinimumWidth(350) + + layout = QVBoxLayout(self) + + # Info label + info_label = QLabel(f"Interpolate after: {folder_name}") + info_label.setWordWrap(True) + layout.addWidget(info_label) + + # Form for settings + form_layout = QFormLayout() + + # Method selection + self.method_combo = QComboBox() + self.method_combo.addItem("RIFE (Fast, small motion)", DirectInterpolationMethod.RIFE) + self.method_combo.addItem("FILM (Slow, large motion)", DirectInterpolationMethod.FILM) + if method == DirectInterpolationMethod.FILM: + self.method_combo.setCurrentIndex(1) + form_layout.addRow("Method:", self.method_combo) + + # Frame count + self.frame_spin = QSpinBox() + self.frame_spin.setRange(1, 60) + self.frame_spin.setValue(frame_count) + self.frame_spin.setToolTip("Number of interpolated frames to generate") + form_layout.addRow("Frames:", self.frame_spin) + + # Enable checkbox + self.enabled_check = QCheckBox("Enabled") + self.enabled_check.setChecked(enabled) + form_layout.addRow("", self.enabled_check) + + layout.addLayout(form_layout) + + # Status label for setup state + self.status_label = QLabel() + self.status_label.setStyleSheet("font-size: 10px;") + layout.addWidget(self.status_label) + + # Setup button (for installing RIFE/FILM) + self.setup_btn = QPushButton("Setup PyTorch Environment") + self.setup_btn.setToolTip("Install PyTorch and required packages") + self.setup_btn.clicked.connect(self._on_setup) + layout.addWidget(self.setup_btn) + + # Explanation + explain = QLabel( + "RIFE: Fast AI interpolation, best for small motion and color shifts.\n" + "FILM: Google Research model, better for large motion and scene gaps.\n\n" + "Generated frames bridge the gap between the last frame of this\n" + "sequence and the first frame of the next MAIN sequence." + ) + explain.setStyleSheet("color: gray; font-size: 10px;") + explain.setWordWrap(True) + layout.addWidget(explain) + + # Buttons + button_layout = QHBoxLayout() + + self.remove_btn = QPushButton("Remove") + self.remove_btn.setToolTip("Remove this direct transition") + button_layout.addWidget(self.remove_btn) + + button_layout.addStretch() + + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + button_layout.addWidget(buttons) + + layout.addLayout(button_layout) + + self._removed = False + self.remove_btn.clicked.connect(self._on_remove) + self.method_combo.currentIndexChanged.connect(self._update_status) + self._update_status() + + def _update_status(self) -> None: + """Update the status label and setup button based on current method.""" + method = self.method_combo.currentData() + + rife_ready = PracticalRifeEnv.is_setup() + film_ready = FilmEnv.is_setup() if rife_ready else False + + if method == DirectInterpolationMethod.RIFE: + if rife_ready: + self.status_label.setText("RIFE: Ready") + self.status_label.setStyleSheet("color: green; font-size: 10px;") + self.setup_btn.setVisible(False) + else: + self.status_label.setText("RIFE: Not installed (PyTorch required)") + self.status_label.setStyleSheet("color: orange; font-size: 10px;") + self.setup_btn.setVisible(True) + self.setup_btn.setText("Setup PyTorch Environment") + else: # FILM + if film_ready: + self.status_label.setText("FILM: Ready") + self.status_label.setStyleSheet("color: green; font-size: 10px;") + self.setup_btn.setVisible(False) + elif rife_ready: + self.status_label.setText("FILM: Package not installed") + self.status_label.setStyleSheet("color: orange; font-size: 10px;") + self.setup_btn.setVisible(True) + self.setup_btn.setText("Install FILM Package") + else: + self.status_label.setText("FILM: Not installed (PyTorch required first)") + self.status_label.setStyleSheet("color: orange; font-size: 10px;") + self.setup_btn.setVisible(True) + self.setup_btn.setText("Setup PyTorch Environment") + + def _on_setup(self) -> None: + """Handle setup button click.""" + method = self.method_combo.currentData() + rife_ready = PracticalRifeEnv.is_setup() + + if not rife_ready: + # Need to set up PyTorch venv first + progress = QProgressDialog( + "Setting up PyTorch environment...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("Setup") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(0) + + cancelled = [False] + + def progress_cb(msg, pct): + progress.setLabelText(msg) + progress.setValue(pct) + + def cancelled_check(): + QApplication.processEvents() + return progress.wasCanceled() + + success = PracticalRifeEnv.setup_venv(progress_cb, cancelled_check) + progress.close() + + if not success: + if not cancelled_check(): + QMessageBox.warning( + self, "Setup Failed", + "Failed to set up PyTorch environment." + ) + return + + # If FILM selected and we need to install FILM package + if method == DirectInterpolationMethod.FILM and not FilmEnv.is_setup(): + progress = QProgressDialog( + "Installing FILM package...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("Setup") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(0) + + def progress_cb(msg, pct): + progress.setLabelText(msg) + progress.setValue(pct) + + def cancelled_check(): + QApplication.processEvents() + return progress.wasCanceled() + + success = FilmEnv.setup_film(progress_cb, cancelled_check) + progress.close() + + if not success: + if not cancelled_check(): + QMessageBox.warning( + self, "Setup Failed", + "Failed to install FILM package." + ) + return + + self._update_status() + + def _on_remove(self) -> None: + """Handle remove button click.""" + self._removed = True + self.reject() + + def was_removed(self) -> bool: + """Check if the user clicked Remove.""" + return self._removed + + def get_values(self) -> tuple[DirectInterpolationMethod, int, bool]: + """Get the dialog values.""" + return ( + self.method_combo.currentData(), + self.frame_spin.value(), + self.enabled_check.isChecked() + ) + + class SequenceLinkerUI(QWidget): """PyQt6 GUI for the Video Montage Linker.""" @@ -210,6 +423,7 @@ class SequenceLinkerUI(QWidget): self._folder_type_overrides: dict[Path, FolderType] = {} self._transition_settings = TransitionSettings() self._per_transition_settings: dict[Path, PerTransitionSettings] = {} + self._direct_transitions: dict[Path, DirectTransitionSettings] = {} self._current_session_id: Optional[int] = None self.db = DatabaseManager() self.manager = SymlinkManager(self.db) @@ -831,6 +1045,8 @@ class SequenceLinkerUI(QWidget): # Sequence table selection - show image self.sequence_table.currentItemChanged.connect(self._on_sequence_table_selected) + # Also handle clicks on non-selectable items (direct interpolation rows) + self.sequence_table.itemClicked.connect(self._on_sequence_table_clicked) # Update sequence table when transitions setting changes self.transition_group.toggled.connect(self._update_sequence_table) @@ -1276,6 +1492,18 @@ class SequenceLinkerUI(QWidget): trans_at_main_end[trans.main_folder] = trans trans_at_trans_start[trans.trans_folder] = trans + # Find consecutive MAIN folders (for direct interpolation) + consecutive_main_pairs: list[tuple[int, int]] = [] + for i in range(len(self.source_folders) - 1): + folder_a = self.source_folders[i] + folder_b = self.source_folders[i + 1] + type_a = self._get_effective_folder_type(i, folder_a) + type_b = self._get_effective_folder_type(i + 1, folder_b) + # Two consecutive MAIN folders with no transition between them + if type_a == FolderType.MAIN and type_b == FolderType.MAIN: + if folder_a not in trans_at_main_end: + consecutive_main_pairs.append((i, i + 1)) + # Process each folder for folder_idx, folder in enumerate(self.source_folders): folder_files = files_by_folder.get(folder, []) @@ -1339,9 +1567,56 @@ class SequenceLinkerUI(QWidget): self.sequence_table.addTopLevelItem(item) + # Check if this folder starts a direct interpolation gap + # (current MAIN followed by another MAIN with no transition) + for pair_idx_a, pair_idx_b in consecutive_main_pairs: + if folder_idx == pair_idx_a: + # Add direct interpolation row after this folder's files + self._add_direct_interpolation_row(folder, pair_idx_b) + # Update timeline display after rebuilding sequence table self._update_timeline_display() + def _add_direct_interpolation_row(self, after_folder: Path, next_folder_idx: int) -> None: + """Add a clickable direct interpolation row between MAIN sequences. + + Args: + after_folder: The folder after which interpolation occurs. + next_folder_idx: Index of the next MAIN folder. + """ + direct_settings = self._direct_transitions.get(after_folder) + + if direct_settings and direct_settings.enabled: + # Configured: show green row with settings + placeholder frames + method_name = direct_settings.method.value.upper() + frame_count = direct_settings.frame_count + + # Header row (clickable to edit) + header_text = f" [{method_name}: {frame_count} frames] (click to edit)" + header_item = QTreeWidgetItem([header_text, ""]) + header_item.setData(0, Qt.ItemDataRole.UserRole, ('direct_header', after_folder)) + header_item.setForeground(0, QColor(50, 180, 100)) # Green + header_item.setFlags(header_item.flags() & ~Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsEnabled) + self.sequence_table.addTopLevelItem(header_item) + + # Add placeholder rows for each interpolated frame + for i in range(frame_count): + placeholder_text = f" [{method_name} {i + 1}/{frame_count}]" + placeholder_item = QTreeWidgetItem([placeholder_text, ""]) + placeholder_item.setData(0, Qt.ItemDataRole.UserRole, ('direct_placeholder', after_folder, i)) + placeholder_item.setForeground(0, QColor(100, 180, 220)) # Light blue + # Make placeholders non-selectable + placeholder_item.setFlags(placeholder_item.flags() & ~Qt.ItemFlag.ItemIsSelectable) + self.sequence_table.addTopLevelItem(placeholder_item) + else: + # Unconfigured: show grey "+" row + add_text = " [+ Add RIFE/FILM transition] (click to configure)" + add_item = QTreeWidgetItem([add_text, ""]) + add_item.setData(0, Qt.ItemDataRole.UserRole, ('direct_add', after_folder)) + add_item.setForeground(0, QColor(150, 150, 150)) # Grey + add_item.setFlags(add_item.flags() & ~Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsEnabled) + self.sequence_table.addTopLevelItem(add_item) + def _on_sequence_table_selected(self, current, previous) -> None: """Handle sequence table row selection - show image in preview.""" if current is None: @@ -1355,6 +1630,25 @@ class SequenceLinkerUI(QWidget): if not data: return + # Handle direct interpolation rows + if isinstance(data, tuple) and len(data) >= 2: + if data[0] == 'direct_add': + # "+" row - only open dialog if not playing (skip during playback) + if not self.sequence_playing: + self._show_direct_transition_dialog(data[1]) + return + elif data[0] == 'direct_header': + # Header row - only open dialog if not playing (skip during playback) + if not self.sequence_playing: + self._show_direct_transition_dialog(data[1]) + return + elif data[0] == 'direct_placeholder': + # Show preview of interpolated frame + after_folder = data[1] + frame_index = data[2] + self._show_direct_interpolation_preview(after_folder, frame_index) + return + frame_type = data[4] if len(data) > 4 else 'symlink' # For blend frames, generate cross-dissolve preview @@ -1389,6 +1683,29 @@ class SequenceLinkerUI(QWidget): seq_name = f"seq{data[2] + 1:02d}_{data[3]:04d}" self.image_name_label.setText(f"{seq_name} ({filename})") + def _on_sequence_table_clicked(self, item, column: int) -> None: + """Handle clicks on sequence table items, including non-selectable ones.""" + if item is None: + return + + data = item.data(0, Qt.ItemDataRole.UserRole) + if not data: + return + + # Handle direct interpolation rows + if isinstance(data, tuple) and len(data) >= 2: + if data[0] == 'direct_add': + # Clicked on "+" row to add direct transition + self._show_direct_transition_dialog(data[1]) + elif data[0] == 'direct_header': + # Clicked on configured direct transition header + self._show_direct_transition_dialog(data[1]) + elif data[0] == 'direct_placeholder': + # Clicked on placeholder row - show preview of interpolated frame + after_folder = data[1] + frame_index = data[2] + self._show_direct_interpolation_preview(after_folder, frame_index) + def _show_blend_preview(self, item, data0, data1) -> None: """Show a cross-dissolve preview for a blend frame.""" from PIL import Image @@ -1516,6 +1833,267 @@ class SequenceLinkerUI(QWidget): self.image_name_label.setText("") self._current_pixmap = None + def _show_direct_interpolation_preview(self, after_folder: Path, frame_index: int) -> None: + """Generate and show a preview for a direct interpolation placeholder frame. + + For RIFE: Generates one frame at a time (RIFE handles arbitrary timesteps well). + For FILM: Generates ALL frames at once on first click (FILM works best this way), + then caches all frames for instant subsequent access. + + Args: + after_folder: The folder after which the interpolation occurs. + frame_index: The index of the interpolated frame (0-based). + """ + from PIL import Image + from PIL.ImageQt import ImageQt + from core import ImageBlender + + # Get direct transition settings + direct_settings = self._direct_transitions.get(after_folder) + if not direct_settings or not direct_settings.enabled: + self.image_label.setText("Direct interpolation not configured") + self.image_name_label.setText("") + self._current_pixmap = None + return + + # Find the folder index and next folder + try: + folder_idx = self.source_folders.index(after_folder) + except ValueError: + self.image_label.setText("Folder not found in sequence") + self.image_name_label.setText("") + self._current_pixmap = None + return + + if folder_idx >= len(self.source_folders) - 1: + self.image_label.setText("No next folder for interpolation") + self.image_name_label.setText("") + self._current_pixmap = None + return + + next_folder = self.source_folders[folder_idx + 1] + + # Get files for both folders + files = self._get_files_in_order() + files_by_folder: dict[Path, list[str]] = {} + for source_dir, filename, f_idx, file_idx in files: + if source_dir not in files_by_folder: + files_by_folder[source_dir] = [] + files_by_folder[source_dir].append(filename) + + after_files = files_by_folder.get(after_folder, []) + next_files = files_by_folder.get(next_folder, []) + + if not after_files or not next_files: + self.image_label.setText("Missing frames for interpolation") + self.image_name_label.setText("") + self._current_pixmap = None + return + + # Get last frame of after_folder and first frame of next_folder + last_frame_path = after_folder / after_files[-1] + first_frame_path = next_folder / next_files[0] + + if not last_frame_path.exists() or not first_frame_path.exists(): + self.image_label.setText(f"Frame files not found") + self.image_name_label.setText("") + self._current_pixmap = None + return + + # Calculate timestep + frame_count = direct_settings.frame_count + t = (frame_index + 1) / (frame_count + 1) # Evenly spaced between 0 and 1 + + # Create cache key - include frame_count so changing count invalidates cache + cache_key = f"direct|{after_folder}|{frame_index}|{direct_settings.method.value}|{frame_count}" + + try: + # Check cache first + if cache_key in self._blend_preview_cache: + pixmap = self._blend_preview_cache[cache_key] + elif direct_settings.method == DirectInterpolationMethod.FILM and FilmEnv.is_setup(): + # FILM: Generate ALL frames at once for better quality + # Check if we need to generate (first frame not cached means none are) + first_cache_key = f"direct|{after_folder}|0|{direct_settings.method.value}|{frame_count}" + if first_cache_key not in self._blend_preview_cache: + # Generate all frames at once + error_msg = self._generate_all_film_preview_frames( + after_folder, last_frame_path, first_frame_path, frame_count + ) + if error_msg: + # Error already displayed in image_label by the method + self._current_pixmap = None + return + + # Now retrieve the specific frame from cache + if cache_key in self._blend_preview_cache: + pixmap = self._blend_preview_cache[cache_key] + else: + # Fallback if batch generation failed + self.image_label.setText("FILM batch generation failed - check console for details") + self.image_name_label.setText("") + self._current_pixmap = None + return + else: + # RIFE (or FILM not set up): Generate one frame at a time + # Load images + img_a = Image.open(last_frame_path) + img_b = Image.open(first_frame_path) + + # Resize B to match A if needed + if img_a.size != img_b.size: + img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS) + + # Convert to RGBA + if img_a.mode != 'RGBA': + img_a = img_a.convert('RGBA') + if img_b.mode != 'RGBA': + img_b = img_b.convert('RGBA') + + # Generate interpolated frame + if direct_settings.method == DirectInterpolationMethod.FILM: + # FILM not set up, use fallback + blended = ImageBlender.film_blend(img_a, img_b, t) + else: # RIFE + settings = self._get_transition_settings() + blended = ImageBlender.practical_rife_blend( + img_a, img_b, t, + settings.practical_rife_model, + settings.practical_rife_ensemble + ) + + # Convert to QPixmap + qim = ImageQt(blended.convert('RGBA')) + pixmap = QPixmap.fromImage(qim) + + # Store in cache + self._blend_preview_cache[cache_key] = pixmap + + img_a.close() + img_b.close() + + self._current_pixmap = pixmap + self._apply_zoom() + + # Update labels + method_name = direct_settings.method.value.upper() + self.image_name_label.setText( + f"[{method_name} {frame_index + 1}/{frame_count}] @ t={t:.2f}" + ) + + # Find the item index in the table for image_index_label + for i in range(self.sequence_table.topLevelItemCount()): + item = self.sequence_table.topLevelItem(i) + item_data = item.data(0, Qt.ItemDataRole.UserRole) + if (isinstance(item_data, tuple) and len(item_data) >= 3 and + item_data[0] == 'direct_placeholder' and + item_data[1] == after_folder and + item_data[2] == frame_index): + total = self.sequence_table.topLevelItemCount() + self.image_index_label.setText(f"{i + 1} / {total}") + break + + except Exception as e: + self.image_label.setText(f"Error generating interpolation preview:\n{e}") + self.image_name_label.setText("") + self._current_pixmap = None + + def _generate_all_film_preview_frames( + self, + after_folder: Path, + last_frame_path: Path, + first_frame_path: Path, + frame_count: int + ) -> Optional[str]: + """Generate all FILM preview frames at once and cache them. + + FILM works best when generating all frames at once using its + recursive approach. This method generates all frames and stores + them in the preview cache. + + Args: + after_folder: The folder after which the interpolation occurs. + last_frame_path: Path to the last frame of the current sequence. + first_frame_path: Path to the first frame of the next sequence. + frame_count: Number of frames to generate. + + Returns: + None on success, error message string on failure. + """ + from PIL import Image + from PIL.ImageQt import ImageQt + import tempfile + + # Show progress dialog + progress = QProgressDialog( + f"Generating {frame_count} FILM frames...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("FILM Interpolation") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(10) + QApplication.processEvents() + + try: + # Use a temp directory for FILM batch output + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + + progress.setLabelText("Running FILM batch interpolation...") + progress.setValue(20) + QApplication.processEvents() + + # Run batch interpolation + success, error, output_paths = FilmEnv.run_batch_interpolation( + last_frame_path, + first_frame_path, + tmp_path, + frame_count, + 'frame_{:04d}.png' + ) + + if not success: + progress.close() + error_msg = f"FILM error: {error}" + self.image_label.setText(error_msg) + self.image_name_label.setText("") + return error_msg + + progress.setLabelText("Loading generated frames...") + progress.setValue(70) + QApplication.processEvents() + + # Load all frames and cache them + for i, output_path in enumerate(output_paths): + if progress.wasCanceled(): + break + + if output_path.exists(): + frame = Image.open(output_path) + qim = ImageQt(frame.convert('RGBA')) + pixmap = QPixmap.fromImage(qim) + + # Cache with the standard key format (include frame_count) + cache_key = f"direct|{after_folder}|{i}|film|{frame_count}" + self._blend_preview_cache[cache_key] = pixmap + + frame.close() + + # Update progress + pct = 70 + int(30 * (i + 1) / frame_count) + progress.setValue(pct) + QApplication.processEvents() + + progress.close() + return None # Success + + except Exception as e: + progress.close() + error_msg = f"FILM batch error: {e}" + self.image_label.setText(error_msg) + self.image_name_label.setText("") + return error_msg + def _update_timeline_display(self) -> None: """Update the timeline duration display based on frame count and FPS.""" frame_count = self.sequence_table.topLevelItemCount() @@ -1575,12 +2153,21 @@ class SequenceLinkerUI(QWidget): current_idx = self.sequence_table.indexOfTopLevelItem(current_item) total = self.sequence_table.topLevelItemCount() - if current_idx < total - 1: - next_item = self.sequence_table.topLevelItem(current_idx + 1) + # Find next valid frame (skip direct_add and direct_header rows) + next_idx = current_idx + 1 + while next_idx < total: + next_item = self.sequence_table.topLevelItem(next_idx) + data = next_item.data(0, Qt.ItemDataRole.UserRole) + # Skip non-frame rows (direct_add, direct_header) + if isinstance(data, tuple) and len(data) >= 1 and data[0] in ('direct_add', 'direct_header'): + next_idx += 1 + continue + # Found a valid frame self.sequence_table.setCurrentItem(next_item) - else: - # Reached end - stop playback - self._stop_sequence_play() + return + + # Reached end - stop playback + self._stop_sequence_play() def _browse_trans_destination(self) -> None: """Select transition destination folder via file dialog.""" @@ -1728,7 +2315,7 @@ class SequenceLinkerUI(QWidget): return def _remove_source_folder(self) -> None: - """Remove selected source folder(s).""" + """Remove selected source folder(s), preserving sequence order of remaining files.""" result = self._get_selected_folder() if result is None: return @@ -1739,13 +2326,72 @@ class SequenceLinkerUI(QWidget): del self._folder_type_overrides[folder] if folder in self._per_transition_settings: del self._per_transition_settings[folder] + if folder in self._folder_trim_settings: + del self._folder_trim_settings[folder] + if folder in self._folder_file_counts: + del self._folder_file_counts[folder] del self.source_folders[idx] self._sync_dual_lists() - self._refresh_files() + + # Remove only files from the deleted folder, preserving order of others + self._remove_files_from_folder(folder) + + # Renumber sequence names to reflect new folder indices + self._recalculate_sequence_names() + + # Update the sequence table (With Transitions tab) + self._update_sequence_table() + self._update_flow_arrows() + def _remove_files_from_folder(self, folder: Path) -> None: + """Remove all files from a specific folder without affecting order of other files.""" + folder_str = str(folder) + rows_to_remove = [] + + for i in range(self.file_list.topLevelItemCount()): + item = self.file_list.topLevelItem(i) + if item and item.text(2) == folder_str: + rows_to_remove.append(i) + + # Remove in reverse order to preserve indices + for row in reversed(rows_to_remove): + self.file_list.takeTopLevelItem(row) + + # Clean up separators (remove consecutive or leading/trailing separators) + self._cleanup_separators() + + # Update slider range + total = self.file_list.topLevelItemCount() + self.image_slider.setRange(0, max(0, total - 1)) + + def _cleanup_separators(self) -> None: + """Remove unnecessary separators (consecutive, leading, or trailing).""" + rows_to_remove = [] + prev_was_separator = True # Treat start as "separator" to remove leading ones + + for i in range(self.file_list.topLevelItemCount()): + item = self.file_list.topLevelItem(i) + is_separator = self._is_separator_item(item) + + if is_separator and prev_was_separator: + rows_to_remove.append(i) + prev_was_separator = is_separator + + # Check if last item is a separator + if self.file_list.topLevelItemCount() > 0: + last_item = self.file_list.topLevelItem(self.file_list.topLevelItemCount() - 1) + if self._is_separator_item(last_item): + last_idx = self.file_list.topLevelItemCount() - 1 + if last_idx not in rows_to_remove: + rows_to_remove.append(last_idx) + + # Remove in reverse order + for row in sorted(rows_to_remove, reverse=True): + self.file_list.takeTopLevelItem(row) + def _remove_selected_files(self) -> None: """Remove selected files from the file list.""" selected = self.file_list.selectedItems() @@ -1756,6 +2402,9 @@ class SequenceLinkerUI(QWidget): for row in rows: self.file_list.takeTopLevelItem(row) + # Update the With Transitions tab to reflect the removal + self._update_sequence_table() + def _get_path_history_file(self) -> Path: """Get the path to the history JSON file.""" cache_dir = Path.home() / '.cache' / 'video-montage-linker' @@ -2154,6 +2803,38 @@ class SequenceLinkerUI(QWidget): self._sync_dual_lists() self._update_sequence_table() + def _show_direct_transition_dialog(self, after_folder: Path) -> None: + """Show dialog to configure direct frame interpolation between sequences.""" + existing = self._direct_transitions.get(after_folder) + if existing: + frame_count = existing.frame_count + method = existing.method + enabled = existing.enabled + else: + frame_count = 16 + method = DirectInterpolationMethod.FILM + enabled = True + + dialog = DirectTransitionDialog( + self, after_folder.name, frame_count, method, enabled + ) + result = dialog.exec() + + if dialog.was_removed(): + # User clicked Remove + if after_folder in self._direct_transitions: + del self._direct_transitions[after_folder] + self._update_sequence_table() + elif result == QDialog.DialogCode.Accepted: + new_method, new_count, new_enabled = dialog.get_values() + self._direct_transitions[after_folder] = DirectTransitionSettings( + after_folder=after_folder, + frame_count=new_count, + method=new_method, + enabled=new_enabled + ) + self._update_sequence_table() + def _set_folder_type(self, folder: Path, folder_type: FolderType) -> None: """Set the folder type override for a folder.""" if folder_type == FolderType.AUTO: @@ -2223,6 +2904,7 @@ class SequenceLinkerUI(QWidget): self._folder_file_counts = {folder: len(files) for folder, files in files_by_folder.items()} folder_file_counts: dict[Path, int] = {} + is_first_folder = True for folder in self.source_folders: if folder not in files_by_folder: continue @@ -2237,8 +2919,17 @@ class SequenceLinkerUI(QWidget): end_idx = total_in_folder - trim_end trimmed_files = folder_files[trim_start:end_idx] + if not trimmed_files: + continue + folder_idx = folder_to_index.get(folder, 0) + # Add separator between folders (not before first) + if not is_first_folder: + separator = self._create_folder_separator(folder_idx) + self.file_list.addTopLevelItem(separator) + is_first_folder = False + for filename in trimmed_files: file_idx = folder_file_counts.get(folder, 0) folder_file_counts[folder] = file_idx + 1 @@ -2261,6 +2952,22 @@ class SequenceLinkerUI(QWidget): self._update_trim_slider_for_selected_folder() self._update_sequence_table() + def _create_folder_separator(self, next_folder_idx: int) -> QTreeWidgetItem: + """Create a visual separator item between folders.""" + separator = QTreeWidgetItem(["", f"── Sequence {next_folder_idx + 1} ──", ""]) + separator.setData(0, Qt.ItemDataRole.UserRole, None) # No data = separator + # Light grey background + grey = QColor(220, 220, 220) + for col in range(3): + separator.setBackground(col, grey) + # Make it non-selectable and non-draggable + separator.setFlags(Qt.ItemFlag.NoItemFlags) + return separator + + def _is_separator_item(self, item: QTreeWidgetItem) -> bool: + """Check if an item is a folder separator.""" + return item.data(0, Qt.ItemDataRole.UserRole) is None + def _get_files_in_order(self) -> list[tuple[Path, str, int, int]]: """Get files in the current list order with sequence info.""" files = [] @@ -2278,6 +2985,7 @@ class SequenceLinkerUI(QWidget): folder_to_index = {folder: i for i, folder in enumerate(self.source_folders)} folder_file_counts: dict[Path, int] = {} + last_folder_idx = -1 for i in range(self.file_list.topLevelItemCount()): item = self.file_list.topLevelItem(i) @@ -2293,6 +3001,21 @@ class SequenceLinkerUI(QWidget): seq_name = f"seq{folder_idx + 1:02d}_{file_idx:04d}{ext}" item.setText(0, seq_name) item.setData(0, Qt.ItemDataRole.UserRole, (source_dir, filename, folder_idx, file_idx)) + last_folder_idx = folder_idx + elif self._is_separator_item(item): + # Update separator label based on next file's folder + # Look ahead to find the next file's folder index + next_folder_idx = last_folder_idx + 1 + for j in range(i + 1, self.file_list.topLevelItemCount()): + next_item = self.file_list.topLevelItem(j) + next_data = next_item.data(0, Qt.ItemDataRole.UserRole) + if next_data: + next_folder_idx = folder_to_index.get(next_data[0], last_folder_idx + 1) + break + item.setText(1, f"── Sequence {next_folder_idx + 1} ──") + + # Update the With Transitions tab to reflect the new order + self._update_sequence_table() # --- Video Preview Methods --- @@ -2773,7 +3496,12 @@ class SequenceLinkerUI(QWidget): trans_at_main_end[trans.main_folder] = trans trans_at_trans_start[trans.trans_folder] = trans + # Count total files including direct interpolation frames total_files = sum(len(f) for f in files_by_folder.values()) + for folder, direct_settings in self._direct_transitions.items(): + if direct_settings.enabled: + total_files += direct_settings.frame_count + progress = QProgressDialog("Generating sequence...", "Cancel", 0, total_files, self) progress.setWindowTitle("Cross-Dissolve Generation") progress.setWindowModality(Qt.WindowModality.WindowModal) @@ -2892,6 +3620,55 @@ class SequenceLinkerUI(QWidget): current_op += 1 progress.setValue(current_op) + # Check for direct interpolation after this folder + if folder in self._direct_transitions: + direct_settings = self._direct_transitions[folder] + if direct_settings.enabled: + # Find next folder and get its first frame + next_folder_idx = folder_idx + 1 + if next_folder_idx < len(self.source_folders): + next_folder = self.source_folders[next_folder_idx] + next_files = files_by_folder.get(next_folder, []) + if next_files and folder_files: + # Get last frame of current folder and first of next + last_frame = folder / folder_files[-1] + first_frame = next_folder / next_files[0] + + progress.setLabelText( + f"Generating {direct_settings.method.value.upper()} frames..." + ) + + # Generate direct interpolation frames + direct_results = generator.generate_direct_interpolation_frames( + last_frame, + first_frame, + direct_settings.frame_count, + direct_settings.method, + trans_dest, + folder_idx, + output_seq, + settings.practical_rife_model, + settings.practical_rife_ensemble + ) + + for result in direct_results: + if result.success: + blend_count += 1 + self.db.record_symlink( + session_id, + str(result.source_a.resolve()), + str(result.output_path), + result.output_path.name, + output_seq + ) + else: + errors.append( + f"Direct interp {result.output_path.name}: {result.error}" + ) + output_seq += 1 + + progress.setLabelText("Generating sequence...") + progress.close() if progress.wasCanceled():