This commit is contained in:
2026-02-03 23:21:31 +01:00
parent 2c6ad4ff35
commit 00f0141b15
4 changed files with 183 additions and 52 deletions

58
.gitignore vendored
View File

@@ -1,5 +1,61 @@
# Python
__pycache__/
*.pyc
*.pyo
.env
*.pyd
.Python
*.so
# Virtual environments
venv/
venv-rife/
.venv/
env/
# Environment files
.env
.env.local
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Database
*.db
*.sqlite
*.sqlite3
# Downloads and cache
*.pkl
*.pt
*.pth
*.onnx
downloads/
cache/
.cache/
# RIFE binaries and models
rife-ncnn-vulkan*/
*.zip
# Output directories
output/
outputs/
temp/
tmp/
# Logs
*.log
logs/
# OS files
.DS_Store
Thumbs.db
# Build artifacts
dist/
build/
*.egg-info/

View File

@@ -29,7 +29,7 @@ from .models import (
# Cache directory for downloaded binaries
CACHE_DIR = Path.home() / '.cache' / 'video-montage-linker'
RIFE_GITHUB_API = 'https://api.github.com/repos/nihui/rife-ncnn-vulkan/releases/latest'
PRACTICAL_RIFE_VENV_DIR = Path('./venv-rife')
PRACTICAL_RIFE_VENV_DIR = CACHE_DIR / 'venv-rife'
class PracticalRifeEnv:
@@ -147,13 +147,13 @@ class PracticalRifeEnv:
)
if progress_callback:
progress_callback("Installing numpy...", 90)
progress_callback("Installing additional dependencies...", 90)
if cancelled_check and cancelled_check():
return False
# numpy is usually a dependency of torch but ensure it's there
# Install numpy (usually a dependency of torch) and gdown (for Google Drive downloads)
subprocess.run(
[str(python), '-m', 'pip', 'install', 'numpy'],
[str(python), '-m', 'pip', 'install', 'numpy', 'gdown'],
capture_output=True
)
@@ -190,7 +190,7 @@ class PracticalRifeEnv:
t: float,
model: str = 'v4.25',
ensemble: bool = False
) -> bool:
) -> tuple[bool, str]:
"""Run RIFE interpolation via subprocess in venv.
Args:
@@ -202,15 +202,15 @@ class PracticalRifeEnv:
ensemble: Enable ensemble mode.
Returns:
True if interpolation succeeded.
Tuple of (success, error_message).
"""
python = cls.get_venv_python()
if not python or not python.exists():
return False
return False, "venv python not found"
script = cls.get_worker_script()
if not script.exists():
return False
return False, f"worker script not found: {script}"
cmd = [
str(python), str(script),
@@ -232,11 +232,15 @@ class PracticalRifeEnv:
text=True,
timeout=120 # 2 minute timeout per frame
)
return result.returncode == 0 and output_path.exists()
if result.returncode == 0 and output_path.exists():
return True, ""
else:
error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}"
return False, error
except subprocess.TimeoutExpired:
return False
except Exception:
return False
return False, "timeout (120s)"
except Exception as e:
return False, str(e)
class RifeDownloader:
@@ -768,7 +772,7 @@ class ImageBlender:
AI-interpolated blended PIL Image.
"""
if not PracticalRifeEnv.is_setup():
# Fall back to ncnn RIFE or optical flow
print("[Practical-RIFE] Venv not set up, falling back to ncnn RIFE", file=sys.stderr)
return ImageBlender.rife_blend(img_a, img_b, t)
try:
@@ -783,15 +787,17 @@ class ImageBlender:
img_b.convert('RGB').save(input_b)
# Run Practical-RIFE via subprocess
success = PracticalRifeEnv.run_interpolation(
success, error_msg = PracticalRifeEnv.run_interpolation(
input_a, input_b, output_file, t, model, ensemble
)
if success and output_file.exists():
return Image.open(output_file).copy()
else:
print(f"[Practical-RIFE] Interpolation failed: {error_msg}, falling back to ncnn RIFE", file=sys.stderr)
except Exception:
pass
except Exception as e:
print(f"[Practical-RIFE] Exception: {e}, falling back to ncnn RIFE", file=sys.stderr)
# Fall back to ncnn RIFE or optical flow
return ImageBlender.rife_blend(img_a, img_b, t)

View File

@@ -11,8 +11,11 @@ with a simplified inference implementation.
import argparse
import os
import shutil
import sys
import tempfile
import urllib.request
import zipfile
from pathlib import Path
import numpy as np
@@ -88,25 +91,31 @@ def warp(tenInput, tenFlow):
class IFNet(nn.Module):
"""IFNet architecture for RIFE v4.x models."""
"""IFNet architecture for Practical-RIFE v4.25/v4.26 models."""
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+16, c=192)
self.block1 = IFBlock(8+4+16, c=128)
self.block2 = IFBlock(8+4+16, c=96)
self.block3 = IFBlock(8+4+16, c=64)
# v4.25/v4.26 architecture:
# block0 input: img0(3) + img1(3) + f0(4) + f1(4) + timestep(1) = 15
# block1+ input: img0(3) + img1(3) + wf0(4) + wf1(4) + f0(4) + f1(4) + timestep(1) + mask(1) + flow(4) = 28
self.block0 = IFBlock(3+3+4+4+1, c=192)
self.block1 = IFBlock(3+3+4+4+4+4+1+1+4, c=128)
self.block2 = IFBlock(3+3+4+4+4+4+1+1+4, c=96)
self.block3 = IFBlock(3+3+4+4+4+4+1+1+4, c=64)
# Encode produces 4-channel features
self.encode = nn.Sequential(
nn.Conv2d(3, 16, 3, 2, 1),
nn.ConvTranspose2d(16, 4, 4, 2, 1)
nn.Conv2d(3, 32, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 32, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 32, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.ConvTranspose2d(32, 4, 4, 2, 1)
)
def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]):
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
@@ -121,34 +130,35 @@ class IFNet(nn.Module):
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1),
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, f0, f1, timestep, mask), 1),
flow, scale=scale_list[i])
flow = flow + fd
mask = mask + m0
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask_final = torch.sigmoid(mask)
merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final)
return merged_final
# Model URLs for downloading
# Model URLs for downloading (Google Drive direct download links)
# File IDs extracted from official Practical-RIFE repository
MODEL_URLS = {
'v4.26': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.26/flownet.pkl',
'v4.25': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.25/flownet.pkl',
'v4.22': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.22/flownet.pkl',
'v4.20': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.20/flownet.pkl',
'v4.18': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.18/flownet.pkl',
'v4.15': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.15/flownet.pkl',
'v4.26': 'https://drive.google.com/uc?export=download&id=1gViYvvQrtETBgU1w8axZSsr7YUuw31uy',
'v4.25': 'https://drive.google.com/uc?export=download&id=1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg',
'v4.22': 'https://drive.google.com/uc?export=download&id=1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_',
'v4.20': 'https://drive.google.com/uc?export=download&id=11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA',
'v4.18': 'https://drive.google.com/uc?export=download&id=1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s',
'v4.15': 'https://drive.google.com/uc?export=download&id=1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5',
}
def download_model(version: str, model_dir: Path) -> Path:
"""Download model if not already cached.
Google Drive links distribute zip files containing the model.
This function downloads and extracts the flownet.pkl file.
Args:
version: Model version (e.g., 'v4.25').
model_dir: Directory to store models.
@@ -160,25 +170,76 @@ def download_model(version: str, model_dir: Path) -> Path:
model_path = model_dir / f'flownet_{version}.pkl'
if model_path.exists():
return model_path
# Verify it's not a zip file (from previous failed attempt)
with open(model_path, 'rb') as f:
header = f.read(4)
if header == b'PK\x03\x04': # ZIP magic number
print(f"Removing corrupted zip file at {model_path}", file=sys.stderr)
model_path.unlink()
else:
return model_path
url = MODEL_URLS.get(version)
if not url:
raise ValueError(f"Unknown model version: {version}")
print(f"Downloading RIFE model {version}...", file=sys.stderr)
try:
req = urllib.request.Request(url, headers={'User-Agent': 'video-montage-linker'})
with urllib.request.urlopen(req, timeout=120) as response:
with open(model_path, 'wb') as f:
f.write(response.read())
print(f"Model downloaded to {model_path}", file=sys.stderr)
return model_path
except Exception as e:
# Clean up partial download
if model_path.exists():
model_path.unlink()
raise RuntimeError(f"Failed to download model: {e}")
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir) / 'download'
# Try using gdown for Google Drive (handles confirmations automatically)
downloaded = False
try:
import gdown
file_id = url.split('id=')[1] if 'id=' in url else None
if file_id:
gdown_url = f'https://drive.google.com/uc?id={file_id}'
gdown.download(gdown_url, str(tmp_path), quiet=False)
downloaded = tmp_path.exists()
except ImportError:
print("gdown not available, trying direct download...", file=sys.stderr)
except Exception as e:
print(f"gdown failed: {e}, trying direct download...", file=sys.stderr)
# Fallback: direct download
if not downloaded:
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req, timeout=300) as response:
data = response.read()
if data[:100].startswith(b'<!') or b'<html' in data[:500].lower():
raise RuntimeError("Google Drive returned HTML - install gdown: pip install gdown")
with open(tmp_path, 'wb') as f:
f.write(data)
downloaded = True
except Exception as e:
raise RuntimeError(f"Failed to download model: {e}")
if not downloaded or not tmp_path.exists():
raise RuntimeError("Download failed - no file received")
# Check if downloaded file is a zip archive
with open(tmp_path, 'rb') as f:
header = f.read(4)
if header == b'PK\x03\x04': # ZIP magic number
print(f"Extracting model from zip archive...", file=sys.stderr)
with zipfile.ZipFile(tmp_path, 'r') as zf:
# Find flownet.pkl in the archive
pkl_files = [n for n in zf.namelist() if n.endswith('flownet.pkl')]
if not pkl_files:
raise RuntimeError(f"No flownet.pkl found in zip. Contents: {zf.namelist()}")
# Extract the pkl file
pkl_name = pkl_files[0]
with zf.open(pkl_name) as src, open(model_path, 'wb') as dst:
dst.write(src.read())
else:
# Already a pkl file, just move it
shutil.move(str(tmp_path), str(model_path))
print(f"Model saved to {model_path}", file=sys.stderr)
return model_path
def load_model(model_path: Path, device: torch.device) -> IFNet:

View File

@@ -1289,6 +1289,12 @@ class SequenceLinkerUI(QWidget):
uhd=settings.rife_uhd,
tta=settings.rife_tta
)
elif settings.blend_method == BlendMethod.RIFE_PRACTICAL:
blended = ImageBlender.practical_rife_blend(
img_a, img_b, factor,
settings.practical_rife_model,
settings.practical_rife_ensemble
)
else:
blended = Image.blend(img_a, img_b, factor)
@@ -2554,7 +2560,9 @@ class SequenceLinkerUI(QWidget):
main_path, trans_path, factor,
output_path, settings.output_format,
settings.output_quality, settings.webp_method,
settings.blend_method, settings.rife_binary_path
settings.blend_method, settings.rife_binary_path,
settings.rife_model, settings.rife_uhd, settings.rife_tta,
settings.practical_rife_model, settings.practical_rife_ensemble
)
if result.success: