Add Practical-RIFE frame interpolation support

Implement standalone PyTorch-based RIFE interpolation that runs in a
dedicated virtual environment to avoid Qt/OpenCV conflicts:

- Add PracticalRifeEnv class for managing venv and subprocess execution
- Add rife_worker.py standalone interpolation script using Practical-RIFE
- Add RIFE_PRACTICAL blending model with ensemble/fast mode settings
- Add UI controls for Practical-RIFE configuration
- Update .gitignore to exclude venv-rife/ directory

The implementation downloads Practical-RIFE models on first use and runs
interpolation in a separate process with proper progress reporting.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-03 20:46:06 +01:00
parent 6bfbefb058
commit 2c6ad4ff35
6 changed files with 1148 additions and 58 deletions

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@ __pycache__/
*.pyc
*.pyo
.env
venv-rife/

View File

@@ -19,7 +19,7 @@ from .models import (
DatabaseError,
)
from .database import DatabaseManager
from .blender import ImageBlender, TransitionGenerator, RifeDownloader
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv
from .manager import SymlinkManager
__all__ = [
@@ -43,5 +43,6 @@ __all__ = [
'ImageBlender',
'TransitionGenerator',
'RifeDownloader',
'PracticalRifeEnv',
'SymlinkManager',
]

View File

@@ -29,6 +29,214 @@ from .models import (
# Cache directory for downloaded binaries
CACHE_DIR = Path.home() / '.cache' / 'video-montage-linker'
RIFE_GITHUB_API = 'https://api.github.com/repos/nihui/rife-ncnn-vulkan/releases/latest'
PRACTICAL_RIFE_VENV_DIR = Path('./venv-rife')
class PracticalRifeEnv:
"""Manages isolated Python environment for Practical-RIFE."""
VENV_DIR = PRACTICAL_RIFE_VENV_DIR
MODEL_CACHE_DIR = CACHE_DIR / 'practical-rife'
REQUIRED_PACKAGES = ['torch', 'torchvision', 'numpy']
# Available Practical-RIFE models
AVAILABLE_MODELS = ['v4.26', 'v4.25', 'v4.22', 'v4.20', 'v4.18', 'v4.15']
@classmethod
def get_venv_python(cls) -> Optional[Path]:
"""Get path to venv Python executable."""
if cls.VENV_DIR.exists():
if sys.platform == 'win32':
return cls.VENV_DIR / 'Scripts' / 'python.exe'
return cls.VENV_DIR / 'bin' / 'python'
return None
@classmethod
def is_setup(cls) -> bool:
"""Check if venv exists and has required packages."""
python = cls.get_venv_python()
if not python or not python.exists():
return False
# Check if torch is importable
result = subprocess.run(
[str(python), '-c', 'import torch; print(torch.__version__)'],
capture_output=True
)
return result.returncode == 0
@classmethod
def get_torch_version(cls) -> Optional[str]:
"""Get installed torch version in venv."""
python = cls.get_venv_python()
if not python or not python.exists():
return None
result = subprocess.run(
[str(python), '-c', 'import torch; print(torch.__version__)'],
capture_output=True,
text=True
)
if result.returncode == 0:
return result.stdout.strip()
return None
@classmethod
def setup_venv(cls, progress_callback=None, cancelled_check=None) -> bool:
"""Create venv and install PyTorch.
Args:
progress_callback: Optional callback(message, percent) for progress.
cancelled_check: Optional callable that returns True if cancelled.
Returns:
True if setup was successful.
"""
import venv
try:
# 1. Create venv
if progress_callback:
progress_callback("Creating virtual environment...", 10)
if cancelled_check and cancelled_check():
return False
# Remove old venv if exists
if cls.VENV_DIR.exists():
shutil.rmtree(cls.VENV_DIR)
venv.create(cls.VENV_DIR, with_pip=True)
# 2. Get pip path
python = cls.get_venv_python()
if not python:
return False
# 3. Upgrade pip
if progress_callback:
progress_callback("Upgrading pip...", 20)
if cancelled_check and cancelled_check():
return False
subprocess.run(
[str(python), '-m', 'pip', 'install', '--upgrade', 'pip'],
capture_output=True,
check=True
)
# 4. Install PyTorch (this is the big download)
if progress_callback:
progress_callback("Installing PyTorch (this may take a while)...", 30)
if cancelled_check and cancelled_check():
return False
# Try to install with CUDA support first, fall back to CPU
# Use pip index to get the right version
result = subprocess.run(
[str(python), '-m', 'pip', 'install', 'torch', 'torchvision'],
capture_output=True,
text=True
)
if result.returncode != 0:
# Try CPU-only version
subprocess.run(
[str(python), '-m', 'pip', 'install',
'torch', 'torchvision',
'--index-url', 'https://download.pytorch.org/whl/cpu'],
capture_output=True,
check=True
)
if progress_callback:
progress_callback("Installing numpy...", 90)
if cancelled_check and cancelled_check():
return False
# numpy is usually a dependency of torch but ensure it's there
subprocess.run(
[str(python), '-m', 'pip', 'install', 'numpy'],
capture_output=True
)
if progress_callback:
progress_callback("Setup complete!", 100)
return cls.is_setup()
except Exception as e:
# Cleanup on error
if cls.VENV_DIR.exists():
try:
shutil.rmtree(cls.VENV_DIR)
except Exception:
pass
return False
@classmethod
def get_available_models(cls) -> list[str]:
"""Return list of available model versions."""
return cls.AVAILABLE_MODELS.copy()
@classmethod
def get_worker_script(cls) -> Path:
"""Get path to the RIFE worker script."""
return Path(__file__).parent / 'rife_worker.py'
@classmethod
def run_interpolation(
cls,
img_a_path: Path,
img_b_path: Path,
output_path: Path,
t: float,
model: str = 'v4.25',
ensemble: bool = False
) -> bool:
"""Run RIFE interpolation via subprocess in venv.
Args:
img_a_path: Path to first input image.
img_b_path: Path to second input image.
output_path: Path to output image.
t: Timestep for interpolation (0.0 to 1.0).
model: Model version to use.
ensemble: Enable ensemble mode.
Returns:
True if interpolation succeeded.
"""
python = cls.get_venv_python()
if not python or not python.exists():
return False
script = cls.get_worker_script()
if not script.exists():
return False
cmd = [
str(python), str(script),
'--input0', str(img_a_path),
'--input1', str(img_b_path),
'--output', str(output_path),
'--timestep', str(t),
'--model', model,
'--model-dir', str(cls.MODEL_CACHE_DIR)
]
if ensemble:
cmd.append('--ensemble')
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=120 # 2 minute timeout per frame
)
return result.returncode == 0 and output_path.exists()
except subprocess.TimeoutExpired:
return False
except Exception:
return False
class RifeDownloader:
@@ -398,7 +606,10 @@ class ImageBlender:
img_b: Image.Image,
t: float,
binary_path: Optional[Path] = None,
auto_download: bool = True
auto_download: bool = True,
model: str = 'rife-v4.6',
uhd: bool = False,
tta: bool = False
) -> Image.Image:
"""Blend using RIFE AI frame interpolation.
@@ -411,27 +622,30 @@ class ImageBlender:
t: Interpolation factor 0.0 (100% A) to 1.0 (100% B).
binary_path: Optional path to rife-ncnn-vulkan binary.
auto_download: Whether to auto-download RIFE if not found.
model: RIFE model to use (e.g., 'rife-v4.6', 'rife-anime').
uhd: Enable UHD mode for high resolution images.
tta: Enable TTA mode for better quality (slower).
Returns:
AI-interpolated blended PIL Image.
"""
# Try NCNN binary first (specified path)
if binary_path and binary_path.exists():
result = ImageBlender._rife_ncnn(img_a, img_b, t, binary_path)
result = ImageBlender._rife_ncnn(img_a, img_b, t, binary_path, model, uhd, tta)
if result is not None:
return result
# Try to find rife-ncnn-vulkan in PATH
ncnn_path = shutil.which('rife-ncnn-vulkan')
if ncnn_path:
result = ImageBlender._rife_ncnn(img_a, img_b, t, Path(ncnn_path))
result = ImageBlender._rife_ncnn(img_a, img_b, t, Path(ncnn_path), model, uhd, tta)
if result is not None:
return result
# Try cached binary
cached = RifeDownloader.get_cached_binary()
if cached:
result = ImageBlender._rife_ncnn(img_a, img_b, t, cached)
result = ImageBlender._rife_ncnn(img_a, img_b, t, cached, model, uhd, tta)
if result is not None:
return result
@@ -439,7 +653,7 @@ class ImageBlender:
if auto_download:
downloaded = RifeDownloader.ensure_binary()
if downloaded:
result = ImageBlender._rife_ncnn(img_a, img_b, t, downloaded)
result = ImageBlender._rife_ncnn(img_a, img_b, t, downloaded, model, uhd, tta)
if result is not None:
return result
@@ -451,7 +665,10 @@ class ImageBlender:
img_a: Image.Image,
img_b: Image.Image,
t: float,
binary: Path
binary: Path,
model: str = 'rife-v4.6',
uhd: bool = False,
tta: bool = False
) -> Optional[Image.Image]:
"""Use rife-ncnn-vulkan binary for interpolation.
@@ -460,6 +677,9 @@ class ImageBlender:
img_b: Second PIL Image.
t: Interpolation timestep (0.0 to 1.0).
binary: Path to rife-ncnn-vulkan binary.
model: RIFE model to use.
uhd: Enable UHD mode.
tta: Enable TTA mode.
Returns:
Interpolated PIL Image, or None if failed.
@@ -485,6 +705,19 @@ class ImageBlender:
'-o', str(output_file),
]
# Add model path (models are in same directory as binary)
model_path = binary.parent / model
if model_path.exists():
cmd.extend(['-m', str(model_path)])
# Add UHD mode flag
if uhd:
cmd.append('-u')
# Add TTA mode flag (spatial)
if tta:
cmd.append('-x')
# Some versions support -s for timestep
# Try with timestep first, fall back to simple interpolation
try:
@@ -492,7 +725,7 @@ class ImageBlender:
cmd + ['-s', str(t)],
check=True,
capture_output=True,
timeout=30
timeout=60 # Increased timeout for TTA mode
)
except subprocess.CalledProcessError:
# Try without timestep (generates middle frame at t=0.5)
@@ -500,7 +733,7 @@ class ImageBlender:
cmd,
check=True,
capture_output=True,
timeout=30
timeout=60
)
if output_file.exists():
@@ -511,6 +744,58 @@ class ImageBlender:
return None
@staticmethod
def practical_rife_blend(
img_a: Image.Image,
img_b: Image.Image,
t: float,
model: str = 'v4.25',
ensemble: bool = False
) -> Image.Image:
"""Blend using Practical-RIFE Python/PyTorch implementation.
Runs RIFE interpolation via subprocess in an isolated venv.
Falls back to ncnn RIFE or optical flow if unavailable.
Args:
img_a: First PIL Image (source frame).
img_b: Second PIL Image (target frame).
t: Interpolation factor 0.0 (100% A) to 1.0 (100% B).
model: Practical-RIFE model version (e.g., 'v4.25', 'v4.26').
ensemble: Enable ensemble mode for better quality (slower).
Returns:
AI-interpolated blended PIL Image.
"""
if not PracticalRifeEnv.is_setup():
# Fall back to ncnn RIFE or optical flow
return ImageBlender.rife_blend(img_a, img_b, t)
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
input_a = tmp / 'a.png'
input_b = tmp / 'b.png'
output_file = tmp / 'out.png'
# Save input images
img_a.convert('RGB').save(input_a)
img_b.convert('RGB').save(input_b)
# Run Practical-RIFE via subprocess
success = PracticalRifeEnv.run_interpolation(
input_a, input_b, output_file, t, model, ensemble
)
if success and output_file.exists():
return Image.open(output_file).copy()
except Exception:
pass
# Fall back to ncnn RIFE or optical flow
return ImageBlender.rife_blend(img_a, img_b, t)
@staticmethod
def blend_images(
img_a_path: Path,
@@ -521,7 +806,12 @@ class ImageBlender:
output_quality: int = 95,
webp_method: int = 4,
blend_method: BlendMethod = BlendMethod.ALPHA,
rife_binary_path: Optional[Path] = None
rife_binary_path: Optional[Path] = None,
rife_model: str = 'rife-v4.6',
rife_uhd: bool = False,
rife_tta: bool = False,
practical_rife_model: str = 'v4.25',
practical_rife_ensemble: bool = False
) -> BlendResult:
"""Blend two images together.
@@ -533,8 +823,13 @@ class ImageBlender:
output_format: Output format (png, jpeg, webp).
output_quality: Quality for JPEG output (1-100).
webp_method: WebP compression method (0-6, higher = smaller but slower).
blend_method: The blending method to use (alpha, optical_flow, or rife).
blend_method: The blending method to use (alpha, optical_flow, rife, rife_practical).
rife_binary_path: Optional path to rife-ncnn-vulkan binary.
rife_model: RIFE ncnn model to use (e.g., 'rife-v4.6').
rife_uhd: Enable RIFE ncnn UHD mode.
rife_tta: Enable RIFE ncnn TTA mode.
practical_rife_model: Practical-RIFE model version (e.g., 'v4.25').
practical_rife_ensemble: Enable Practical-RIFE ensemble mode.
Returns:
BlendResult with operation status.
@@ -557,7 +852,13 @@ class ImageBlender:
if blend_method == BlendMethod.OPTICAL_FLOW:
blended = ImageBlender.optical_flow_blend(img_a, img_b, factor)
elif blend_method == BlendMethod.RIFE:
blended = ImageBlender.rife_blend(img_a, img_b, factor, rife_binary_path)
blended = ImageBlender.rife_blend(
img_a, img_b, factor, rife_binary_path, True, rife_model, rife_uhd, rife_tta
)
elif blend_method == BlendMethod.RIFE_PRACTICAL:
blended = ImageBlender.practical_rife_blend(
img_a, img_b, factor, practical_rife_model, practical_rife_ensemble
)
else:
# Default: simple alpha blend
blended = Image.blend(img_a, img_b, factor)
@@ -610,7 +911,12 @@ class ImageBlender:
output_quality: int = 95,
webp_method: int = 4,
blend_method: BlendMethod = BlendMethod.ALPHA,
rife_binary_path: Optional[Path] = None
rife_binary_path: Optional[Path] = None,
rife_model: str = 'rife-v4.6',
rife_uhd: bool = False,
rife_tta: bool = False,
practical_rife_model: str = 'v4.25',
practical_rife_ensemble: bool = False
) -> BlendResult:
"""Blend two PIL Image objects together.
@@ -622,8 +928,13 @@ class ImageBlender:
output_format: Output format (png, jpeg, webp).
output_quality: Quality for JPEG output (1-100).
webp_method: WebP compression method (0-6).
blend_method: The blending method to use (alpha, optical_flow, or rife).
blend_method: The blending method to use (alpha, optical_flow, rife, rife_practical).
rife_binary_path: Optional path to rife-ncnn-vulkan binary.
rife_model: RIFE ncnn model to use (e.g., 'rife-v4.6').
rife_uhd: Enable RIFE ncnn UHD mode.
rife_tta: Enable RIFE ncnn TTA mode.
practical_rife_model: Practical-RIFE model version (e.g., 'v4.25').
practical_rife_ensemble: Enable Practical-RIFE ensemble mode.
Returns:
BlendResult with operation status.
@@ -643,7 +954,13 @@ class ImageBlender:
if blend_method == BlendMethod.OPTICAL_FLOW:
blended = ImageBlender.optical_flow_blend(img_a, img_b, factor)
elif blend_method == BlendMethod.RIFE:
blended = ImageBlender.rife_blend(img_a, img_b, factor, rife_binary_path)
blended = ImageBlender.rife_blend(
img_a, img_b, factor, rife_binary_path, True, rife_model, rife_uhd, rife_tta
)
elif blend_method == BlendMethod.RIFE_PRACTICAL:
blended = ImageBlender.practical_rife_blend(
img_a, img_b, factor, practical_rife_model, practical_rife_ensemble
)
else:
# Default: simple alpha blend
blended = Image.blend(img_a, img_b, factor)
@@ -887,7 +1204,12 @@ class TransitionGenerator:
self.settings.output_quality,
self.settings.webp_method,
self.settings.blend_method,
self.settings.rife_binary_path
self.settings.rife_binary_path,
self.settings.rife_model,
self.settings.rife_uhd,
self.settings.rife_tta,
self.settings.practical_rife_model,
self.settings.practical_rife_ensemble
)
results.append(result)

View File

@@ -21,7 +21,8 @@ class BlendMethod(Enum):
"""Blend method types for transitions."""
ALPHA = 'alpha' # Simple cross-dissolve (PIL.Image.blend)
OPTICAL_FLOW = 'optical' # OpenCV Farneback optical flow
RIFE = 'rife' # AI frame interpolation (NCNN binary or PyTorch)
RIFE = 'rife' # AI frame interpolation (NCNN binary)
RIFE_PRACTICAL = 'rife_practical' # Practical-RIFE Python/PyTorch implementation
class FolderType(Enum):
@@ -44,6 +45,12 @@ class TransitionSettings:
trans_destination: Optional[Path] = None # separate destination for transition output
blend_method: BlendMethod = BlendMethod.ALPHA # blending method
rife_binary_path: Optional[Path] = None # path to rife-ncnn-vulkan binary
rife_model: str = 'rife-v4.6' # RIFE model to use
rife_uhd: bool = False # Enable UHD mode for high resolution
rife_tta: bool = False # Enable TTA mode for better quality
# Practical-RIFE settings
practical_rife_model: str = 'v4.25' # v4.25, v4.26, v4.22, etc.
practical_rife_ensemble: bool = False # Ensemble mode for better quality (slower)
@dataclass

368
core/rife_worker.py Normal file
View File

@@ -0,0 +1,368 @@
#!/usr/bin/env python
"""RIFE interpolation worker - runs in isolated venv with PyTorch.
This script is executed via subprocess from the main application.
It handles loading Practical-RIFE models and performing frame interpolation.
Note: The Practical-RIFE models require the IFNet architecture from the
Practical-RIFE repository. This script downloads and uses the model weights
with a simplified inference implementation.
"""
import argparse
import os
import sys
import urllib.request
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) / scale
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat) + feat
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale * 2
mask = tmp[:, 4:5]
return flow, mask
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
backwarp_tenGrid = {}
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
class IFNet(nn.Module):
"""IFNet architecture for RIFE v4.x models."""
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+16, c=192)
self.block1 = IFBlock(8+4+16, c=128)
self.block2 = IFBlock(8+4+16, c=96)
self.block3 = IFBlock(8+4+16, c=64)
self.encode = nn.Sequential(
nn.Conv2d(3, 16, 3, 2, 1),
nn.ConvTranspose2d(16, 4, 4, 2, 1)
)
def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]):
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
None, scale=scale_list[i])
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1),
flow, scale=scale_list[i])
flow = flow + fd
mask = mask + m0
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask_final = torch.sigmoid(mask)
merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final)
return merged_final
# Model URLs for downloading
MODEL_URLS = {
'v4.26': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.26/flownet.pkl',
'v4.25': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.25/flownet.pkl',
'v4.22': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.22/flownet.pkl',
'v4.20': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.20/flownet.pkl',
'v4.18': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.18/flownet.pkl',
'v4.15': 'https://github.com/hzwer/Practical-RIFE/raw/main/train_log_v4.15/flownet.pkl',
}
def download_model(version: str, model_dir: Path) -> Path:
"""Download model if not already cached.
Args:
version: Model version (e.g., 'v4.25').
model_dir: Directory to store models.
Returns:
Path to the downloaded model file.
"""
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f'flownet_{version}.pkl'
if model_path.exists():
return model_path
url = MODEL_URLS.get(version)
if not url:
raise ValueError(f"Unknown model version: {version}")
print(f"Downloading RIFE model {version}...", file=sys.stderr)
try:
req = urllib.request.Request(url, headers={'User-Agent': 'video-montage-linker'})
with urllib.request.urlopen(req, timeout=120) as response:
with open(model_path, 'wb') as f:
f.write(response.read())
print(f"Model downloaded to {model_path}", file=sys.stderr)
return model_path
except Exception as e:
# Clean up partial download
if model_path.exists():
model_path.unlink()
raise RuntimeError(f"Failed to download model: {e}")
def load_model(model_path: Path, device: torch.device) -> IFNet:
"""Load IFNet model from state dict.
Args:
model_path: Path to flownet.pkl file.
device: Device to load model to.
Returns:
Loaded IFNet model.
"""
model = IFNet()
state_dict = torch.load(model_path, map_location='cpu')
# Handle different state dict formats
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Remove 'module.' prefix if present (from DataParallel)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
# Handle flownet. prefix
if k.startswith('flownet.'):
k = k[8:]
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
return model
def pad_image(img: torch.Tensor, padding: int = 64) -> tuple:
"""Pad image to be divisible by padding.
Args:
img: Input tensor (B, C, H, W).
padding: Padding divisor.
Returns:
Tuple of (padded image, (original H, original W)).
"""
_, _, h, w = img.shape
ph = ((h - 1) // padding + 1) * padding
pw = ((w - 1) // padding + 1) * padding
pad_h = ph - h
pad_w = pw - w
padded = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate')
return padded, (h, w)
@torch.no_grad()
def inference(model: IFNet, img0: torch.Tensor, img1: torch.Tensor,
timestep: float = 0.5, ensemble: bool = False) -> torch.Tensor:
"""Perform frame interpolation.
Args:
model: Loaded IFNet model.
img0: First frame tensor (B, C, H, W) normalized to [0, 1].
img1: Second frame tensor (B, C, H, W) normalized to [0, 1].
timestep: Interpolation timestep (0.0 to 1.0).
ensemble: Enable ensemble mode for better quality.
Returns:
Interpolated frame tensor.
"""
# Pad images
img0_padded, orig_size = pad_image(img0)
img1_padded, _ = pad_image(img1)
h, w = orig_size
# Create timestep tensor
timestep_tensor = torch.full((1, 1, img0_padded.shape[2], img0_padded.shape[3]),
timestep, device=img0.device)
if ensemble:
# Ensemble: average of forward and reverse
result1 = model(img0_padded, img1_padded, timestep_tensor)
result2 = model(img1_padded, img0_padded, 1 - timestep_tensor)
result = (result1 + result2) / 2
else:
result = model(img0_padded, img1_padded, timestep_tensor)
# Crop back to original size
result = result[:, :, :h, :w]
return result.clamp(0, 1)
def load_image(path: Path, device: torch.device) -> torch.Tensor:
"""Load image as tensor.
Args:
path: Path to image file.
device: Device to load tensor to.
Returns:
Image tensor (1, 3, H, W) normalized to [0, 1].
"""
img = Image.open(path).convert('RGB')
arr = np.array(img).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(device)
def save_image(tensor: torch.Tensor, path: Path) -> None:
"""Save tensor as image.
Args:
tensor: Image tensor (1, 3, H, W) normalized to [0, 1].
path: Output path.
"""
arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
arr = (arr * 255).clip(0, 255).astype(np.uint8)
Image.fromarray(arr).save(path)
# Global model cache
_model_cache: dict = {}
def get_model(version: str, model_dir: Path, device: torch.device) -> IFNet:
"""Get or load model (cached).
Args:
version: Model version.
model_dir: Model cache directory.
device: Device to run on.
Returns:
IFNet model instance.
"""
cache_key = f"{version}_{device}"
if cache_key not in _model_cache:
model_path = download_model(version, model_dir)
_model_cache[cache_key] = load_model(model_path, device)
return _model_cache[cache_key]
def main():
parser = argparse.ArgumentParser(description='RIFE frame interpolation worker')
parser.add_argument('--input0', required=True, help='Path to first input image')
parser.add_argument('--input1', required=True, help='Path to second input image')
parser.add_argument('--output', required=True, help='Path to output image')
parser.add_argument('--timestep', type=float, default=0.5, help='Interpolation timestep (0-1)')
parser.add_argument('--model', default='v4.25', help='Model version')
parser.add_argument('--model-dir', required=True, help='Model cache directory')
parser.add_argument('--ensemble', action='store_true', help='Enable ensemble mode')
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use')
args = parser.parse_args()
try:
# Select device
if args.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Load model
model_dir = Path(args.model_dir)
model = get_model(args.model, model_dir, device)
# Load images
img0 = load_image(Path(args.input0), device)
img1 = load_image(Path(args.input1), device)
# Interpolate
result = inference(model, img0, img1, args.timestep, args.ensemble)
# Save result
save_image(result, Path(args.output))
print("Success", file=sys.stderr)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@@ -5,8 +5,8 @@ import re
from pathlib import Path
from typing import Optional
from PyQt6.QtCore import Qt, QUrl, QEvent, QPoint
from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QColor
from PyQt6.QtCore import Qt, QUrl, QEvent, QPoint, QTimer
from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QColor, QPainter, QFont, QFontMetrics
from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput
from PyQt6.QtMultimediaWidgets import QVideoWidget
from PyQt6.QtWidgets import (
@@ -38,6 +38,7 @@ from PyQt6.QtWidgets import (
QDialog,
QDialogButtonBox,
QFormLayout,
QCheckBox,
)
from PyQt6.QtGui import QPixmap
@@ -53,11 +54,84 @@ from core import (
DatabaseManager,
TransitionGenerator,
RifeDownloader,
PracticalRifeEnv,
SymlinkManager,
)
from .widgets import TrimSlider
class TimelineTreeWidget(QTreeWidget):
"""QTreeWidget with timeline markers drawn in the background."""
def __init__(self, parent: Optional[QWidget] = None) -> None:
super().__init__(parent)
self.fps = 16
self._text_color = QColor(100, 100, 100)
def set_fps(self, fps: int) -> None:
"""Update FPS for timeline display."""
self.fps = max(1, fps)
self.viewport().update()
def paintEvent(self, event) -> None:
"""Draw timeline markers in background, then call parent paint."""
# Draw the timeline background on the viewport
painter = QPainter(self.viewport())
frame_count = self.topLevelItemCount()
if frame_count > 0 and self.fps > 0:
# Get row height from first visible item
first_item = self.topLevelItem(0)
if first_item:
# Get column positions
col0_width = self.columnWidth(0)
viewport_width = self.viewport().width()
# Font for time labels
font = QFont("Monospace", 9)
painter.setFont(font)
metrics = QFontMetrics(font)
# Draw for each row
for i in range(frame_count):
item = self.topLevelItem(i)
if not item:
continue
item_rect = self.visualItemRect(item)
if item_rect.isNull() or item_rect.bottom() < 0 or item_rect.top() > self.viewport().height():
continue # Not visible
y_center = item_rect.center().y()
# Calculate time for this frame
time_seconds = i / self.fps
is_major = (i % self.fps == 0) # Every second
if is_major:
# Format time
minutes = int(time_seconds // 60)
seconds = int(time_seconds % 60)
if minutes > 0:
time_str = f"{minutes}:{seconds:02d}"
else:
time_str = f"{seconds}s"
text_width = metrics.horizontalAdvance(time_str)
painter.setPen(self._text_color)
# Draw time label on right of column 0
painter.drawText(col0_width - text_width - 6, y_center + metrics.ascent() // 2, time_str)
# Draw time label on right of column 1 (right edge)
painter.drawText(viewport_width - text_width - 6, y_center + metrics.ascent() // 2, time_str)
painter.end()
# Call parent to draw the actual tree content
super().paintEvent(event)
class OverlapDialog(QDialog):
"""Dialog for setting per-transition overlap frames."""
@@ -141,6 +215,8 @@ class SequenceLinkerUI(QWidget):
self._create_layout()
self._connect_signals()
self.setAcceptDrops(True)
# Initialize sequence table FPS
self.sequence_table.set_fps(self.fps_spin.value())
def _setup_window(self) -> None:
"""Configure the main window properties."""
@@ -260,14 +336,15 @@ class SequenceLinkerUI(QWidget):
self._current_pixmap: Optional[QPixmap] = None
self._pan_start = None
self._pan_scrollbar_start = None
self._blend_preview_cache: dict[str, QPixmap] = {} # Cache for generated blend frames
# Trim slider
self.trim_slider = TrimSlider()
self.trim_label = QLabel("Frames: All included")
self.trim_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
# Sequence table (2-column: Main Frame | Transition Frame)
self.sequence_table = QTreeWidget()
# Sequence table (2-column: Main Frame | Transition Frame) with timeline background
self.sequence_table = TimelineTreeWidget()
self.sequence_table.setHeaderLabels(["Main Frame", "Transition Frame"])
self.sequence_table.setColumnCount(2)
self.sequence_table.setRootIsDecorated(False)
@@ -317,12 +394,14 @@ class SequenceLinkerUI(QWidget):
self.blend_method_combo = QComboBox()
self.blend_method_combo.addItem("Cross-Dissolve", BlendMethod.ALPHA)
self.blend_method_combo.addItem("Optical Flow", BlendMethod.OPTICAL_FLOW)
self.blend_method_combo.addItem("RIFE (AI)", BlendMethod.RIFE)
self.blend_method_combo.addItem("RIFE (ncnn)", BlendMethod.RIFE)
self.blend_method_combo.addItem("RIFE (Practical)", BlendMethod.RIFE_PRACTICAL)
self.blend_method_combo.setToolTip(
"Blending method:\n"
"- Cross-Dissolve: Simple alpha blend (fast, may ghost)\n"
"- Optical Flow: Motion-compensated blend (slower, less ghosting)\n"
"- RIFE: AI frame interpolation (best quality, requires rife-ncnn-vulkan)"
"- RIFE (ncnn): AI frame interpolation (fast, Vulkan GPU, models up to v4.6)\n"
"- RIFE (Practical): AI frame interpolation (PyTorch, latest models v4.25/v4.26)"
)
# RIFE binary path
@@ -338,6 +417,77 @@ class SequenceLinkerUI(QWidget):
self.rife_path_btn.setVisible(False)
self.rife_download_btn.setVisible(False)
# RIFE model selection
self.rife_model_label = QLabel("Model:")
self.rife_model_combo = QComboBox()
self.rife_model_combo.addItem("v4.6 (Best)", "rife-v4.6")
self.rife_model_combo.addItem("v4", "rife-v4")
self.rife_model_combo.addItem("v3.1", "rife-v3.1")
self.rife_model_combo.addItem("v2.4", "rife-v2.4")
self.rife_model_combo.addItem("Anime", "rife-anime")
self.rife_model_combo.addItem("UHD", "rife-UHD")
self.rife_model_combo.addItem("HD", "rife-HD")
self.rife_model_combo.setToolTip("RIFE model version:\n- v4.6: Latest, best quality\n- Anime: Optimized for animation\n- UHD/HD: For high resolution content")
self.rife_model_label.setVisible(False)
self.rife_model_combo.setVisible(False)
# RIFE UHD mode
self.rife_uhd_check = QCheckBox("UHD")
self.rife_uhd_check.setToolTip("Enable UHD mode for high resolution images (4K+)")
self.rife_uhd_check.setVisible(False)
# RIFE TTA mode
self.rife_tta_check = QCheckBox("TTA")
self.rife_tta_check.setToolTip("Enable TTA (Test-Time Augmentation) for better quality (slower)")
self.rife_tta_check.setVisible(False)
# Practical-RIFE settings
self.practical_model_label = QLabel("Model:")
self.practical_model_combo = QComboBox()
self.practical_model_combo.addItem("v4.26 (Latest)", "v4.26")
self.practical_model_combo.addItem("v4.25 (Recommended)", "v4.25")
self.practical_model_combo.addItem("v4.22", "v4.22")
self.practical_model_combo.addItem("v4.20", "v4.20")
self.practical_model_combo.addItem("v4.18", "v4.18")
self.practical_model_combo.addItem("v4.15", "v4.15")
self.practical_model_combo.setCurrentIndex(1) # Default to v4.25
self.practical_model_combo.setToolTip(
"Practical-RIFE model version:\n"
"- v4.26: Latest version\n"
"- v4.25: Recommended, good balance of quality and speed"
)
self.practical_model_label.setVisible(False)
self.practical_model_combo.setVisible(False)
self.practical_ensemble_check = QCheckBox("Ensemble")
self.practical_ensemble_check.setToolTip("Enable ensemble mode for better quality (slower)")
self.practical_ensemble_check.setVisible(False)
self.practical_setup_btn = QPushButton("Setup PyTorch")
self.practical_setup_btn.setToolTip("Create local venv and install PyTorch (~2GB download)")
self.practical_setup_btn.setVisible(False)
self.practical_status_label = QLabel("")
self.practical_status_label.setStyleSheet("color: gray; font-size: 10px;")
self.practical_status_label.setVisible(False)
# FPS setting for sequence playback and timeline
self.fps_label = QLabel("FPS:")
self.fps_spin = QSpinBox()
self.fps_spin.setRange(1, 120)
self.fps_spin.setValue(16)
self.fps_spin.setToolTip("Frames per second for sequence preview and timeline")
# Timeline duration label
self.timeline_label = QLabel("Duration: 00:00.000 (0 frames)")
self.timeline_label.setStyleSheet("font-family: monospace;")
# Sequence playback button and timer
self.seq_play_btn = QPushButton("▶ Play")
self.seq_play_btn.setToolTip("Play image sequence at configured FPS")
self.sequence_timer = QTimer(self)
self.sequence_playing = False
def _create_layout(self) -> None:
"""Arrange widgets in layouts."""
# === LEFT SIDE PANEL: Source Folders ===
@@ -397,6 +547,19 @@ class SequenceLinkerUI(QWidget):
transition_layout.addWidget(self.rife_path_input)
transition_layout.addWidget(self.rife_path_btn)
transition_layout.addWidget(self.rife_download_btn)
transition_layout.addWidget(self.rife_model_label)
transition_layout.addWidget(self.rife_model_combo)
transition_layout.addWidget(self.rife_uhd_check)
transition_layout.addWidget(self.rife_tta_check)
transition_layout.addWidget(self.practical_model_label)
transition_layout.addWidget(self.practical_model_combo)
transition_layout.addWidget(self.practical_ensemble_check)
transition_layout.addWidget(self.practical_setup_btn)
transition_layout.addWidget(self.practical_status_label)
transition_layout.addWidget(self.fps_label)
transition_layout.addWidget(self.fps_spin)
transition_layout.addWidget(self.timeline_label)
transition_layout.addWidget(self.seq_play_btn)
transition_layout.addStretch()
self.transition_group.setLayout(transition_layout)
@@ -459,11 +622,13 @@ class SequenceLinkerUI(QWidget):
sequence_order_layout.addWidget(self.file_list)
self.sequence_tabs.addTab(sequence_order_tab, "Sequence Order")
# Tab 2: With Transitions (2-column view)
# Tab 2: With Transitions (2-column view with timeline rulers)
trans_sequence_tab = QWidget()
trans_sequence_layout = QVBoxLayout(trans_sequence_tab)
trans_sequence_layout.setContentsMargins(0, 0, 0, 0)
trans_sequence_layout.addWidget(self.sequence_table)
self.sequence_tabs.addTab(trans_sequence_tab, "With Transitions")
file_list_layout.addWidget(self.sequence_tabs)
@@ -555,9 +720,18 @@ class SequenceLinkerUI(QWidget):
# Blend method combo change - show/hide RIFE path
self.blend_method_combo.currentIndexChanged.connect(self._on_blend_method_changed)
self.curve_combo.currentIndexChanged.connect(self._clear_blend_cache)
self.rife_model_combo.currentIndexChanged.connect(self._clear_blend_cache)
self.rife_uhd_check.stateChanged.connect(self._clear_blend_cache)
self.rife_tta_check.stateChanged.connect(self._clear_blend_cache)
self.rife_path_btn.clicked.connect(self._browse_rife_binary)
self.rife_download_btn.clicked.connect(self._download_rife_binary)
# Practical-RIFE signals
self.practical_model_combo.currentIndexChanged.connect(self._clear_blend_cache)
self.practical_ensemble_check.stateChanged.connect(self._clear_blend_cache)
self.practical_setup_btn.clicked.connect(self._setup_practical_rife)
# Sequence table selection - show image
self.sequence_table.currentItemChanged.connect(self._on_sequence_table_selected)
@@ -567,6 +741,14 @@ class SequenceLinkerUI(QWidget):
# Update sequence table when switching to "With Transitions" tab
self.sequence_tabs.currentChanged.connect(self._on_sequence_tab_changed)
# FPS and sequence playback signals
self.fps_spin.valueChanged.connect(self._update_timeline_display)
self.seq_play_btn.clicked.connect(self._toggle_sequence_play)
self.sequence_timer.timeout.connect(self._advance_sequence_frame)
# Update sequence table FPS when spinner changes
self.fps_spin.valueChanged.connect(self.sequence_table.set_fps)
def _on_format_changed(self, index: int) -> None:
"""Handle format combo change to show/hide quality/method widgets."""
fmt = self.blend_format_combo.currentData()
@@ -589,15 +771,39 @@ class SequenceLinkerUI(QWidget):
def _on_blend_method_changed(self, index: int) -> None:
"""Handle blend method combo change to show/hide RIFE path widgets."""
method = self.blend_method_combo.currentData()
is_rife = (method == BlendMethod.RIFE)
self.rife_path_label.setVisible(is_rife)
self.rife_path_input.setVisible(is_rife)
self.rife_path_btn.setVisible(is_rife)
self.rife_download_btn.setVisible(is_rife)
is_rife_ncnn = (method == BlendMethod.RIFE)
is_rife_practical = (method == BlendMethod.RIFE_PRACTICAL)
if is_rife:
# RIFE ncnn settings
self.rife_path_label.setVisible(is_rife_ncnn)
self.rife_path_input.setVisible(is_rife_ncnn)
self.rife_path_btn.setVisible(is_rife_ncnn)
self.rife_download_btn.setVisible(is_rife_ncnn)
self.rife_model_label.setVisible(is_rife_ncnn)
self.rife_model_combo.setVisible(is_rife_ncnn)
self.rife_uhd_check.setVisible(is_rife_ncnn)
self.rife_tta_check.setVisible(is_rife_ncnn)
# Practical-RIFE settings
self.practical_model_label.setVisible(is_rife_practical)
self.practical_model_combo.setVisible(is_rife_practical)
self.practical_ensemble_check.setVisible(is_rife_practical)
self.practical_setup_btn.setVisible(is_rife_practical)
self.practical_status_label.setVisible(is_rife_practical)
if is_rife_ncnn:
self._update_rife_download_button()
if is_rife_practical:
self._update_practical_rife_status()
# Clear blend preview cache when method changes
self._blend_preview_cache.clear()
def _clear_blend_cache(self) -> None:
"""Clear the blend preview cache."""
self._blend_preview_cache.clear()
def _browse_rife_binary(self) -> None:
"""Browse for RIFE binary."""
start_dir = self.last_directory or ""
@@ -743,6 +949,94 @@ class SequenceLinkerUI(QWidget):
)
self._update_rife_download_button()
def _update_practical_rife_status(self) -> None:
"""Update the Practical-RIFE status label and setup button."""
if PracticalRifeEnv.is_setup():
torch_version = PracticalRifeEnv.get_torch_version()
if torch_version:
self.practical_status_label.setText(f"Ready (PyTorch {torch_version})")
self.practical_status_label.setStyleSheet("color: green; font-size: 10px;")
else:
self.practical_status_label.setText("Ready")
self.practical_status_label.setStyleSheet("color: green; font-size: 10px;")
self.practical_setup_btn.setText("Reinstall")
self.practical_setup_btn.setToolTip("Reinstall PyTorch environment")
self.practical_model_combo.setEnabled(True)
self.practical_ensemble_check.setEnabled(True)
else:
self.practical_status_label.setText("Not configured")
self.practical_status_label.setStyleSheet("color: orange; font-size: 10px;")
self.practical_setup_btn.setText("Setup PyTorch")
self.practical_setup_btn.setToolTip("Create local venv and install PyTorch (~2GB download)")
self.practical_model_combo.setEnabled(False)
self.practical_ensemble_check.setEnabled(False)
def _setup_practical_rife(self) -> None:
"""Setup Practical-RIFE environment with progress dialog."""
# Confirm if already setup
if PracticalRifeEnv.is_setup():
reply = QMessageBox.question(
self, "Reinstall PyTorch?",
"PyTorch environment is already set up.\n"
"Do you want to reinstall it?",
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No
)
if reply != QMessageBox.StandardButton.Yes:
return
# Create progress dialog
progress = QProgressDialog(
"Setting up PyTorch environment...", "Cancel", 0, 100, self
)
progress.setWindowTitle("Setup Practical-RIFE")
progress.setWindowModality(Qt.WindowModality.WindowModal)
progress.setMinimumDuration(0)
progress.setValue(0)
progress.show()
# Progress callback
def progress_callback(message, percent):
if not progress.wasCanceled():
progress.setLabelText(message)
progress.setValue(percent)
QApplication.processEvents()
def cancelled_check():
QApplication.processEvents()
return progress.wasCanceled()
try:
success = PracticalRifeEnv.setup_venv(progress_callback, cancelled_check)
progress.close()
if progress.wasCanceled():
self._update_practical_rife_status()
return
if success:
QMessageBox.information(
self, "Setup Complete",
"PyTorch environment set up successfully!\n\n"
f"Location: {PracticalRifeEnv.VENV_DIR}\n\n"
"You can now use RIFE (Practical) for frame interpolation."
)
else:
QMessageBox.warning(
self, "Setup Failed",
"Failed to set up PyTorch environment.\n"
"Check your internet connection and try again."
)
self._update_practical_rife_status()
except Exception as e:
progress.close()
QMessageBox.critical(
self, "Setup Error",
f"Error setting up PyTorch: {e}"
)
self._update_practical_rife_status()
def _on_sequence_tab_changed(self, index: int) -> None:
"""Handle sequence tab change to update the With Transitions view."""
if index == 1: # "With Transitions" tab
@@ -753,10 +1047,12 @@ class SequenceLinkerUI(QWidget):
self.sequence_table.clear()
if not self.source_folders:
self._update_timeline_display()
return
files = self._get_files_in_order()
if not files:
self._update_timeline_display()
return
# Group files by folder
@@ -774,6 +1070,7 @@ class SequenceLinkerUI(QWidget):
item = QTreeWidgetItem([f"{seq_name} ({filename})", ""])
item.setData(0, Qt.ItemDataRole.UserRole, (source_dir, filename, folder_idx, file_idx, 'symlink'))
self.sequence_table.addTopLevelItem(item)
self._update_timeline_display()
return
# Get transition specs
@@ -856,6 +1153,9 @@ class SequenceLinkerUI(QWidget):
self.sequence_table.addTopLevelItem(item)
# Update timeline display after rebuilding sequence table
self._update_timeline_display()
def _on_sequence_table_selected(self, current, previous) -> None:
"""Handle sequence table row selection - show image in preview."""
if current is None:
@@ -921,20 +1221,6 @@ class SequenceLinkerUI(QWidget):
return
try:
# Load images
img_a = Image.open(main_path)
img_b = Image.open(trans_path)
# Resize B to match A if needed
if img_a.size != img_b.size:
img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS)
# Convert to RGBA
if img_a.mode != 'RGBA':
img_a = img_a.convert('RGBA')
if img_b.mode != 'RGBA':
img_b = img_b.convert('RGBA')
# Calculate blend factor based on position in sequence table
# Find this frame's position in the blend sequence
row_idx = self.sequence_table.indexOfTopLevelItem(item)
@@ -970,17 +1256,51 @@ class SequenceLinkerUI(QWidget):
blend_position, blend_count, settings.blend_curve
)
# Blend images using selected method
if settings.blend_method == BlendMethod.OPTICAL_FLOW:
blended = ImageBlender.optical_flow_blend(img_a, img_b, factor)
elif settings.blend_method == BlendMethod.RIFE:
blended = ImageBlender.rife_blend(img_a, img_b, factor, settings.rife_binary_path)
else:
blended = Image.blend(img_a, img_b, factor)
# Create cache key (include RIFE settings when using RIFE)
cache_key = f"{main_path}|{trans_path}|{factor:.6f}|{settings.blend_method.value}|{settings.blend_curve.value}"
if settings.blend_method == BlendMethod.RIFE:
cache_key += f"|{settings.rife_model}|{settings.rife_uhd}|{settings.rife_tta}"
# Convert to QPixmap
qim = ImageQt(blended.convert('RGBA'))
pixmap = QPixmap.fromImage(qim)
# Check cache first
if cache_key in self._blend_preview_cache:
pixmap = self._blend_preview_cache[cache_key]
else:
# Load images
img_a = Image.open(main_path)
img_b = Image.open(trans_path)
# Resize B to match A if needed
if img_a.size != img_b.size:
img_b = img_b.resize(img_a.size, Image.Resampling.LANCZOS)
# Convert to RGBA
if img_a.mode != 'RGBA':
img_a = img_a.convert('RGBA')
if img_b.mode != 'RGBA':
img_b = img_b.convert('RGBA')
# Blend images using selected method
if settings.blend_method == BlendMethod.OPTICAL_FLOW:
blended = ImageBlender.optical_flow_blend(img_a, img_b, factor)
elif settings.blend_method == BlendMethod.RIFE:
blended = ImageBlender.rife_blend(
img_a, img_b, factor, settings.rife_binary_path,
model=settings.rife_model,
uhd=settings.rife_uhd,
tta=settings.rife_tta
)
else:
blended = Image.blend(img_a, img_b, factor)
# Convert to QPixmap
qim = ImageQt(blended.convert('RGBA'))
pixmap = QPixmap.fromImage(qim)
# Store in cache
self._blend_preview_cache[cache_key] = pixmap
img_a.close()
img_b.close()
self._current_pixmap = pixmap
self._apply_zoom()
@@ -990,14 +1310,77 @@ class SequenceLinkerUI(QWidget):
seq_name = f"seq{data0[2] + 1:02d}_{data0[3]:04d}"
self.image_name_label.setText(f"[B] {seq_name} ({main_file} + {trans_file}) @ {factor:.0%}")
img_a.close()
img_b.close()
except Exception as e:
self.image_label.setText(f"Error generating blend preview:\n{e}")
self.image_name_label.setText("")
self._current_pixmap = None
def _update_timeline_display(self) -> None:
"""Update the timeline duration display based on frame count and FPS."""
frame_count = self.sequence_table.topLevelItemCount()
fps = self.fps_spin.value()
if fps > 0 and frame_count > 0:
total_seconds = frame_count / fps
minutes = int(total_seconds // 60)
seconds = total_seconds % 60
self.timeline_label.setText(
f"Duration: {minutes:02d}:{seconds:06.3f} ({frame_count} frames @ {fps}fps)"
)
else:
self.timeline_label.setText("Duration: 00:00.000 (0 frames)")
# Refresh the sequence table to update timeline background
self.sequence_table.viewport().update()
def _toggle_sequence_play(self) -> None:
"""Toggle sequence playback."""
if self.sequence_playing:
self._stop_sequence_play()
else:
self._start_sequence_play()
def _start_sequence_play(self) -> None:
"""Start playing the image sequence."""
if self.sequence_table.topLevelItemCount() == 0:
return
fps = self.fps_spin.value()
interval = int(1000 / fps) # milliseconds per frame
self.sequence_timer.setInterval(interval)
self.sequence_timer.start()
self.sequence_playing = True
self.seq_play_btn.setText("⏸ Pause")
# If no item selected, start from first
if self.sequence_table.currentItem() is None:
first_item = self.sequence_table.topLevelItem(0)
if first_item:
self.sequence_table.setCurrentItem(first_item)
def _stop_sequence_play(self) -> None:
"""Stop sequence playback."""
self.sequence_timer.stop()
self.sequence_playing = False
self.seq_play_btn.setText("▶ Play")
def _advance_sequence_frame(self) -> None:
"""Advance to next frame in sequence."""
current_item = self.sequence_table.currentItem()
if current_item is None:
self._stop_sequence_play()
return
current_idx = self.sequence_table.indexOfTopLevelItem(current_item)
total = self.sequence_table.topLevelItemCount()
if current_idx < total - 1:
next_item = self.sequence_table.topLevelItem(current_idx + 1)
self.sequence_table.setCurrentItem(next_item)
else:
# Reached end - stop playback
self._stop_sequence_play()
def _browse_trans_destination(self) -> None:
"""Select transition destination folder via file dialog."""
start_dir = self.last_directory or ""
@@ -1512,7 +1895,12 @@ class SequenceLinkerUI(QWidget):
output_quality=self.blend_quality_spin.value(),
trans_destination=trans_dest,
blend_method=self.blend_method_combo.currentData(),
rife_binary_path=rife_path
rife_binary_path=rife_path,
rife_model=self.rife_model_combo.currentData(),
rife_uhd=self.rife_uhd_check.isChecked(),
rife_tta=self.rife_tta_check.isChecked(),
practical_rife_model=self.practical_model_combo.currentData(),
practical_rife_ensemble=self.practical_ensemble_check.isChecked()
)
def _refresh_files(self, select_position: str = 'first') -> None:
@@ -1762,6 +2150,9 @@ class SequenceLinkerUI(QWidget):
video_path = self.video_combo.currentData()
if video_path and isinstance(video_path, Path) and video_path.exists():
self.media_player.setSource(QUrl.fromLocalFile(str(video_path)))
# Play and immediately pause to show first frame
self.media_player.play()
self.media_player.pause()
def _toggle_play(self) -> None:
"""Toggle play/pause state."""