From 00f0141b1573dacfdc6af95acba2b2d0a820fe7f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Tue, 3 Feb 2026 23:21:31 +0100 Subject: [PATCH] rife --- .gitignore | 58 +++++++++++++++++++- core/blender.py | 38 +++++++------ core/rife_worker.py | 129 ++++++++++++++++++++++++++++++++------------ ui/main_window.py | 10 +++- 4 files changed, 183 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index aef96bd..6bc4cfc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,61 @@ +# Python __pycache__/ *.pyc *.pyo -.env +*.pyd +.Python +*.so + +# Virtual environments +venv/ venv-rife/ +.venv/ +env/ + +# Environment files +.env +.env.local + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Database +*.db +*.sqlite +*.sqlite3 + +# Downloads and cache +*.pkl +*.pt +*.pth +*.onnx +downloads/ +cache/ +.cache/ + +# RIFE binaries and models +rife-ncnn-vulkan*/ +*.zip + +# Output directories +output/ +outputs/ +temp/ +tmp/ + +# Logs +*.log +logs/ + +# OS files +.DS_Store +Thumbs.db + +# Build artifacts +dist/ +build/ +*.egg-info/ diff --git a/core/blender.py b/core/blender.py index 9058e62..a9544df 100644 --- a/core/blender.py +++ b/core/blender.py @@ -29,7 +29,7 @@ from .models import ( # Cache directory for downloaded binaries CACHE_DIR = Path.home() / '.cache' / 'video-montage-linker' RIFE_GITHUB_API = 'https://api.github.com/repos/nihui/rife-ncnn-vulkan/releases/latest' -PRACTICAL_RIFE_VENV_DIR = Path('./venv-rife') +PRACTICAL_RIFE_VENV_DIR = CACHE_DIR / 'venv-rife' class PracticalRifeEnv: @@ -147,13 +147,13 @@ class PracticalRifeEnv: ) if progress_callback: - progress_callback("Installing numpy...", 90) + progress_callback("Installing additional dependencies...", 90) if cancelled_check and cancelled_check(): return False - # numpy is usually a dependency of torch but ensure it's there + # Install numpy (usually a dependency of torch) and gdown (for Google Drive downloads) subprocess.run( - [str(python), '-m', 'pip', 'install', 'numpy'], + [str(python), '-m', 'pip', 'install', 'numpy', 'gdown'], capture_output=True ) @@ -190,7 +190,7 @@ class PracticalRifeEnv: t: float, model: str = 'v4.25', ensemble: bool = False - ) -> bool: + ) -> tuple[bool, str]: """Run RIFE interpolation via subprocess in venv. Args: @@ -202,15 +202,15 @@ class PracticalRifeEnv: ensemble: Enable ensemble mode. Returns: - True if interpolation succeeded. + Tuple of (success, error_message). """ python = cls.get_venv_python() if not python or not python.exists(): - return False + return False, "venv python not found" script = cls.get_worker_script() if not script.exists(): - return False + return False, f"worker script not found: {script}" cmd = [ str(python), str(script), @@ -232,11 +232,15 @@ class PracticalRifeEnv: text=True, timeout=120 # 2 minute timeout per frame ) - return result.returncode == 0 and output_path.exists() + if result.returncode == 0 and output_path.exists(): + return True, "" + else: + error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}" + return False, error except subprocess.TimeoutExpired: - return False - except Exception: - return False + return False, "timeout (120s)" + except Exception as e: + return False, str(e) class RifeDownloader: @@ -768,7 +772,7 @@ class ImageBlender: AI-interpolated blended PIL Image. """ if not PracticalRifeEnv.is_setup(): - # Fall back to ncnn RIFE or optical flow + print("[Practical-RIFE] Venv not set up, falling back to ncnn RIFE", file=sys.stderr) return ImageBlender.rife_blend(img_a, img_b, t) try: @@ -783,15 +787,17 @@ class ImageBlender: img_b.convert('RGB').save(input_b) # Run Practical-RIFE via subprocess - success = PracticalRifeEnv.run_interpolation( + success, error_msg = PracticalRifeEnv.run_interpolation( input_a, input_b, output_file, t, model, ensemble ) if success and output_file.exists(): return Image.open(output_file).copy() + else: + print(f"[Practical-RIFE] Interpolation failed: {error_msg}, falling back to ncnn RIFE", file=sys.stderr) - except Exception: - pass + except Exception as e: + print(f"[Practical-RIFE] Exception: {e}, falling back to ncnn RIFE", file=sys.stderr) # Fall back to ncnn RIFE or optical flow return ImageBlender.rife_blend(img_a, img_b, t) diff --git a/core/rife_worker.py b/core/rife_worker.py index 961e0c6..4e0de37 100644 --- a/core/rife_worker.py +++ b/core/rife_worker.py @@ -11,8 +11,11 @@ with a simplified inference implementation. import argparse import os +import shutil import sys +import tempfile import urllib.request +import zipfile from pathlib import Path import numpy as np @@ -88,25 +91,31 @@ def warp(tenInput, tenFlow): class IFNet(nn.Module): - """IFNet architecture for RIFE v4.x models.""" + """IFNet architecture for Practical-RIFE v4.25/v4.26 models.""" def __init__(self): super(IFNet, self).__init__() - self.block0 = IFBlock(7+16, c=192) - self.block1 = IFBlock(8+4+16, c=128) - self.block2 = IFBlock(8+4+16, c=96) - self.block3 = IFBlock(8+4+16, c=64) + # v4.25/v4.26 architecture: + # block0 input: img0(3) + img1(3) + f0(4) + f1(4) + timestep(1) = 15 + # block1+ input: img0(3) + img1(3) + wf0(4) + wf1(4) + f0(4) + f1(4) + timestep(1) + mask(1) + flow(4) = 28 + self.block0 = IFBlock(3+3+4+4+1, c=192) + self.block1 = IFBlock(3+3+4+4+4+4+1+1+4, c=128) + self.block2 = IFBlock(3+3+4+4+4+4+1+1+4, c=96) + self.block3 = IFBlock(3+3+4+4+4+4+1+1+4, c=64) + # Encode produces 4-channel features self.encode = nn.Sequential( - nn.Conv2d(3, 16, 3, 2, 1), - nn.ConvTranspose2d(16, 4, 4, 2, 1) + nn.Conv2d(3, 32, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.ConvTranspose2d(32, 4, 4, 2, 1) ) def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]): f0 = self.encode(img0[:, :3]) f1 = self.encode(img1[:, :3]) - flow_list = [] - merged = [] - mask_list = [] warped_img0 = img0 warped_img1 = img1 flow = None @@ -121,34 +130,35 @@ class IFNet(nn.Module): wf0 = warp(f0, flow[:, :2]) wf1 = warp(f1, flow[:, 2:4]) fd, m0 = block[i]( - torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1), + torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, f0, f1, timestep, mask), 1), flow, scale=scale_list[i]) flow = flow + fd mask = mask + m0 - mask_list.append(mask) - flow_list.append(flow) warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) - merged.append((warped_img0, warped_img1)) mask_final = torch.sigmoid(mask) merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final) return merged_final -# Model URLs for downloading +# Model URLs for downloading (Google Drive direct download links) +# File IDs extracted from official Practical-RIFE repository MODEL_URLS = { - 'v4.26': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.26/flownet.pkl', - 'v4.25': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.25/flownet.pkl', - 'v4.22': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.22/flownet.pkl', - 'v4.20': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.20/flownet.pkl', - 'v4.18': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.18/flownet.pkl', - 'v4.15': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.15/flownet.pkl', + 'v4.26': 'https://drive.google.com/uc?export=download&id=1gViYvvQrtETBgU1w8axZSsr7YUuw31uy', + 'v4.25': 'https://drive.google.com/uc?export=download&id=1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg', + 'v4.22': 'https://drive.google.com/uc?export=download&id=1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_', + 'v4.20': 'https://drive.google.com/uc?export=download&id=11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA', + 'v4.18': 'https://drive.google.com/uc?export=download&id=1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s', + 'v4.15': 'https://drive.google.com/uc?export=download&id=1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5', } def download_model(version: str, model_dir: Path) -> Path: """Download model if not already cached. + Google Drive links distribute zip files containing the model. + This function downloads and extracts the flownet.pkl file. + Args: version: Model version (e.g., 'v4.25'). model_dir: Directory to store models. @@ -160,25 +170,76 @@ def download_model(version: str, model_dir: Path) -> Path: model_path = model_dir / f'flownet_{version}.pkl' if model_path.exists(): - return model_path + # Verify it's not a zip file (from previous failed attempt) + with open(model_path, 'rb') as f: + header = f.read(4) + if header == b'PK\x03\x04': # ZIP magic number + print(f"Removing corrupted zip file at {model_path}", file=sys.stderr) + model_path.unlink() + else: + return model_path url = MODEL_URLS.get(version) if not url: raise ValueError(f"Unknown model version: {version}") print(f"Downloading RIFE model {version}...", file=sys.stderr) - try: - req = urllib.request.Request(url, headers={'User-Agent': 'video-montage-linker'}) - with urllib.request.urlopen(req, timeout=120) as response: - with open(model_path, 'wb') as f: - f.write(response.read()) - print(f"Model downloaded to {model_path}", file=sys.stderr) - return model_path - except Exception as e: - # Clean up partial download - if model_path.exists(): - model_path.unlink() - raise RuntimeError(f"Failed to download model: {e}") + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) / 'download' + + # Try using gdown for Google Drive (handles confirmations automatically) + downloaded = False + try: + import gdown + file_id = url.split('id=')[1] if 'id=' in url else None + if file_id: + gdown_url = f'https://drive.google.com/uc?id={file_id}' + gdown.download(gdown_url, str(tmp_path), quiet=False) + downloaded = tmp_path.exists() + except ImportError: + print("gdown not available, trying direct download...", file=sys.stderr) + except Exception as e: + print(f"gdown failed: {e}, trying direct download...", file=sys.stderr) + + # Fallback: direct download + if not downloaded: + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + with urllib.request.urlopen(req, timeout=300) as response: + data = response.read() + if data[:100].startswith(b' IFNet: diff --git a/ui/main_window.py b/ui/main_window.py index c9b3ac6..4bdd19b 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -1289,6 +1289,12 @@ class SequenceLinkerUI(QWidget): uhd=settings.rife_uhd, tta=settings.rife_tta ) + elif settings.blend_method == BlendMethod.RIFE_PRACTICAL: + blended = ImageBlender.practical_rife_blend( + img_a, img_b, factor, + settings.practical_rife_model, + settings.practical_rife_ensemble + ) else: blended = Image.blend(img_a, img_b, factor) @@ -2554,7 +2560,9 @@ class SequenceLinkerUI(QWidget): main_path, trans_path, factor, output_path, settings.output_format, settings.output_quality, settings.webp_method, - settings.blend_method, settings.rife_binary_path + settings.blend_method, settings.rife_binary_path, + settings.rife_model, settings.rife_uhd, settings.rife_tta, + settings.practical_rife_model, settings.practical_rife_ensemble ) if result.success: