rife
This commit is contained in:
58
.gitignore
vendored
58
.gitignore
vendored
@@ -1,5 +1,61 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
.env
|
||||
*.pyd
|
||||
.Python
|
||||
*.so
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
venv-rife/
|
||||
.venv/
|
||||
env/
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Downloads and cache
|
||||
*.pkl
|
||||
*.pt
|
||||
*.pth
|
||||
*.onnx
|
||||
downloads/
|
||||
cache/
|
||||
.cache/
|
||||
|
||||
# RIFE binaries and models
|
||||
rife-ncnn-vulkan*/
|
||||
*.zip
|
||||
|
||||
# Output directories
|
||||
output/
|
||||
outputs/
|
||||
temp/
|
||||
tmp/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
@@ -29,7 +29,7 @@ from .models import (
|
||||
# Cache directory for downloaded binaries
|
||||
CACHE_DIR = Path.home() / '.cache' / 'video-montage-linker'
|
||||
RIFE_GITHUB_API = 'https://api.github.com/repos/nihui/rife-ncnn-vulkan/releases/latest'
|
||||
PRACTICAL_RIFE_VENV_DIR = Path('./venv-rife')
|
||||
PRACTICAL_RIFE_VENV_DIR = CACHE_DIR / 'venv-rife'
|
||||
|
||||
|
||||
class PracticalRifeEnv:
|
||||
@@ -147,13 +147,13 @@ class PracticalRifeEnv:
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("Installing numpy...", 90)
|
||||
progress_callback("Installing additional dependencies...", 90)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
# numpy is usually a dependency of torch but ensure it's there
|
||||
# Install numpy (usually a dependency of torch) and gdown (for Google Drive downloads)
|
||||
subprocess.run(
|
||||
[str(python), '-m', 'pip', 'install', 'numpy'],
|
||||
[str(python), '-m', 'pip', 'install', 'numpy', 'gdown'],
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
@@ -190,7 +190,7 @@ class PracticalRifeEnv:
|
||||
t: float,
|
||||
model: str = 'v4.25',
|
||||
ensemble: bool = False
|
||||
) -> bool:
|
||||
) -> tuple[bool, str]:
|
||||
"""Run RIFE interpolation via subprocess in venv.
|
||||
|
||||
Args:
|
||||
@@ -202,15 +202,15 @@ class PracticalRifeEnv:
|
||||
ensemble: Enable ensemble mode.
|
||||
|
||||
Returns:
|
||||
True if interpolation succeeded.
|
||||
Tuple of (success, error_message).
|
||||
"""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False
|
||||
return False, "venv python not found"
|
||||
|
||||
script = cls.get_worker_script()
|
||||
if not script.exists():
|
||||
return False
|
||||
return False, f"worker script not found: {script}"
|
||||
|
||||
cmd = [
|
||||
str(python), str(script),
|
||||
@@ -232,11 +232,15 @@ class PracticalRifeEnv:
|
||||
text=True,
|
||||
timeout=120 # 2 minute timeout per frame
|
||||
)
|
||||
return result.returncode == 0 and output_path.exists()
|
||||
if result.returncode == 0 and output_path.exists():
|
||||
return True, ""
|
||||
else:
|
||||
error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}"
|
||||
return False, error
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
return False, "timeout (120s)"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
class RifeDownloader:
|
||||
@@ -768,7 +772,7 @@ class ImageBlender:
|
||||
AI-interpolated blended PIL Image.
|
||||
"""
|
||||
if not PracticalRifeEnv.is_setup():
|
||||
# Fall back to ncnn RIFE or optical flow
|
||||
print("[Practical-RIFE] Venv not set up, falling back to ncnn RIFE", file=sys.stderr)
|
||||
return ImageBlender.rife_blend(img_a, img_b, t)
|
||||
|
||||
try:
|
||||
@@ -783,15 +787,17 @@ class ImageBlender:
|
||||
img_b.convert('RGB').save(input_b)
|
||||
|
||||
# Run Practical-RIFE via subprocess
|
||||
success = PracticalRifeEnv.run_interpolation(
|
||||
success, error_msg = PracticalRifeEnv.run_interpolation(
|
||||
input_a, input_b, output_file, t, model, ensemble
|
||||
)
|
||||
|
||||
if success and output_file.exists():
|
||||
return Image.open(output_file).copy()
|
||||
else:
|
||||
print(f"[Practical-RIFE] Interpolation failed: {error_msg}, falling back to ncnn RIFE", file=sys.stderr)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"[Practical-RIFE] Exception: {e}, falling back to ncnn RIFE", file=sys.stderr)
|
||||
|
||||
# Fall back to ncnn RIFE or optical flow
|
||||
return ImageBlender.rife_blend(img_a, img_b, t)
|
||||
|
||||
@@ -11,8 +11,11 @@ with a simplified inference implementation.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -88,25 +91,31 @@ def warp(tenInput, tenFlow):
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
"""IFNet architecture for RIFE v4.x models."""
|
||||
"""IFNet architecture for Practical-RIFE v4.25/v4.26 models."""
|
||||
|
||||
def __init__(self):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(7+16, c=192)
|
||||
self.block1 = IFBlock(8+4+16, c=128)
|
||||
self.block2 = IFBlock(8+4+16, c=96)
|
||||
self.block3 = IFBlock(8+4+16, c=64)
|
||||
# v4.25/v4.26 architecture:
|
||||
# block0 input: img0(3) + img1(3) + f0(4) + f1(4) + timestep(1) = 15
|
||||
# block1+ input: img0(3) + img1(3) + wf0(4) + wf1(4) + f0(4) + f1(4) + timestep(1) + mask(1) + flow(4) = 28
|
||||
self.block0 = IFBlock(3+3+4+4+1, c=192)
|
||||
self.block1 = IFBlock(3+3+4+4+4+4+1+1+4, c=128)
|
||||
self.block2 = IFBlock(3+3+4+4+4+4+1+1+4, c=96)
|
||||
self.block3 = IFBlock(3+3+4+4+4+4+1+1+4, c=64)
|
||||
# Encode produces 4-channel features
|
||||
self.encode = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, 2, 1),
|
||||
nn.ConvTranspose2d(16, 4, 4, 2, 1)
|
||||
nn.Conv2d(3, 32, 3, 2, 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(32, 32, 3, 1, 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(32, 32, 3, 1, 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.ConvTranspose2d(32, 4, 4, 2, 1)
|
||||
)
|
||||
|
||||
def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]):
|
||||
f0 = self.encode(img0[:, :3])
|
||||
f1 = self.encode(img1[:, :3])
|
||||
flow_list = []
|
||||
merged = []
|
||||
mask_list = []
|
||||
warped_img0 = img0
|
||||
warped_img1 = img1
|
||||
flow = None
|
||||
@@ -121,34 +130,35 @@ class IFNet(nn.Module):
|
||||
wf0 = warp(f0, flow[:, :2])
|
||||
wf1 = warp(f1, flow[:, 2:4])
|
||||
fd, m0 = block[i](
|
||||
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1),
|
||||
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, f0, f1, timestep, mask), 1),
|
||||
flow, scale=scale_list[i])
|
||||
flow = flow + fd
|
||||
mask = mask + m0
|
||||
mask_list.append(mask)
|
||||
flow_list.append(flow)
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
merged.append((warped_img0, warped_img1))
|
||||
mask_final = torch.sigmoid(mask)
|
||||
merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final)
|
||||
return merged_final
|
||||
|
||||
|
||||
# Model URLs for downloading
|
||||
# Model URLs for downloading (Google Drive direct download links)
|
||||
# File IDs extracted from official Practical-RIFE repository
|
||||
MODEL_URLS = {
|
||||
'v4.26': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.26/flownet.pkl',
|
||||
'v4.25': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.25/flownet.pkl',
|
||||
'v4.22': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.22/flownet.pkl',
|
||||
'v4.20': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.20/flownet.pkl',
|
||||
'v4.18': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.18/flownet.pkl',
|
||||
'v4.15': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.15/flownet.pkl',
|
||||
'v4.26': 'https://drive.google.com/uc?export=download&id=1gViYvvQrtETBgU1w8axZSsr7YUuw31uy',
|
||||
'v4.25': 'https://drive.google.com/uc?export=download&id=1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg',
|
||||
'v4.22': 'https://drive.google.com/uc?export=download&id=1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_',
|
||||
'v4.20': 'https://drive.google.com/uc?export=download&id=11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA',
|
||||
'v4.18': 'https://drive.google.com/uc?export=download&id=1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s',
|
||||
'v4.15': 'https://drive.google.com/uc?export=download&id=1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5',
|
||||
}
|
||||
|
||||
|
||||
def download_model(version: str, model_dir: Path) -> Path:
|
||||
"""Download model if not already cached.
|
||||
|
||||
Google Drive links distribute zip files containing the model.
|
||||
This function downloads and extracts the flownet.pkl file.
|
||||
|
||||
Args:
|
||||
version: Model version (e.g., 'v4.25').
|
||||
model_dir: Directory to store models.
|
||||
@@ -160,25 +170,76 @@ def download_model(version: str, model_dir: Path) -> Path:
|
||||
model_path = model_dir / f'flownet_{version}.pkl'
|
||||
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
# Verify it's not a zip file (from previous failed attempt)
|
||||
with open(model_path, 'rb') as f:
|
||||
header = f.read(4)
|
||||
if header == b'PK\x03\x04': # ZIP magic number
|
||||
print(f"Removing corrupted zip file at {model_path}", file=sys.stderr)
|
||||
model_path.unlink()
|
||||
else:
|
||||
return model_path
|
||||
|
||||
url = MODEL_URLS.get(version)
|
||||
if not url:
|
||||
raise ValueError(f"Unknown model version: {version}")
|
||||
|
||||
print(f"Downloading RIFE model {version}...", file=sys.stderr)
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'video-montage-linker'})
|
||||
with urllib.request.urlopen(req, timeout=120) as response:
|
||||
with open(model_path, 'wb') as f:
|
||||
f.write(response.read())
|
||||
print(f"Model downloaded to {model_path}", file=sys.stderr)
|
||||
return model_path
|
||||
except Exception as e:
|
||||
# Clean up partial download
|
||||
if model_path.exists():
|
||||
model_path.unlink()
|
||||
raise RuntimeError(f"Failed to download model: {e}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir) / 'download'
|
||||
|
||||
# Try using gdown for Google Drive (handles confirmations automatically)
|
||||
downloaded = False
|
||||
try:
|
||||
import gdown
|
||||
file_id = url.split('id=')[1] if 'id=' in url else None
|
||||
if file_id:
|
||||
gdown_url = f'https://drive.google.com/uc?id={file_id}'
|
||||
gdown.download(gdown_url, str(tmp_path), quiet=False)
|
||||
downloaded = tmp_path.exists()
|
||||
except ImportError:
|
||||
print("gdown not available, trying direct download...", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"gdown failed: {e}, trying direct download...", file=sys.stderr)
|
||||
|
||||
# Fallback: direct download
|
||||
if not downloaded:
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
data = response.read()
|
||||
if data[:100].startswith(b'<!') or b'<html' in data[:500].lower():
|
||||
raise RuntimeError("Google Drive returned HTML - install gdown: pip install gdown")
|
||||
with open(tmp_path, 'wb') as f:
|
||||
f.write(data)
|
||||
downloaded = True
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download model: {e}")
|
||||
|
||||
if not downloaded or not tmp_path.exists():
|
||||
raise RuntimeError("Download failed - no file received")
|
||||
|
||||
# Check if downloaded file is a zip archive
|
||||
with open(tmp_path, 'rb') as f:
|
||||
header = f.read(4)
|
||||
|
||||
if header == b'PK\x03\x04': # ZIP magic number
|
||||
print(f"Extracting model from zip archive...", file=sys.stderr)
|
||||
with zipfile.ZipFile(tmp_path, 'r') as zf:
|
||||
# Find flownet.pkl in the archive
|
||||
pkl_files = [n for n in zf.namelist() if n.endswith('flownet.pkl')]
|
||||
if not pkl_files:
|
||||
raise RuntimeError(f"No flownet.pkl found in zip. Contents: {zf.namelist()}")
|
||||
# Extract the pkl file
|
||||
pkl_name = pkl_files[0]
|
||||
with zf.open(pkl_name) as src, open(model_path, 'wb') as dst:
|
||||
dst.write(src.read())
|
||||
else:
|
||||
# Already a pkl file, just move it
|
||||
shutil.move(str(tmp_path), str(model_path))
|
||||
|
||||
print(f"Model saved to {model_path}", file=sys.stderr)
|
||||
return model_path
|
||||
|
||||
|
||||
def load_model(model_path: Path, device: torch.device) -> IFNet:
|
||||
|
||||
@@ -1289,6 +1289,12 @@ class SequenceLinkerUI(QWidget):
|
||||
uhd=settings.rife_uhd,
|
||||
tta=settings.rife_tta
|
||||
)
|
||||
elif settings.blend_method == BlendMethod.RIFE_PRACTICAL:
|
||||
blended = ImageBlender.practical_rife_blend(
|
||||
img_a, img_b, factor,
|
||||
settings.practical_rife_model,
|
||||
settings.practical_rife_ensemble
|
||||
)
|
||||
else:
|
||||
blended = Image.blend(img_a, img_b, factor)
|
||||
|
||||
@@ -2554,7 +2560,9 @@ class SequenceLinkerUI(QWidget):
|
||||
main_path, trans_path, factor,
|
||||
output_path, settings.output_format,
|
||||
settings.output_quality, settings.webp_method,
|
||||
settings.blend_method, settings.rife_binary_path
|
||||
settings.blend_method, settings.rife_binary_path,
|
||||
settings.rife_model, settings.rife_uhd, settings.rife_tta,
|
||||
settings.practical_rife_model, settings.practical_rife_ensemble
|
||||
)
|
||||
|
||||
if result.success:
|
||||
|
||||
Reference in New Issue
Block a user