diff --git a/.gitignore b/.gitignore index 9460c99..aef96bd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/ *.pyc *.pyo .env +venv-rife/ diff --git a/core/__init__.py b/core/__init__.py index 02ccc3a..7d46cd4 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -19,7 +19,7 @@ from .models import ( DatabaseError, ) from .database import DatabaseManager -from .blender import ImageBlender, TransitionGenerator, RifeDownloader +from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv from .manager import SymlinkManager __all__ = [ @@ -43,5 +43,6 @@ __all__ = [ 'ImageBlender', 'TransitionGenerator', 'RifeDownloader', + 'PracticalRifeEnv', 'SymlinkManager', ] diff --git a/core/blender.py b/core/blender.py index d87051f..9058e62 100644 --- a/core/blender.py +++ b/core/blender.py @@ -29,6 +29,214 @@ from .models import ( # Cache directory for downloaded binaries CACHE_DIR = Path.home() / '.cache' / 'video-montage-linker' RIFE_GITHUB_API = 'https://api.github.com/repos/nihui/rife-ncnn-vulkan/releases/latest' +PRACTICAL_RIFE_VENV_DIR = Path('./venv-rife') + + +class PracticalRifeEnv: + """Manages isolated Python environment for Practical-RIFE.""" + + VENV_DIR = PRACTICAL_RIFE_VENV_DIR + MODEL_CACHE_DIR = CACHE_DIR / 'practical-rife' + REQUIRED_PACKAGES = ['torch', 'torchvision', 'numpy'] + + # Available Practical-RIFE models + AVAILABLE_MODELS = ['v4.26', 'v4.25', 'v4.22', 'v4.20', 'v4.18', 'v4.15'] + + @classmethod + def get_venv_python(cls) -> Optional[Path]: + """Get path to venv Python executable.""" + if cls.VENV_DIR.exists(): + if sys.platform == 'win32': + return cls.VENV_DIR / 'Scripts' / 'python.exe' + return cls.VENV_DIR / 'bin' / 'python' + return None + + @classmethod + def is_setup(cls) -> bool: + """Check if venv exists and has required packages.""" + python = cls.get_venv_python() + if not python or not python.exists(): + return False + # Check if torch is importable + result = subprocess.run( + [str(python), '-c', 'import torch; print(torch.__version__)'], + capture_output=True + ) + return result.returncode == 0 + + @classmethod + def get_torch_version(cls) -> Optional[str]: + """Get installed torch version in venv.""" + python = cls.get_venv_python() + if not python or not python.exists(): + return None + result = subprocess.run( + [str(python), '-c', 'import torch; print(torch.__version__)'], + capture_output=True, + text=True + ) + if result.returncode == 0: + return result.stdout.strip() + return None + + @classmethod + def setup_venv(cls, progress_callback=None, cancelled_check=None) -> bool: + """Create venv and install PyTorch. + + Args: + progress_callback: Optional callback(message, percent) for progress. + cancelled_check: Optional callable that returns True if cancelled. + + Returns: + True if setup was successful. + """ + import venv + + try: + # 1. Create venv + if progress_callback: + progress_callback("Creating virtual environment...", 10) + if cancelled_check and cancelled_check(): + return False + + # Remove old venv if exists + if cls.VENV_DIR.exists(): + shutil.rmtree(cls.VENV_DIR) + + venv.create(cls.VENV_DIR, with_pip=True) + + # 2. Get pip path + python = cls.get_venv_python() + if not python: + return False + + # 3. Upgrade pip + if progress_callback: + progress_callback("Upgrading pip...", 20) + if cancelled_check and cancelled_check(): + return False + + subprocess.run( + [str(python), '-m', 'pip', 'install', '--upgrade', 'pip'], + capture_output=True, + check=True + ) + + # 4. Install PyTorch (this is the big download) + if progress_callback: + progress_callback("Installing PyTorch (this may take a while)...", 30) + if cancelled_check and cancelled_check(): + return False + + # Try to install with CUDA support first, fall back to CPU + # Use pip index to get the right version + result = subprocess.run( + [str(python), '-m', 'pip', 'install', 'torch', 'torchvision'], + capture_output=True, + text=True + ) + + if result.returncode != 0: + # Try CPU-only version + subprocess.run( + [str(python), '-m', 'pip', 'install', + 'torch', 'torchvision', + '--index-url', 'https://download.pytorch.org/whl/cpu'], + capture_output=True, + check=True + ) + + if progress_callback: + progress_callback("Installing numpy...", 90) + if cancelled_check and cancelled_check(): + return False + + # numpy is usually a dependency of torch but ensure it's there + subprocess.run( + [str(python), '-m', 'pip', 'install', 'numpy'], + capture_output=True + ) + + if progress_callback: + progress_callback("Setup complete!", 100) + + return cls.is_setup() + + except Exception as e: + # Cleanup on error + if cls.VENV_DIR.exists(): + try: + shutil.rmtree(cls.VENV_DIR) + except Exception: + pass + return False + + @classmethod + def get_available_models(cls) -> list[str]: + """Return list of available model versions.""" + return cls.AVAILABLE_MODELS.copy() + + @classmethod + def get_worker_script(cls) -> Path: + """Get path to the RIFE worker script.""" + return Path(__file__).parent / 'rife_worker.py' + + @classmethod + def run_interpolation( + cls, + img_a_path: Path, + img_b_path: Path, + output_path: Path, + t: float, + model: str = 'v4.25', + ensemble: bool = False + ) -> bool: + """Run RIFE interpolation via subprocess in venv. + + Args: + img_a_path: Path to first input image. + img_b_path: Path to second input image. + output_path: Path to output image. + t: Timestep for interpolation (0.0 to 1.0). + model: Model version to use. + ensemble: Enable ensemble mode. + + Returns: + True if interpolation succeeded. + """ + python = cls.get_venv_python() + if not python or not python.exists(): + return False + + script = cls.get_worker_script() + if not script.exists(): + return False + + cmd = [ + str(python), str(script), + '--input0', str(img_a_path), + '--input1', str(img_b_path), + '--output', str(output_path), + '--timestep', str(t), + '--model', model, + '--model-dir', str(cls.MODEL_CACHE_DIR) + ] + + if ensemble: + cmd.append('--ensemble') + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=120 # 2 minute timeout per frame + ) + return result.returncode == 0 and output_path.exists() + except subprocess.TimeoutExpired: + return False + except Exception: + return False class RifeDownloader: @@ -398,7 +606,10 @@ class ImageBlender: img_b: Image.Image, t: float, binary_path: Optional[Path] = None, - auto_download: bool = True + auto_download: bool = True, + model: str = 'rife-v4.6', + uhd: bool = False, + tta: bool = False ) -> Image.Image: """Blend using RIFE AI frame interpolation. @@ -411,27 +622,30 @@ class ImageBlender: t: Interpolation factor 0.0 (100% A) to 1.0 (100% B). binary_path: Optional path to rife-ncnn-vulkan binary. auto_download: Whether to auto-download RIFE if not found. + model: RIFE model to use (e.g., 'rife-v4.6', 'rife-anime'). + uhd: Enable UHD mode for high resolution images. + tta: Enable TTA mode for better quality (slower). Returns: AI-interpolated blended PIL Image. """ # Try NCNN binary first (specified path) if binary_path and binary_path.exists(): - result = ImageBlender._rife_ncnn(img_a, img_b, t, binary_path) + result = ImageBlender._rife_ncnn(img_a, img_b, t, binary_path, model, uhd, tta) if result is not None: return result # Try to find rife-ncnn-vulkan in PATH ncnn_path = shutil.which('rife-ncnn-vulkan') if ncnn_path: - result = ImageBlender._rife_ncnn(img_a, img_b, t, Path(ncnn_path)) + result = ImageBlender._rife_ncnn(img_a, img_b, t, Path(ncnn_path), model, uhd, tta) if result is not None: return result # Try cached binary cached = RifeDownloader.get_cached_binary() if cached: - result = ImageBlender._rife_ncnn(img_a, img_b, t, cached) + result = ImageBlender._rife_ncnn(img_a, img_b, t, cached, model, uhd, tta) if result is not None: return result @@ -439,7 +653,7 @@ class ImageBlender: if auto_download: downloaded = RifeDownloader.ensure_binary() if downloaded: - result = ImageBlender._rife_ncnn(img_a, img_b, t, downloaded) + result = ImageBlender._rife_ncnn(img_a, img_b, t, downloaded, model, uhd, tta) if result is not None: return result @@ -451,7 +665,10 @@ class ImageBlender: img_a: Image.Image, img_b: Image.Image, t: float, - binary: Path + binary: Path, + model: str = 'rife-v4.6', + uhd: bool = False, + tta: bool = False ) -> Optional[Image.Image]: """Use rife-ncnn-vulkan binary for interpolation. @@ -460,6 +677,9 @@ class ImageBlender: img_b: Second PIL Image. t: Interpolation timestep (0.0 to 1.0). binary: Path to rife-ncnn-vulkan binary. + model: RIFE model to use. + uhd: Enable UHD mode. + tta: Enable TTA mode. Returns: Interpolated PIL Image, or None if failed. @@ -485,6 +705,19 @@ class ImageBlender: '-o', str(output_file), ] + # Add model path (models are in same directory as binary) + model_path = binary.parent / model + if model_path.exists(): + cmd.extend(['-m', str(model_path)]) + + # Add UHD mode flag + if uhd: + cmd.append('-u') + + # Add TTA mode flag (spatial) + if tta: + cmd.append('-x') + # Some versions support -s for timestep # Try with timestep first, fall back to simple interpolation try: @@ -492,7 +725,7 @@ class ImageBlender: cmd + ['-s', str(t)], check=True, capture_output=True, - timeout=30 + timeout=60 # Increased timeout for TTA mode ) except subprocess.CalledProcessError: # Try without timestep (generates middle frame at t=0.5) @@ -500,7 +733,7 @@ class ImageBlender: cmd, check=True, capture_output=True, - timeout=30 + timeout=60 ) if output_file.exists(): @@ -511,6 +744,58 @@ class ImageBlender: return None + @staticmethod + def practical_rife_blend( + img_a: Image.Image, + img_b: Image.Image, + t: float, + model: str = 'v4.25', + ensemble: bool = False + ) -> Image.Image: + """Blend using Practical-RIFE Python/PyTorch implementation. + + Runs RIFE interpolation via subprocess in an isolated venv. + Falls back to ncnn RIFE or optical flow if unavailable. + + Args: + img_a: First PIL Image (source frame). + img_b: Second PIL Image (target frame). + t: Interpolation factor 0.0 (100% A) to 1.0 (100% B). + model: Practical-RIFE model version (e.g., 'v4.25', 'v4.26'). + ensemble: Enable ensemble mode for better quality (slower). + + Returns: + AI-interpolated blended PIL Image. + """ + if not PracticalRifeEnv.is_setup(): + # Fall back to ncnn RIFE or optical flow + return ImageBlender.rife_blend(img_a, img_b, t) + + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmp = Path(tmpdir) + input_a = tmp / 'a.png' + input_b = tmp / 'b.png' + output_file = tmp / 'out.png' + + # Save input images + img_a.convert('RGB').save(input_a) + img_b.convert('RGB').save(input_b) + + # Run Practical-RIFE via subprocess + success = PracticalRifeEnv.run_interpolation( + input_a, input_b, output_file, t, model, ensemble + ) + + if success and output_file.exists(): + return Image.open(output_file).copy() + + except Exception: + pass + + # Fall back to ncnn RIFE or optical flow + return ImageBlender.rife_blend(img_a, img_b, t) + @staticmethod def blend_images( img_a_path: Path, @@ -521,7 +806,12 @@ class ImageBlender: output_quality: int = 95, webp_method: int = 4, blend_method: BlendMethod = BlendMethod.ALPHA, - rife_binary_path: Optional[Path] = None + rife_binary_path: Optional[Path] = None, + rife_model: str = 'rife-v4.6', + rife_uhd: bool = False, + rife_tta: bool = False, + practical_rife_model: str = 'v4.25', + practical_rife_ensemble: bool = False ) -> BlendResult: """Blend two images together. @@ -533,8 +823,13 @@ class ImageBlender: output_format: Output format (png, jpeg, webp). output_quality: Quality for JPEG output (1-100). webp_method: WebP compression method (0-6, higher = smaller but slower). - blend_method: The blending method to use (alpha, optical_flow, or rife). + blend_method: The blending method to use (alpha, optical_flow, rife, rife_practical). rife_binary_path: Optional path to rife-ncnn-vulkan binary. + rife_model: RIFE ncnn model to use (e.g., 'rife-v4.6'). + rife_uhd: Enable RIFE ncnn UHD mode. + rife_tta: Enable RIFE ncnn TTA mode. + practical_rife_model: Practical-RIFE model version (e.g., 'v4.25'). + practical_rife_ensemble: Enable Practical-RIFE ensemble mode. Returns: BlendResult with operation status. @@ -557,7 +852,13 @@ class ImageBlender: if blend_method == BlendMethod.OPTICAL_FLOW: blended = ImageBlender.optical_flow_blend(img_a, img_b, factor) elif blend_method == BlendMethod.RIFE: - blended = ImageBlender.rife_blend(img_a, img_b, factor, rife_binary_path) + blended = ImageBlender.rife_blend( + img_a, img_b, factor, rife_binary_path, True, rife_model, rife_uhd, rife_tta + ) + elif blend_method == BlendMethod.RIFE_PRACTICAL: + blended = ImageBlender.practical_rife_blend( + img_a, img_b, factor, practical_rife_model, practical_rife_ensemble + ) else: # Default: simple alpha blend blended = Image.blend(img_a, img_b, factor) @@ -610,7 +911,12 @@ class ImageBlender: output_quality: int = 95, webp_method: int = 4, blend_method: BlendMethod = BlendMethod.ALPHA, - rife_binary_path: Optional[Path] = None + rife_binary_path: Optional[Path] = None, + rife_model: str = 'rife-v4.6', + rife_uhd: bool = False, + rife_tta: bool = False, + practical_rife_model: str = 'v4.25', + practical_rife_ensemble: bool = False ) -> BlendResult: """Blend two PIL Image objects together. @@ -622,8 +928,13 @@ class ImageBlender: output_format: Output format (png, jpeg, webp). output_quality: Quality for JPEG output (1-100). webp_method: WebP compression method (0-6). - blend_method: The blending method to use (alpha, optical_flow, or rife). + blend_method: The blending method to use (alpha, optical_flow, rife, rife_practical). rife_binary_path: Optional path to rife-ncnn-vulkan binary. + rife_model: RIFE ncnn model to use (e.g., 'rife-v4.6'). + rife_uhd: Enable RIFE ncnn UHD mode. + rife_tta: Enable RIFE ncnn TTA mode. + practical_rife_model: Practical-RIFE model version (e.g., 'v4.25'). + practical_rife_ensemble: Enable Practical-RIFE ensemble mode. Returns: BlendResult with operation status. @@ -643,7 +954,13 @@ class ImageBlender: if blend_method == BlendMethod.OPTICAL_FLOW: blended = ImageBlender.optical_flow_blend(img_a, img_b, factor) elif blend_method == BlendMethod.RIFE: - blended = ImageBlender.rife_blend(img_a, img_b, factor, rife_binary_path) + blended = ImageBlender.rife_blend( + img_a, img_b, factor, rife_binary_path, True, rife_model, rife_uhd, rife_tta + ) + elif blend_method == BlendMethod.RIFE_PRACTICAL: + blended = ImageBlender.practical_rife_blend( + img_a, img_b, factor, practical_rife_model, practical_rife_ensemble + ) else: # Default: simple alpha blend blended = Image.blend(img_a, img_b, factor) @@ -887,7 +1204,12 @@ class TransitionGenerator: self.settings.output_quality, self.settings.webp_method, self.settings.blend_method, - self.settings.rife_binary_path + self.settings.rife_binary_path, + self.settings.rife_model, + self.settings.rife_uhd, + self.settings.rife_tta, + self.settings.practical_rife_model, + self.settings.practical_rife_ensemble ) results.append(result) diff --git a/core/models.py b/core/models.py index c293f84..6333bf9 100644 --- a/core/models.py +++ b/core/models.py @@ -21,7 +21,8 @@ class BlendMethod(Enum): """Blend method types for transitions.""" ALPHA = 'alpha' # Simple cross-dissolve (PIL.Image.blend) OPTICAL_FLOW = 'optical' # OpenCV Farneback optical flow - RIFE = 'rife' # AI frame interpolation (NCNN binary or PyTorch) + RIFE = 'rife' # AI frame interpolation (NCNN binary) + RIFE_PRACTICAL = 'rife_practical' # Practical-RIFE Python/PyTorch implementation class FolderType(Enum): @@ -44,6 +45,12 @@ class TransitionSettings: trans_destination: Optional[Path] = None # separate destination for transition output blend_method: BlendMethod = BlendMethod.ALPHA # blending method rife_binary_path: Optional[Path] = None # path to rife-ncnn-vulkan binary + rife_model: str = 'rife-v4.6' # RIFE model to use + rife_uhd: bool = False # Enable UHD mode for high resolution + rife_tta: bool = False # Enable TTA mode for better quality + # Practical-RIFE settings + practical_rife_model: str = 'v4.25' # v4.25, v4.26, v4.22, etc. + practical_rife_ensemble: bool = False # Ensemble mode for better quality (slower) @dataclass diff --git a/core/rife_worker.py b/core/rife_worker.py new file mode 100644 index 0000000..961e0c6 --- /dev/null +++ b/core/rife_worker.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python +"""RIFE interpolation worker - runs in isolated venv with PyTorch. + +This script is executed via subprocess from the main application. +It handles loading Practical-RIFE models and performing frame interpolation. + +Note: The Practical-RIFE models require the IFNet architecture from the +Practical-RIFE repository. This script downloads and uses the model weights +with a simplified inference implementation. +""" + +import argparse +import os +import sys +import urllib.request +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True), + nn.PReLU(out_planes) + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + feat + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale * 2 + mask = tmp[:, 4:5] + return flow, mask + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + backwarp_tenGrid = {} + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) + + +class IFNet(nn.Module): + """IFNet architecture for RIFE v4.x models.""" + + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+16, c=192) + self.block1 = IFBlock(8+4+16, c=128) + self.block2 = IFBlock(8+4+16, c=96) + self.block3 = IFBlock(8+4+16, c=64) + self.encode = nn.Sequential( + nn.Conv2d(3, 16, 3, 2, 1), + nn.ConvTranspose2d(16, 4, 4, 2, 1) + ) + + def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]): + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, scale=scale_list[i]) + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0 = block[i]( + torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1), + flow, scale=scale_list[i]) + flow = flow + fd + mask = mask + m0 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask_final = torch.sigmoid(mask) + merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final) + return merged_final + + +# Model URLs for downloading +MODEL_URLS = { + 'v4.26': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.26/flownet.pkl', + 'v4.25': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.25/flownet.pkl', + 'v4.22': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.22/flownet.pkl', + 'v4.20': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.20/flownet.pkl', + 'v4.18': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.18/flownet.pkl', + 'v4.15': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.15/flownet.pkl', +} + + +def download_model(version: str, model_dir: Path) -> Path: + """Download model if not already cached. + + Args: + version: Model version (e.g., 'v4.25'). + model_dir: Directory to store models. + + Returns: + Path to the downloaded model file. + """ + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / f'flownet_{version}.pkl' + + if model_path.exists(): + return model_path + + url = MODEL_URLS.get(version) + if not url: + raise ValueError(f"Unknown model version: {version}") + + print(f"Downloading RIFE model {version}...", file=sys.stderr) + try: + req = urllib.request.Request(url, headers={'User-Agent': 'video-montage-linker'}) + with urllib.request.urlopen(req, timeout=120) as response: + with open(model_path, 'wb') as f: + f.write(response.read()) + print(f"Model downloaded to {model_path}", file=sys.stderr) + return model_path + except Exception as e: + # Clean up partial download + if model_path.exists(): + model_path.unlink() + raise RuntimeError(f"Failed to download model: {e}") + + +def load_model(model_path: Path, device: torch.device) -> IFNet: + """Load IFNet model from state dict. + + Args: + model_path: Path to flownet.pkl file. + device: Device to load model to. + + Returns: + Loaded IFNet model. + """ + model = IFNet() + state_dict = torch.load(model_path, map_location='cpu') + + # Handle different state dict formats + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + # Remove 'module.' prefix if present (from DataParallel) + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + # Handle flownet. prefix + if k.startswith('flownet.'): + k = k[8:] + new_state_dict[k] = v + + model.load_state_dict(new_state_dict, strict=False) + model.to(device) + model.eval() + return model + + +def pad_image(img: torch.Tensor, padding: int = 64) -> tuple: + """Pad image to be divisible by padding. + + Args: + img: Input tensor (B, C, H, W). + padding: Padding divisor. + + Returns: + Tuple of (padded image, (original H, original W)). + """ + _, _, h, w = img.shape + ph = ((h - 1) // padding + 1) * padding + pw = ((w - 1) // padding + 1) * padding + pad_h = ph - h + pad_w = pw - w + padded = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate') + return padded, (h, w) + + +@torch.no_grad() +def inference(model: IFNet, img0: torch.Tensor, img1: torch.Tensor, + timestep: float = 0.5, ensemble: bool = False) -> torch.Tensor: + """Perform frame interpolation. + + Args: + model: Loaded IFNet model. + img0: First frame tensor (B, C, H, W) normalized to [0, 1]. + img1: Second frame tensor (B, C, H, W) normalized to [0, 1]. + timestep: Interpolation timestep (0.0 to 1.0). + ensemble: Enable ensemble mode for better quality. + + Returns: + Interpolated frame tensor. + """ + # Pad images + img0_padded, orig_size = pad_image(img0) + img1_padded, _ = pad_image(img1) + h, w = orig_size + + # Create timestep tensor + timestep_tensor = torch.full((1, 1, img0_padded.shape[2], img0_padded.shape[3]), + timestep, device=img0.device) + + if ensemble: + # Ensemble: average of forward and reverse + result1 = model(img0_padded, img1_padded, timestep_tensor) + result2 = model(img1_padded, img0_padded, 1 - timestep_tensor) + result = (result1 + result2) / 2 + else: + result = model(img0_padded, img1_padded, timestep_tensor) + + # Crop back to original size + result = result[:, :, :h, :w] + return result.clamp(0, 1) + + +def load_image(path: Path, device: torch.device) -> torch.Tensor: + """Load image as tensor. + + Args: + path: Path to image file. + device: Device to load tensor to. + + Returns: + Image tensor (1, 3, H, W) normalized to [0, 1]. + """ + img = Image.open(path).convert('RGB') + arr = np.array(img).astype(np.float32) / 255.0 + tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) + return tensor.to(device) + + +def save_image(tensor: torch.Tensor, path: Path) -> None: + """Save tensor as image. + + Args: + tensor: Image tensor (1, 3, H, W) normalized to [0, 1]. + path: Output path. + """ + arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() + arr = (arr * 255).clip(0, 255).astype(np.uint8) + Image.fromarray(arr).save(path) + + +# Global model cache +_model_cache: dict = {} + + +def get_model(version: str, model_dir: Path, device: torch.device) -> IFNet: + """Get or load model (cached). + + Args: + version: Model version. + model_dir: Model cache directory. + device: Device to run on. + + Returns: + IFNet model instance. + """ + cache_key = f"{version}_{device}" + if cache_key not in _model_cache: + model_path = download_model(version, model_dir) + _model_cache[cache_key] = load_model(model_path, device) + return _model_cache[cache_key] + + +def main(): + parser = argparse.ArgumentParser(description='RIFE frame interpolation worker') + parser.add_argument('--input0', required=True, help='Path to first input image') + parser.add_argument('--input1', required=True, help='Path to second input image') + parser.add_argument('--output', required=True, help='Path to output image') + parser.add_argument('--timestep', type=float, default=0.5, help='Interpolation timestep (0-1)') + parser.add_argument('--model', default='v4.25', help='Model version') + parser.add_argument('--model-dir', required=True, help='Model cache directory') + parser.add_argument('--ensemble', action='store_true', help='Enable ensemble mode') + parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use') + + args = parser.parse_args() + + try: + # Select device + if args.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Load model + model_dir = Path(args.model_dir) + model = get_model(args.model, model_dir, device) + + # Load images + img0 = load_image(Path(args.input0), device) + img1 = load_image(Path(args.input1), device) + + # Interpolate + result = inference(model, img0, img1, args.timestep, args.ensemble) + + # Save result + save_image(result, Path(args.output)) + + print("Success", file=sys.stderr) + return 0 + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/ui/main_window.py b/ui/main_window.py index 2f9a71f..c9b3ac6 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -5,8 +5,8 @@ import re from pathlib import Path from typing import Optional -from PyQt6.QtCore import Qt, QUrl, QEvent, QPoint -from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QColor +from PyQt6.QtCore import Qt, QUrl, QEvent, QPoint, QTimer +from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QColor, QPainter, QFont, QFontMetrics from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput from PyQt6.QtMultimediaWidgets import QVideoWidget from PyQt6.QtWidgets import ( @@ -38,6 +38,7 @@ from PyQt6.QtWidgets import ( QDialog, QDialogButtonBox, QFormLayout, + QCheckBox, ) from PyQt6.QtGui import QPixmap @@ -53,11 +54,84 @@ from core import ( DatabaseManager, TransitionGenerator, RifeDownloader, + PracticalRifeEnv, SymlinkManager, ) from .widgets import TrimSlider +class TimelineTreeWidget(QTreeWidget): + """QTreeWidget with timeline markers drawn in the background.""" + + def __init__(self, parent: Optional[QWidget] = None) -> None: + super().__init__(parent) + self.fps = 16 + self._text_color = QColor(100, 100, 100) + + def set_fps(self, fps: int) -> None: + """Update FPS for timeline display.""" + self.fps = max(1, fps) + self.viewport().update() + + def paintEvent(self, event) -> None: + """Draw timeline markers in background, then call parent paint.""" + # Draw the timeline background on the viewport + painter = QPainter(self.viewport()) + + frame_count = self.topLevelItemCount() + if frame_count > 0 and self.fps > 0: + # Get row height from first visible item + first_item = self.topLevelItem(0) + if first_item: + # Get column positions + col0_width = self.columnWidth(0) + viewport_width = self.viewport().width() + + # Font for time labels + font = QFont("Monospace", 9) + painter.setFont(font) + metrics = QFontMetrics(font) + + # Draw for each row + for i in range(frame_count): + item = self.topLevelItem(i) + if not item: + continue + + item_rect = self.visualItemRect(item) + if item_rect.isNull() or item_rect.bottom() < 0 or item_rect.top() > self.viewport().height(): + continue # Not visible + + y_center = item_rect.center().y() + + # Calculate time for this frame + time_seconds = i / self.fps + is_major = (i % self.fps == 0) # Every second + + if is_major: + # Format time + minutes = int(time_seconds // 60) + seconds = int(time_seconds % 60) + if minutes > 0: + time_str = f"{minutes}:{seconds:02d}" + else: + time_str = f"{seconds}s" + + text_width = metrics.horizontalAdvance(time_str) + painter.setPen(self._text_color) + + # Draw time label on right of column 0 + painter.drawText(col0_width - text_width - 6, y_center + metrics.ascent() // 2, time_str) + + # Draw time label on right of column 1 (right edge) + painter.drawText(viewport_width - text_width - 6, y_center + metrics.ascent() // 2, time_str) + + painter.end() + + # Call parent to draw the actual tree content + super().paintEvent(event) + + class OverlapDialog(QDialog): """Dialog for setting per-transition overlap frames.""" @@ -141,6 +215,8 @@ class SequenceLinkerUI(QWidget): self._create_layout() self._connect_signals() self.setAcceptDrops(True) + # Initialize sequence table FPS + self.sequence_table.set_fps(self.fps_spin.value()) def _setup_window(self) -> None: """Configure the main window properties.""" @@ -260,14 +336,15 @@ class SequenceLinkerUI(QWidget): self._current_pixmap: Optional[QPixmap] = None self._pan_start = None self._pan_scrollbar_start = None + self._blend_preview_cache: dict[str, QPixmap] = {} # Cache for generated blend frames # Trim slider self.trim_slider = TrimSlider() self.trim_label = QLabel("Frames: All included") self.trim_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - # Sequence table (2-column: Main Frame | Transition Frame) - self.sequence_table = QTreeWidget() + # Sequence table (2-column: Main Frame | Transition Frame) with timeline background + self.sequence_table = TimelineTreeWidget() self.sequence_table.setHeaderLabels(["Main Frame", "Transition Frame"]) self.sequence_table.setColumnCount(2) self.sequence_table.setRootIsDecorated(False) @@ -317,12 +394,14 @@ class SequenceLinkerUI(QWidget): self.blend_method_combo = QComboBox() self.blend_method_combo.addItem("Cross-Dissolve", BlendMethod.ALPHA) self.blend_method_combo.addItem("Optical Flow", BlendMethod.OPTICAL_FLOW) - self.blend_method_combo.addItem("RIFE (AI)", BlendMethod.RIFE) + self.blend_method_combo.addItem("RIFE (ncnn)", BlendMethod.RIFE) + self.blend_method_combo.addItem("RIFE (Practical)", BlendMethod.RIFE_PRACTICAL) self.blend_method_combo.setToolTip( "Blending method:\n" "- Cross-Dissolve: Simple alpha blend (fast, may ghost)\n" "- Optical Flow: Motion-compensated blend (slower, less ghosting)\n" - "- RIFE: AI frame interpolation (best quality, requires rife-ncnn-vulkan)" + "- RIFE (ncnn): AI frame interpolation (fast, Vulkan GPU, models up to v4.6)\n" + "- RIFE (Practical): AI frame interpolation (PyTorch, latest models v4.25/v4.26)" ) # RIFE binary path @@ -338,6 +417,77 @@ class SequenceLinkerUI(QWidget): self.rife_path_btn.setVisible(False) self.rife_download_btn.setVisible(False) + # RIFE model selection + self.rife_model_label = QLabel("Model:") + self.rife_model_combo = QComboBox() + self.rife_model_combo.addItem("v4.6 (Best)", "rife-v4.6") + self.rife_model_combo.addItem("v4", "rife-v4") + self.rife_model_combo.addItem("v3.1", "rife-v3.1") + self.rife_model_combo.addItem("v2.4", "rife-v2.4") + self.rife_model_combo.addItem("Anime", "rife-anime") + self.rife_model_combo.addItem("UHD", "rife-UHD") + self.rife_model_combo.addItem("HD", "rife-HD") + self.rife_model_combo.setToolTip("RIFE model version:\n- v4.6: Latest, best quality\n- Anime: Optimized for animation\n- UHD/HD: For high resolution content") + self.rife_model_label.setVisible(False) + self.rife_model_combo.setVisible(False) + + # RIFE UHD mode + self.rife_uhd_check = QCheckBox("UHD") + self.rife_uhd_check.setToolTip("Enable UHD mode for high resolution images (4K+)") + self.rife_uhd_check.setVisible(False) + + # RIFE TTA mode + self.rife_tta_check = QCheckBox("TTA") + self.rife_tta_check.setToolTip("Enable TTA (Test-Time Augmentation) for better quality (slower)") + self.rife_tta_check.setVisible(False) + + # Practical-RIFE settings + self.practical_model_label = QLabel("Model:") + self.practical_model_combo = QComboBox() + self.practical_model_combo.addItem("v4.26 (Latest)", "v4.26") + self.practical_model_combo.addItem("v4.25 (Recommended)", "v4.25") + self.practical_model_combo.addItem("v4.22", "v4.22") + self.practical_model_combo.addItem("v4.20", "v4.20") + self.practical_model_combo.addItem("v4.18", "v4.18") + self.practical_model_combo.addItem("v4.15", "v4.15") + self.practical_model_combo.setCurrentIndex(1) # Default to v4.25 + self.practical_model_combo.setToolTip( + "Practical-RIFE model version:\n" + "- v4.26: Latest version\n" + "- v4.25: Recommended, good balance of quality and speed" + ) + self.practical_model_label.setVisible(False) + self.practical_model_combo.setVisible(False) + + self.practical_ensemble_check = QCheckBox("Ensemble") + self.practical_ensemble_check.setToolTip("Enable ensemble mode for better quality (slower)") + self.practical_ensemble_check.setVisible(False) + + self.practical_setup_btn = QPushButton("Setup PyTorch") + self.practical_setup_btn.setToolTip("Create local venv and install PyTorch (~2GB download)") + self.practical_setup_btn.setVisible(False) + + self.practical_status_label = QLabel("") + self.practical_status_label.setStyleSheet("color: gray; font-size: 10px;") + self.practical_status_label.setVisible(False) + + # FPS setting for sequence playback and timeline + self.fps_label = QLabel("FPS:") + self.fps_spin = QSpinBox() + self.fps_spin.setRange(1, 120) + self.fps_spin.setValue(16) + self.fps_spin.setToolTip("Frames per second for sequence preview and timeline") + + # Timeline duration label + self.timeline_label = QLabel("Duration: 00:00.000 (0 frames)") + self.timeline_label.setStyleSheet("font-family: monospace;") + + # Sequence playback button and timer + self.seq_play_btn = QPushButton("▶ Play") + self.seq_play_btn.setToolTip("Play image sequence at configured FPS") + self.sequence_timer = QTimer(self) + self.sequence_playing = False + def _create_layout(self) -> None: """Arrange widgets in layouts.""" # === LEFT SIDE PANEL: Source Folders === @@ -397,6 +547,19 @@ class SequenceLinkerUI(QWidget): transition_layout.addWidget(self.rife_path_input) transition_layout.addWidget(self.rife_path_btn) transition_layout.addWidget(self.rife_download_btn) + transition_layout.addWidget(self.rife_model_label) + transition_layout.addWidget(self.rife_model_combo) + transition_layout.addWidget(self.rife_uhd_check) + transition_layout.addWidget(self.rife_tta_check) + transition_layout.addWidget(self.practical_model_label) + transition_layout.addWidget(self.practical_model_combo) + transition_layout.addWidget(self.practical_ensemble_check) + transition_layout.addWidget(self.practical_setup_btn) + transition_layout.addWidget(self.practical_status_label) + transition_layout.addWidget(self.fps_label) + transition_layout.addWidget(self.fps_spin) + transition_layout.addWidget(self.timeline_label) + transition_layout.addWidget(self.seq_play_btn) transition_layout.addStretch() self.transition_group.setLayout(transition_layout) @@ -459,11 +622,13 @@ class SequenceLinkerUI(QWidget): sequence_order_layout.addWidget(self.file_list) self.sequence_tabs.addTab(sequence_order_tab, "Sequence Order") - # Tab 2: With Transitions (2-column view) + # Tab 2: With Transitions (2-column view with timeline rulers) trans_sequence_tab = QWidget() trans_sequence_layout = QVBoxLayout(trans_sequence_tab) trans_sequence_layout.setContentsMargins(0, 0, 0, 0) + trans_sequence_layout.addWidget(self.sequence_table) + self.sequence_tabs.addTab(trans_sequence_tab, "With Transitions") file_list_layout.addWidget(self.sequence_tabs) @@ -555,9 +720,18 @@ class SequenceLinkerUI(QWidget): # Blend method combo change - show/hide RIFE path self.blend_method_combo.currentIndexChanged.connect(self._on_blend_method_changed) + self.curve_combo.currentIndexChanged.connect(self._clear_blend_cache) + self.rife_model_combo.currentIndexChanged.connect(self._clear_blend_cache) + self.rife_uhd_check.stateChanged.connect(self._clear_blend_cache) + self.rife_tta_check.stateChanged.connect(self._clear_blend_cache) self.rife_path_btn.clicked.connect(self._browse_rife_binary) self.rife_download_btn.clicked.connect(self._download_rife_binary) + # Practical-RIFE signals + self.practical_model_combo.currentIndexChanged.connect(self._clear_blend_cache) + self.practical_ensemble_check.stateChanged.connect(self._clear_blend_cache) + self.practical_setup_btn.clicked.connect(self._setup_practical_rife) + # Sequence table selection - show image self.sequence_table.currentItemChanged.connect(self._on_sequence_table_selected) @@ -567,6 +741,14 @@ class SequenceLinkerUI(QWidget): # Update sequence table when switching to "With Transitions" tab self.sequence_tabs.currentChanged.connect(self._on_sequence_tab_changed) + # FPS and sequence playback signals + self.fps_spin.valueChanged.connect(self._update_timeline_display) + self.seq_play_btn.clicked.connect(self._toggle_sequence_play) + self.sequence_timer.timeout.connect(self._advance_sequence_frame) + + # Update sequence table FPS when spinner changes + self.fps_spin.valueChanged.connect(self.sequence_table.set_fps) + def _on_format_changed(self, index: int) -> None: """Handle format combo change to show/hide quality/method widgets.""" fmt = self.blend_format_combo.currentData() @@ -589,15 +771,39 @@ class SequenceLinkerUI(QWidget): def _on_blend_method_changed(self, index: int) -> None: """Handle blend method combo change to show/hide RIFE path widgets.""" method = self.blend_method_combo.currentData() - is_rife = (method == BlendMethod.RIFE) - self.rife_path_label.setVisible(is_rife) - self.rife_path_input.setVisible(is_rife) - self.rife_path_btn.setVisible(is_rife) - self.rife_download_btn.setVisible(is_rife) + is_rife_ncnn = (method == BlendMethod.RIFE) + is_rife_practical = (method == BlendMethod.RIFE_PRACTICAL) - if is_rife: + # RIFE ncnn settings + self.rife_path_label.setVisible(is_rife_ncnn) + self.rife_path_input.setVisible(is_rife_ncnn) + self.rife_path_btn.setVisible(is_rife_ncnn) + self.rife_download_btn.setVisible(is_rife_ncnn) + self.rife_model_label.setVisible(is_rife_ncnn) + self.rife_model_combo.setVisible(is_rife_ncnn) + self.rife_uhd_check.setVisible(is_rife_ncnn) + self.rife_tta_check.setVisible(is_rife_ncnn) + + # Practical-RIFE settings + self.practical_model_label.setVisible(is_rife_practical) + self.practical_model_combo.setVisible(is_rife_practical) + self.practical_ensemble_check.setVisible(is_rife_practical) + self.practical_setup_btn.setVisible(is_rife_practical) + self.practical_status_label.setVisible(is_rife_practical) + + if is_rife_ncnn: self._update_rife_download_button() + if is_rife_practical: + self._update_practical_rife_status() + + # Clear blend preview cache when method changes + self._blend_preview_cache.clear() + + def _clear_blend_cache(self) -> None: + """Clear the blend preview cache.""" + self._blend_preview_cache.clear() + def _browse_rife_binary(self) -> None: """Browse for RIFE binary.""" start_dir = self.last_directory or "" @@ -743,6 +949,94 @@ class SequenceLinkerUI(QWidget): ) self._update_rife_download_button() + def _update_practical_rife_status(self) -> None: + """Update the Practical-RIFE status label and setup button.""" + if PracticalRifeEnv.is_setup(): + torch_version = PracticalRifeEnv.get_torch_version() + if torch_version: + self.practical_status_label.setText(f"Ready (PyTorch {torch_version})") + self.practical_status_label.setStyleSheet("color: green; font-size: 10px;") + else: + self.practical_status_label.setText("Ready") + self.practical_status_label.setStyleSheet("color: green; font-size: 10px;") + self.practical_setup_btn.setText("Reinstall") + self.practical_setup_btn.setToolTip("Reinstall PyTorch environment") + self.practical_model_combo.setEnabled(True) + self.practical_ensemble_check.setEnabled(True) + else: + self.practical_status_label.setText("Not configured") + self.practical_status_label.setStyleSheet("color: orange; font-size: 10px;") + self.practical_setup_btn.setText("Setup PyTorch") + self.practical_setup_btn.setToolTip("Create local venv and install PyTorch (~2GB download)") + self.practical_model_combo.setEnabled(False) + self.practical_ensemble_check.setEnabled(False) + + def _setup_practical_rife(self) -> None: + """Setup Practical-RIFE environment with progress dialog.""" + # Confirm if already setup + if PracticalRifeEnv.is_setup(): + reply = QMessageBox.question( + self, "Reinstall PyTorch?", + "PyTorch environment is already set up.\n" + "Do you want to reinstall it?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No + ) + if reply != QMessageBox.StandardButton.Yes: + return + + # Create progress dialog + progress = QProgressDialog( + "Setting up PyTorch environment...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("Setup Practical-RIFE") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(0) + progress.show() + + # Progress callback + def progress_callback(message, percent): + if not progress.wasCanceled(): + progress.setLabelText(message) + progress.setValue(percent) + QApplication.processEvents() + + def cancelled_check(): + QApplication.processEvents() + return progress.wasCanceled() + + try: + success = PracticalRifeEnv.setup_venv(progress_callback, cancelled_check) + progress.close() + + if progress.wasCanceled(): + self._update_practical_rife_status() + return + + if success: + QMessageBox.information( + self, "Setup Complete", + "PyTorch environment set up successfully!\n\n" + f"Location: {PracticalRifeEnv.VENV_DIR}\n\n" + "You can now use RIFE (Practical) for frame interpolation." + ) + else: + QMessageBox.warning( + self, "Setup Failed", + "Failed to set up PyTorch environment.\n" + "Check your internet connection and try again." + ) + + self._update_practical_rife_status() + + except Exception as e: + progress.close() + QMessageBox.critical( + self, "Setup Error", + f"Error setting up PyTorch: {e}" + ) + self._update_practical_rife_status() + def _on_sequence_tab_changed(self, index: int) -> None: """Handle sequence tab change to update the With Transitions view.""" if index == 1: # "With Transitions" tab @@ -753,10 +1047,12 @@ class SequenceLinkerUI(QWidget): self.sequence_table.clear() if not self.source_folders: + self._update_timeline_display() return files = self._get_files_in_order() if not files: + self._update_timeline_display() return # Group files by folder @@ -774,6 +1070,7 @@ class SequenceLinkerUI(QWidget): item = QTreeWidgetItem([f"{seq_name} ({filename})", ""]) item.setData(0, Qt.ItemDataRole.UserRole, (source_dir, filename, folder_idx, file_idx, 'symlink')) self.sequence_table.addTopLevelItem(item) + self._update_timeline_display() return # Get transition specs @@ -856,6 +1153,9 @@ class SequenceLinkerUI(QWidget): self.sequence_table.addTopLevelItem(item) + # Update timeline display after rebuilding sequence table + self._update_timeline_display() + def _on_sequence_table_selected(self, current, previous) -> None: """Handle sequence table row selection - show image in preview.""" if current is None: @@ -921,20 +1221,6 @@ class SequenceLinkerUI(QWidget): return try: - # Load images - img_a = Image.open(main_path) - img_b = Image.open(trans_path) - - # Resize B to match A if needed - if img_a.size != img_b.size: - img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS) - - # Convert to RGBA - if img_a.mode != 'RGBA': - img_a = img_a.convert('RGBA') - if img_b.mode != 'RGBA': - img_b = img_b.convert('RGBA') - # Calculate blend factor based on position in sequence table # Find this frame's position in the blend sequence row_idx = self.sequence_table.indexOfTopLevelItem(item) @@ -970,17 +1256,51 @@ class SequenceLinkerUI(QWidget): blend_position, blend_count, settings.blend_curve ) - # Blend images using selected method - if settings.blend_method == BlendMethod.OPTICAL_FLOW: - blended = ImageBlender.optical_flow_blend(img_a, img_b, factor) - elif settings.blend_method == BlendMethod.RIFE: - blended = ImageBlender.rife_blend(img_a, img_b, factor, settings.rife_binary_path) - else: - blended = Image.blend(img_a, img_b, factor) + # Create cache key (include RIFE settings when using RIFE) + cache_key = f"{main_path}|{trans_path}|{factor:.6f}|{settings.blend_method.value}|{settings.blend_curve.value}" + if settings.blend_method == BlendMethod.RIFE: + cache_key += f"|{settings.rife_model}|{settings.rife_uhd}|{settings.rife_tta}" - # Convert to QPixmap - qim = ImageQt(blended.convert('RGBA')) - pixmap = QPixmap.fromImage(qim) + # Check cache first + if cache_key in self._blend_preview_cache: + pixmap = self._blend_preview_cache[cache_key] + else: + # Load images + img_a = Image.open(main_path) + img_b = Image.open(trans_path) + + # Resize B to match A if needed + if img_a.size != img_b.size: + img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS) + + # Convert to RGBA + if img_a.mode != 'RGBA': + img_a = img_a.convert('RGBA') + if img_b.mode != 'RGBA': + img_b = img_b.convert('RGBA') + + # Blend images using selected method + if settings.blend_method == BlendMethod.OPTICAL_FLOW: + blended = ImageBlender.optical_flow_blend(img_a, img_b, factor) + elif settings.blend_method == BlendMethod.RIFE: + blended = ImageBlender.rife_blend( + img_a, img_b, factor, settings.rife_binary_path, + model=settings.rife_model, + uhd=settings.rife_uhd, + tta=settings.rife_tta + ) + else: + blended = Image.blend(img_a, img_b, factor) + + # Convert to QPixmap + qim = ImageQt(blended.convert('RGBA')) + pixmap = QPixmap.fromImage(qim) + + # Store in cache + self._blend_preview_cache[cache_key] = pixmap + + img_a.close() + img_b.close() self._current_pixmap = pixmap self._apply_zoom() @@ -990,14 +1310,77 @@ class SequenceLinkerUI(QWidget): seq_name = f"seq{data0[2] + 1:02d}_{data0[3]:04d}" self.image_name_label.setText(f"[B] {seq_name} ({main_file} + {trans_file}) @ {factor:.0%}") - img_a.close() - img_b.close() - except Exception as e: self.image_label.setText(f"Error generating blend preview:\n{e}") self.image_name_label.setText("") self._current_pixmap = None + def _update_timeline_display(self) -> None: + """Update the timeline duration display based on frame count and FPS.""" + frame_count = self.sequence_table.topLevelItemCount() + fps = self.fps_spin.value() + + if fps > 0 and frame_count > 0: + total_seconds = frame_count / fps + minutes = int(total_seconds // 60) + seconds = total_seconds % 60 + self.timeline_label.setText( + f"Duration: {minutes:02d}:{seconds:06.3f} ({frame_count} frames @ {fps}fps)" + ) + else: + self.timeline_label.setText("Duration: 00:00.000 (0 frames)") + + # Refresh the sequence table to update timeline background + self.sequence_table.viewport().update() + + def _toggle_sequence_play(self) -> None: + """Toggle sequence playback.""" + if self.sequence_playing: + self._stop_sequence_play() + else: + self._start_sequence_play() + + def _start_sequence_play(self) -> None: + """Start playing the image sequence.""" + if self.sequence_table.topLevelItemCount() == 0: + return + + fps = self.fps_spin.value() + interval = int(1000 / fps) # milliseconds per frame + self.sequence_timer.setInterval(interval) + self.sequence_timer.start() + self.sequence_playing = True + self.seq_play_btn.setText("⏸ Pause") + + # If no item selected, start from first + if self.sequence_table.currentItem() is None: + first_item = self.sequence_table.topLevelItem(0) + if first_item: + self.sequence_table.setCurrentItem(first_item) + + def _stop_sequence_play(self) -> None: + """Stop sequence playback.""" + self.sequence_timer.stop() + self.sequence_playing = False + self.seq_play_btn.setText("▶ Play") + + def _advance_sequence_frame(self) -> None: + """Advance to next frame in sequence.""" + current_item = self.sequence_table.currentItem() + if current_item is None: + self._stop_sequence_play() + return + + current_idx = self.sequence_table.indexOfTopLevelItem(current_item) + total = self.sequence_table.topLevelItemCount() + + if current_idx < total - 1: + next_item = self.sequence_table.topLevelItem(current_idx + 1) + self.sequence_table.setCurrentItem(next_item) + else: + # Reached end - stop playback + self._stop_sequence_play() + def _browse_trans_destination(self) -> None: """Select transition destination folder via file dialog.""" start_dir = self.last_directory or "" @@ -1512,7 +1895,12 @@ class SequenceLinkerUI(QWidget): output_quality=self.blend_quality_spin.value(), trans_destination=trans_dest, blend_method=self.blend_method_combo.currentData(), - rife_binary_path=rife_path + rife_binary_path=rife_path, + rife_model=self.rife_model_combo.currentData(), + rife_uhd=self.rife_uhd_check.isChecked(), + rife_tta=self.rife_tta_check.isChecked(), + practical_rife_model=self.practical_model_combo.currentData(), + practical_rife_ensemble=self.practical_ensemble_check.isChecked() ) def _refresh_files(self, select_position: str = 'first') -> None: @@ -1762,6 +2150,9 @@ class SequenceLinkerUI(QWidget): video_path = self.video_combo.currentData() if video_path and isinstance(video_path, Path) and video_path.exists(): self.media_player.setSource(QUrl.fromLocalFile(str(video_path))) + # Play and immediately pause to show first frame + self.media_player.play() + self.media_player.pause() def _toggle_play(self) -> None: """Toggle play/pause state."""