Compare commits

...

7 Commits

Author SHA1 Message Date
d26265a115 rife 2026-02-03 23:26:07 +01:00
00f0141b15 rife 2026-02-03 23:21:31 +01:00
2c6ad4ff35 Add Practical-RIFE frame interpolation support
Implement standalone PyTorch-based RIFE interpolation that runs in a
dedicated virtual environment to avoid Qt/OpenCV conflicts:

- Add PracticalRifeEnv class for managing venv and subprocess execution
- Add rife_worker.py standalone interpolation script using Practical-RIFE
- Add RIFE_PRACTICAL blending model with ensemble/fast mode settings
- Add UI controls for Practical-RIFE configuration
- Update .gitignore to exclude venv-rife/ directory

The implementation downloads Practical-RIFE models on first use and runs
interpolation in a separate process with proper progress reporting.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:46:06 +01:00
6bfbefb058 Update README with comprehensive feature documentation
- Document cross-dissolve transitions and blend methods
- Add RIFE auto-download instructions
- Document per-folder trim and per-transition overlap
- Add file structure diagram
- Update installation requirements
- Expand supported formats list

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:51:50 +01:00
bdddce910c Restructure into multi-file architecture
Split monolithic symlink.py into modular components:
- config.py: Constants and configuration
- core/: Models, database, blender, manager
- ui/: Main window and widgets

New features included:
- Cross-dissolve transitions with multiple blend methods
- Alpha blend, Optical Flow, and RIFE (AI) interpolation
- Per-folder trim settings with start/end frame control
- Per-transition asymmetric overlap settings
- Folder type overrides (Main/Transition)
- Dual destination folders (sequence + transitions)
- WebP lossless output with compression method setting
- Video and image sequence preview with zoom/pan
- Session resume from destination folder

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:49:51 +01:00
99858bcfe8 Add alternating row colors to source folder list
Improves readability of the source folder panel.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:49:39 +01:00
716ff34062 cool 2026-02-03 16:21:56 +01:00
13 changed files with 5786 additions and 779 deletions

60
.gitignore vendored
View File

@@ -1,4 +1,64 @@
# Python
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
*.so
# Virtual environments
venv/
venv-rife/
.venv/
env/
# Environment files
.env
.env.local
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Database
*.db
*.sqlite
*.sqlite3
# Downloads and cache
*.pkl
*.pt
*.pth
*.onnx
downloads/
cache/
.cache/
# RIFE binaries and models
rife-ncnn-vulkan*/
*.zip
# Output directories
output/
outputs/
temp/
tmp/
# Logs
*.log
logs/
# OS files
.DS_Store
Thumbs.db
# Build artifacts
dist/
build/
*.egg-info/
# Git mirror scripts
gitea-push-mirror-setup

134
README.md
View File

@@ -1,97 +1,141 @@
# Video Montage Linker
A PyQt6 application to create sequenced symlinks for image files. Useful for preparing image sequences for video editing or montage creation.
A PyQt6 application for creating sequenced symlinks from image folders with advanced cross-dissolve transitions. Perfect for preparing image sequences for video editing, time-lapse assembly, or montage creation.
## Features
### Multiple Source Folders
### Source Folder Management
- Add multiple source folders to merge images from different locations
- Files are ordered by folder (first added = first in sequence), then alphabetically within each folder
- Drag & drop folders directly onto the window to add them
- Multi-select support for removing folders
- Drag & drop folders directly onto the window
- Alternating folder types: odd positions = **Main**, even positions = **Transition**
- Override folder types via right-click context menu
- Reorder folders with up/down buttons
- Per-folder trim settings (exclude frames from start/end)
### File Management
- Two-column view showing filename and source path
- Drag & drop to reorder files in the sequence
- Multi-select files (Ctrl+click, Shift+click)
- Remove files with Delete key or "Remove Files" button
- Refresh to rescan source folders
### Cross-Dissolve Transitions
Smooth blending between folder boundaries with three blend methods:
### Symlink Creation
- Creates numbered symlinks (`seq_0000.png`, `seq_0001.png`, etc.)
- Uses relative paths for portability
- Automatically cleans up old `seq_*` links before creating new ones
| Method | Description | Quality | Speed |
|--------|-------------|---------|-------|
| **Cross-Dissolve** | Simple alpha blend | Good | Fastest |
| **Optical Flow** | Motion-compensated blend using OpenCV Farneback | Better | Medium |
| **RIFE (AI)** | Neural network frame interpolation | Best | Fast (GPU) |
### Session Tracking
- SQLite database tracks all symlink sessions
- Located at `~/.config/video-montage-linker/symlinks.db`
- List past sessions and clean up by destination
- **Asymmetric overlap**: Set different frame counts for each side of a transition
- **Blend curves**: Linear, Ease In, Ease Out, Ease In/Out
- **Output formats**: PNG, JPEG (with quality), WebP (lossless with method setting)
- **RIFE auto-download**: Automatically downloads rife-ncnn-vulkan binary
### Supported Formats
- PNG, WEBP, JPG, JPEG
### Preview
- **Video Preview**: Play video files from source folders
- **Image Sequence Preview**: Browse frames with zoom (scroll wheel) and pan (drag)
- **Sequence Table**: 2-column view showing Main/Transition frame pairing
- **Trim Slider**: Visual frame range selection per folder
### Dual Export Destinations
- **Sequence destination**: Regular symlinks only
- **Transition destination**: Symlinks + blended transition frames
### Session Persistence
- SQLite database tracks all sessions and settings
- Resume previous session by selecting the same destination folder
- Restores: source folders, trim settings, folder types, transition settings, per-transition overlaps
## Installation
Requires Python 3 and PyQt6:
### Requirements
- Python 3.10+
- PyQt6
- Pillow
- NumPy
- OpenCV (optional, for Optical Flow)
```bash
pip install PyQt6
pip install PyQt6 Pillow numpy opencv-python
```
### RIFE (Optional)
For AI-powered frame interpolation, the app can auto-download [rife-ncnn-vulkan](https://github.com/nihui/rife-ncnn-vulkan) or you can install it manually:
- Select **RIFE (AI)** as the blend method
- Click **Download** to fetch the latest release
- Or specify a custom binary path
## Usage
### GUI Mode
```bash
# Launch the graphical interface
python symlink.py
python symlink.py --gui
python symlink.py # Launch GUI (default)
python symlink.py --gui # Explicit GUI launch
```
1. Click "Add Folder" or drag & drop folders to add source directories
2. Reorder files by dragging them in the list
3. Remove unwanted files (select + Delete key)
4. Select destination folder
5. Click "Generate Virtual Sequence"
**Workflow:**
1. Add source folders (drag & drop or click "Add Folder")
2. Adjust trim settings per folder if needed (right-click or use trim slider)
3. Set destination folder(s)
4. Enable transitions and configure blend method/settings
5. Click **Export Sequence** or **Export with Transitions**
### CLI Mode
```bash
# Create symlinks from a single source
python symlink.py --src /path/to/images --dst /path/to/dest
# Merge multiple source folders
python symlink.py --src /folder1 --src /folder2 --dst /path/to/dest
# Create symlinks from source folders
python symlink.py --src /path/to/folder1 --src /path/to/folder2 --dst /path/to/dest
# List all tracked sessions
python symlink.py --list
# Clean up symlinks and remove session record
# Clean up symlinks and remove session
python symlink.py --clean /path/to/dest
```
## File Structure
```
video-montage-linker/
├── symlink.py # Entry point, CLI
├── config.py # Constants, paths
├── core/
│ ├── models.py # Enums, dataclasses
│ ├── database.py # SQLite session management
│ ├── blender.py # Image blending, RIFE downloader
│ └── manager.py # Symlink operations
└── ui/
├── widgets.py # TrimSlider, custom widgets
└── main_window.py # Main application window
```
## Supported Formats
**Images:** PNG, WEBP, JPG, JPEG, TIFF, BMP, EXR
**Videos (preview only):** MP4, MOV, AVI, MKV, WEBM
## Database
Session data stored at: `~/.config/video-montage-linker/symlinks.db`
## System Installation (Linux)
To add as a system application:
```bash
# Make executable and add to PATH
# Make executable
chmod +x symlink.py
ln -s /path/to/symlink.py ~/.local/bin/video-montage-linker
# Add to PATH
ln -s /full/path/to/symlink.py ~/.local/bin/video-montage-linker
# Create desktop entry
cat > ~/.local/share/applications/video-montage-linker.desktop << 'EOF'
[Desktop Entry]
Name=Video Montage Linker
Comment=Create sequenced symlinks for image files
Exec=/path/to/symlink.py
Comment=Create sequenced symlinks with cross-dissolve transitions
Exec=/full/path/to/symlink.py
Icon=emblem-symbolic-link
Terminal=false
Type=Application
Categories=Utility;Graphics;
Categories=Utility;Graphics;AudioVideo;
EOF
# Update desktop database
update-desktop-database ~/.local/share/applications/
```

10
config.py Normal file
View File

@@ -0,0 +1,10 @@
"""Configuration constants for Video Montage Linker."""
from pathlib import Path
# Supported file extensions
SUPPORTED_EXTENSIONS = ('.png', '.webp', '.jpg', '.jpeg')
VIDEO_EXTENSIONS = ('.mp4', '.webm', '.mkv', '.avi', '.mov', '.wmv', '.flv', '.m4v')
# Database path
DB_PATH = Path.home() / '.config' / 'video-montage-linker' / 'symlinks.db'

48
core/__init__.py Normal file
View File

@@ -0,0 +1,48 @@
"""Core modules for Video Montage Linker."""
from .models import (
BlendCurve,
BlendMethod,
FolderType,
TransitionSettings,
PerTransitionSettings,
BlendResult,
TransitionSpec,
LinkResult,
SymlinkRecord,
SessionRecord,
SymlinkError,
PathValidationError,
SourceNotFoundError,
DestinationError,
CleanupError,
DatabaseError,
)
from .database import DatabaseManager
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv
from .manager import SymlinkManager
__all__ = [
'BlendCurve',
'BlendMethod',
'FolderType',
'TransitionSettings',
'PerTransitionSettings',
'BlendResult',
'TransitionSpec',
'LinkResult',
'SymlinkRecord',
'SessionRecord',
'SymlinkError',
'PathValidationError',
'SourceNotFoundError',
'DestinationError',
'CleanupError',
'DatabaseError',
'DatabaseManager',
'ImageBlender',
'TransitionGenerator',
'RifeDownloader',
'PracticalRifeEnv',
'SymlinkManager',
]

1253
core/blender.py Normal file

File diff suppressed because it is too large Load Diff

616
core/database.py Normal file
View File

@@ -0,0 +1,616 @@
"""Database management for Video Montage Linker."""
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional
from config import DB_PATH
from .models import (
BlendCurve,
BlendMethod,
FolderType,
TransitionSettings,
PerTransitionSettings,
SymlinkRecord,
SessionRecord,
DatabaseError,
)
class DatabaseManager:
"""Manages SQLite database for tracking symlink sessions and links."""
def __init__(self, db_path: Path = DB_PATH) -> None:
"""Initialize database manager.
Args:
db_path: Path to the SQLite database file.
"""
self.db_path = db_path
self._ensure_db_exists()
def _ensure_db_exists(self) -> None:
"""Create database and tables if they don't exist."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS symlink_sessions (
id INTEGER PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
destination TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS symlinks (
id INTEGER PRIMARY KEY,
session_id INTEGER REFERENCES symlink_sessions(id) ON DELETE CASCADE,
source_path TEXT NOT NULL,
link_path TEXT NOT NULL,
original_filename TEXT NOT NULL,
sequence_number INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS sequence_trim_settings (
id INTEGER PRIMARY KEY,
session_id INTEGER REFERENCES symlink_sessions(id) ON DELETE CASCADE,
source_folder TEXT NOT NULL,
trim_start INTEGER DEFAULT 0,
trim_end INTEGER DEFAULT 0,
folder_type TEXT DEFAULT 'auto',
UNIQUE(session_id, source_folder)
);
CREATE TABLE IF NOT EXISTS transition_settings (
id INTEGER PRIMARY KEY,
session_id INTEGER REFERENCES symlink_sessions(id) ON DELETE CASCADE,
enabled INTEGER DEFAULT 0,
blend_curve TEXT DEFAULT 'linear',
output_format TEXT DEFAULT 'png',
webp_method INTEGER DEFAULT 4,
output_quality INTEGER DEFAULT 95,
trans_destination TEXT,
UNIQUE(session_id)
);
CREATE TABLE IF NOT EXISTS per_transition_settings (
id INTEGER PRIMARY KEY,
session_id INTEGER REFERENCES symlink_sessions(id) ON DELETE CASCADE,
trans_folder TEXT NOT NULL,
left_overlap INTEGER DEFAULT 16,
right_overlap INTEGER DEFAULT 16,
UNIQUE(session_id, trans_folder)
);
""")
# Migration: add folder_type column if it doesn't exist
try:
conn.execute("SELECT folder_type FROM sequence_trim_settings LIMIT 1")
except sqlite3.OperationalError:
conn.execute("ALTER TABLE sequence_trim_settings ADD COLUMN folder_type TEXT DEFAULT 'auto'")
# Migration: add webp_method column if it doesn't exist
try:
conn.execute("SELECT webp_method FROM transition_settings LIMIT 1")
except sqlite3.OperationalError:
conn.execute("ALTER TABLE transition_settings ADD COLUMN webp_method INTEGER DEFAULT 4")
# Migration: add trans_destination column if it doesn't exist
try:
conn.execute("SELECT trans_destination FROM transition_settings LIMIT 1")
except sqlite3.OperationalError:
conn.execute("ALTER TABLE transition_settings ADD COLUMN trans_destination TEXT")
# Migration: add blend_method column if it doesn't exist
try:
conn.execute("SELECT blend_method FROM transition_settings LIMIT 1")
except sqlite3.OperationalError:
conn.execute("ALTER TABLE transition_settings ADD COLUMN blend_method TEXT DEFAULT 'alpha'")
# Migration: add rife_binary_path column if it doesn't exist
try:
conn.execute("SELECT rife_binary_path FROM transition_settings LIMIT 1")
except sqlite3.OperationalError:
conn.execute("ALTER TABLE transition_settings ADD COLUMN rife_binary_path TEXT")
# Migration: remove overlap_frames from transition_settings (now per-transition)
# We'll keep it for backward compatibility but won't use it
def _connect(self) -> sqlite3.Connection:
"""Create a database connection with foreign keys enabled."""
conn = sqlite3.connect(self.db_path)
conn.execute("PRAGMA foreign_keys = ON")
return conn
def create_session(self, destination: str) -> int:
"""Create a new linking session.
Args:
destination: The destination directory path.
Returns:
The ID of the created session.
Raises:
DatabaseError: If session creation fails.
"""
try:
with self._connect() as conn:
cursor = conn.execute(
"INSERT INTO symlink_sessions (destination) VALUES (?)",
(destination,)
)
return cursor.lastrowid
except sqlite3.Error as e:
raise DatabaseError(f"Failed to create session: {e}") from e
def record_symlink(
self,
session_id: int,
source: str,
link: str,
filename: str,
seq: int
) -> int:
"""Record a created symlink.
Args:
session_id: The session this symlink belongs to.
source: Full path to the source file.
link: Full path to the created symlink.
filename: Original filename.
seq: Sequence number in the destination.
Returns:
The ID of the created record.
Raises:
DatabaseError: If recording fails.
"""
try:
with self._connect() as conn:
cursor = conn.execute(
"""INSERT INTO symlinks
(session_id, source_path, link_path, original_filename, sequence_number)
VALUES (?, ?, ?, ?, ?)""",
(session_id, source, link, filename, seq)
)
return cursor.lastrowid
except sqlite3.Error as e:
raise DatabaseError(f"Failed to record symlink: {e}") from e
def get_sessions(self) -> list[SessionRecord]:
"""List all sessions with link counts.
Returns:
List of session records.
"""
with self._connect() as conn:
rows = conn.execute("""
SELECT s.id, s.created_at, s.destination, COUNT(l.id) as link_count
FROM symlink_sessions s
LEFT JOIN symlinks l ON s.id = l.session_id
GROUP BY s.id
ORDER BY s.created_at DESC
""").fetchall()
return [
SessionRecord(
id=row[0],
created_at=datetime.fromisoformat(row[1]),
destination=row[2],
link_count=row[3]
)
for row in rows
]
def get_symlinks_by_session(self, session_id: int) -> list[SymlinkRecord]:
"""Get all symlinks for a session.
Args:
session_id: The session ID to query.
Returns:
List of symlink records.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT id, session_id, source_path, link_path,
original_filename, sequence_number, created_at
FROM symlinks WHERE session_id = ?
ORDER BY sequence_number""",
(session_id,)
).fetchall()
return [
SymlinkRecord(
id=row[0],
session_id=row[1],
source_path=row[2],
link_path=row[3],
original_filename=row[4],
sequence_number=row[5],
created_at=datetime.fromisoformat(row[6])
)
for row in rows
]
def get_symlinks_by_destination(self, dest: str) -> list[SymlinkRecord]:
"""Get all symlinks for a destination directory.
Args:
dest: The destination directory path.
Returns:
List of symlink records.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT l.id, l.session_id, l.source_path, l.link_path,
l.original_filename, l.sequence_number, l.created_at
FROM symlinks l
JOIN symlink_sessions s ON l.session_id = s.id
WHERE s.destination = ?
ORDER BY l.sequence_number""",
(dest,)
).fetchall()
return [
SymlinkRecord(
id=row[0],
session_id=row[1],
source_path=row[2],
link_path=row[3],
original_filename=row[4],
sequence_number=row[5],
created_at=datetime.fromisoformat(row[6])
)
for row in rows
]
def delete_session(self, session_id: int) -> None:
"""Delete a session and all its symlink records.
Args:
session_id: The session ID to delete.
Raises:
DatabaseError: If deletion fails.
"""
try:
with self._connect() as conn:
conn.execute("DELETE FROM symlinks WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM symlink_sessions WHERE id = ?", (session_id,))
except sqlite3.Error as e:
raise DatabaseError(f"Failed to delete session: {e}") from e
def get_sessions_by_destination(self, dest: str) -> list[SessionRecord]:
"""Get all sessions for a destination directory.
Args:
dest: The destination directory path.
Returns:
List of session records.
"""
with self._connect() as conn:
rows = conn.execute("""
SELECT s.id, s.created_at, s.destination, COUNT(l.id) as link_count
FROM symlink_sessions s
LEFT JOIN symlinks l ON s.id = l.session_id
WHERE s.destination = ?
GROUP BY s.id
ORDER BY s.created_at DESC
""", (dest,)).fetchall()
return [
SessionRecord(
id=row[0],
created_at=datetime.fromisoformat(row[1]),
destination=row[2],
link_count=row[3]
)
for row in rows
]
def save_trim_settings(
self,
session_id: int,
source_folder: str,
trim_start: int,
trim_end: int,
folder_type: FolderType = FolderType.AUTO
) -> None:
"""Save trim settings for a folder in a session.
Args:
session_id: The session ID.
source_folder: Path to the source folder.
trim_start: Number of images to trim from start.
trim_end: Number of images to trim from end.
folder_type: The folder type (auto, main, or transition).
Raises:
DatabaseError: If saving fails.
"""
try:
with self._connect() as conn:
conn.execute(
"""INSERT INTO sequence_trim_settings
(session_id, source_folder, trim_start, trim_end, folder_type)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(session_id, source_folder)
DO UPDATE SET trim_start=excluded.trim_start,
trim_end=excluded.trim_end,
folder_type=excluded.folder_type""",
(session_id, source_folder, trim_start, trim_end, folder_type.value)
)
except sqlite3.Error as e:
raise DatabaseError(f"Failed to save trim settings: {e}") from e
def get_trim_settings(
self,
session_id: int,
source_folder: str
) -> tuple[int, int, FolderType]:
"""Get trim settings for a folder in a session.
Args:
session_id: The session ID.
source_folder: Path to the source folder.
Returns:
Tuple of (trim_start, trim_end, folder_type). Returns (0, 0, AUTO) if not found.
"""
with self._connect() as conn:
row = conn.execute(
"""SELECT trim_start, trim_end, folder_type FROM sequence_trim_settings
WHERE session_id = ? AND source_folder = ?""",
(session_id, source_folder)
).fetchone()
if row:
try:
folder_type = FolderType(row[2]) if row[2] else FolderType.AUTO
except ValueError:
folder_type = FolderType.AUTO
return (row[0], row[1], folder_type)
return (0, 0, FolderType.AUTO)
def get_all_trim_settings(self, session_id: int) -> dict[str, tuple[int, int]]:
"""Get all trim settings for a session.
Args:
session_id: The session ID.
Returns:
Dict mapping source folder paths to (trim_start, trim_end) tuples.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT source_folder, trim_start, trim_end
FROM sequence_trim_settings WHERE session_id = ?""",
(session_id,)
).fetchall()
return {row[0]: (row[1], row[2]) for row in rows}
def save_transition_settings(
self,
session_id: int,
settings: TransitionSettings
) -> None:
"""Save transition settings for a session.
Args:
session_id: The session ID.
settings: TransitionSettings to save.
Raises:
DatabaseError: If saving fails.
"""
try:
trans_dest = str(settings.trans_destination) if settings.trans_destination else None
rife_path = str(settings.rife_binary_path) if settings.rife_binary_path else None
with self._connect() as conn:
conn.execute(
"""INSERT INTO transition_settings
(session_id, enabled, blend_curve, output_format, webp_method, output_quality, trans_destination, blend_method, rife_binary_path)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(session_id)
DO UPDATE SET enabled=excluded.enabled,
blend_curve=excluded.blend_curve,
output_format=excluded.output_format,
webp_method=excluded.webp_method,
output_quality=excluded.output_quality,
trans_destination=excluded.trans_destination,
blend_method=excluded.blend_method,
rife_binary_path=excluded.rife_binary_path""",
(session_id, 1 if settings.enabled else 0,
settings.blend_curve.value, settings.output_format,
settings.webp_method, settings.output_quality, trans_dest,
settings.blend_method.value, rife_path)
)
except sqlite3.Error as e:
raise DatabaseError(f"Failed to save transition settings: {e}") from e
def get_transition_settings(self, session_id: int) -> Optional[TransitionSettings]:
"""Get transition settings for a session.
Args:
session_id: The session ID.
Returns:
TransitionSettings or None if not found.
"""
with self._connect() as conn:
row = conn.execute(
"""SELECT enabled, blend_curve, output_format, webp_method, output_quality, trans_destination, blend_method, rife_binary_path
FROM transition_settings WHERE session_id = ?""",
(session_id,)
).fetchone()
if row:
trans_dest = Path(row[5]) if row[5] else None
try:
blend_method = BlendMethod(row[6]) if row[6] else BlendMethod.ALPHA
except ValueError:
blend_method = BlendMethod.ALPHA
rife_path = Path(row[7]) if row[7] else None
return TransitionSettings(
enabled=bool(row[0]),
blend_curve=BlendCurve(row[1]),
output_format=row[2],
webp_method=row[3] if row[3] is not None else 4,
output_quality=row[4],
trans_destination=trans_dest,
blend_method=blend_method,
rife_binary_path=rife_path
)
return None
def save_folder_type_override(
self,
session_id: int,
folder: str,
folder_type: FolderType,
trim_start: int = 0,
trim_end: int = 0
) -> None:
"""Save folder type override for a folder in a session.
Args:
session_id: The session ID.
folder: Path to the source folder.
folder_type: The folder type override.
trim_start: Number of images to trim from start.
trim_end: Number of images to trim from end.
Raises:
DatabaseError: If saving fails.
"""
try:
with self._connect() as conn:
conn.execute(
"""INSERT INTO sequence_trim_settings
(session_id, source_folder, trim_start, trim_end, folder_type)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(session_id, source_folder)
DO UPDATE SET trim_start=excluded.trim_start,
trim_end=excluded.trim_end,
folder_type=excluded.folder_type""",
(session_id, folder, trim_start, trim_end, folder_type.value)
)
except sqlite3.Error as e:
raise DatabaseError(f"Failed to save folder type override: {e}") from e
def get_folder_type_overrides(self, session_id: int) -> dict[str, FolderType]:
"""Get all folder type overrides for a session.
Args:
session_id: The session ID.
Returns:
Dict mapping source folder paths to FolderType.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT source_folder, folder_type
FROM sequence_trim_settings WHERE session_id = ?""",
(session_id,)
).fetchall()
result = {}
for row in rows:
try:
result[row[0]] = FolderType(row[1]) if row[1] else FolderType.AUTO
except ValueError:
result[row[0]] = FolderType.AUTO
return result
def save_per_transition_settings(
self,
session_id: int,
settings: PerTransitionSettings
) -> None:
"""Save per-transition overlap settings.
Args:
session_id: The session ID.
settings: PerTransitionSettings to save.
Raises:
DatabaseError: If saving fails.
"""
try:
with self._connect() as conn:
conn.execute(
"""INSERT INTO per_transition_settings
(session_id, trans_folder, left_overlap, right_overlap)
VALUES (?, ?, ?, ?)
ON CONFLICT(session_id, trans_folder)
DO UPDATE SET left_overlap=excluded.left_overlap,
right_overlap=excluded.right_overlap""",
(session_id, str(settings.trans_folder),
settings.left_overlap, settings.right_overlap)
)
except sqlite3.Error as e:
raise DatabaseError(f"Failed to save per-transition settings: {e}") from e
def get_per_transition_settings(
self,
session_id: int,
trans_folder: str
) -> Optional[PerTransitionSettings]:
"""Get per-transition settings for a specific transition folder.
Args:
session_id: The session ID.
trans_folder: Path to the transition folder.
Returns:
PerTransitionSettings or None if not found.
"""
with self._connect() as conn:
row = conn.execute(
"""SELECT left_overlap, right_overlap FROM per_transition_settings
WHERE session_id = ? AND trans_folder = ?""",
(session_id, trans_folder)
).fetchone()
if row:
return PerTransitionSettings(
trans_folder=Path(trans_folder),
left_overlap=row[0],
right_overlap=row[1]
)
return None
def get_all_per_transition_settings(
self,
session_id: int
) -> dict[str, PerTransitionSettings]:
"""Get all per-transition settings for a session.
Args:
session_id: The session ID.
Returns:
Dict mapping transition folder paths to PerTransitionSettings.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT trans_folder, left_overlap, right_overlap
FROM per_transition_settings WHERE session_id = ?""",
(session_id,)
).fetchall()
return {
row[0]: PerTransitionSettings(
trans_folder=Path(row[0]),
left_overlap=row[1],
right_overlap=row[2]
)
for row in rows
}

205
core/manager.py Normal file
View File

@@ -0,0 +1,205 @@
"""Symlink management for Video Montage Linker."""
import os
import re
from pathlib import Path
from typing import Optional
from config import SUPPORTED_EXTENSIONS
from .models import LinkResult, CleanupError, SourceNotFoundError, DestinationError
from .database import DatabaseManager
class SymlinkManager:
"""Manages symlink creation and cleanup operations."""
def __init__(self, db: Optional[DatabaseManager] = None) -> None:
"""Initialize the symlink manager.
Args:
db: Optional database manager for tracking operations.
"""
self.db = db
@staticmethod
def get_supported_files(directories: list[Path]) -> list[tuple[Path, str]]:
"""Get all supported image files from multiple directories.
Files are returned sorted by directory order (as provided), then
alphabetically by filename within each directory.
Args:
directories: List of source directories to scan.
Returns:
List of (directory, filename) tuples.
"""
files: list[tuple[Path, str]] = []
for directory in directories:
if not directory.is_dir():
continue
dir_files = []
for item in directory.iterdir():
if item.is_file() and item.suffix.lower() in SUPPORTED_EXTENSIONS:
dir_files.append((directory, item.name))
# Sort files within this directory alphabetically
dir_files.sort(key=lambda x: x[1].lower())
files.extend(dir_files)
return files
@staticmethod
def validate_paths(sources: list[Path], dest: Path) -> None:
"""Validate source and destination paths.
Args:
sources: List of source directories.
dest: Destination directory.
Raises:
SourceNotFoundError: If any source directory doesn't exist.
DestinationError: If destination cannot be created or accessed.
"""
if not sources:
raise SourceNotFoundError("No source directories specified")
for source in sources:
if not source.exists():
raise SourceNotFoundError(f"Source directory not found: {source}")
if not source.is_dir():
raise SourceNotFoundError(f"Source is not a directory: {source}")
try:
dest.mkdir(parents=True, exist_ok=True)
except OSError as e:
raise DestinationError(f"Cannot create destination directory: {e}") from e
if not dest.is_dir():
raise DestinationError(f"Destination is not a directory: {dest}")
@staticmethod
def cleanup_old_links(directory: Path) -> int:
"""Remove existing seq* symlinks from a directory.
Handles both old format (seq_0000) and new format (seq01_0000).
Also removes blended image files (not just symlinks) created by
cross-dissolve transitions.
Args:
directory: Directory to clean up.
Returns:
Number of files removed.
Raises:
CleanupError: If cleanup fails.
"""
removed = 0
seq_pattern = re.compile(r'^seq\d*_\d+\.(png|jpg|jpeg|webp)$', re.IGNORECASE)
try:
for item in directory.iterdir():
# Match both old (seq_NNNN) and new (seqNN_NNNN) formats
if item.name.startswith("seq"):
if item.is_symlink():
item.unlink()
removed += 1
elif item.is_file() and seq_pattern.match(item.name):
# Also remove blended image files
item.unlink()
removed += 1
except OSError as e:
raise CleanupError(f"Failed to clean up old links: {e}") from e
return removed
def create_sequence_links(
self,
sources: list[Path],
dest: Path,
files: list[tuple],
trim_settings: Optional[dict[Path, tuple[int, int]]] = None,
) -> tuple[list[LinkResult], Optional[int]]:
"""Create sequenced symlinks from source files to destination.
Args:
sources: List of source directories (for validation).
dest: Destination directory.
files: List of tuples. Can be:
- (source_dir, filename) for CLI mode (uses global sequence)
- (source_dir, filename, folder_idx, file_idx) for GUI mode
trim_settings: Optional dict mapping folder paths to (trim_start, trim_end).
Returns:
Tuple of (list of LinkResult objects, session_id or None).
"""
self.validate_paths(sources, dest)
self.cleanup_old_links(dest)
session_id = None
if self.db:
session_id = self.db.create_session(str(dest))
# Save trim settings if provided
if trim_settings and session_id:
for folder, (trim_start, trim_end) in trim_settings.items():
if trim_start > 0 or trim_end > 0:
self.db.save_trim_settings(
session_id, str(folder), trim_start, trim_end
)
results: list[LinkResult] = []
# Check if we have folder indices (GUI mode) or not (CLI mode)
use_folder_sequences = len(files) > 0 and len(files[0]) >= 4
# For CLI mode without folder indices, calculate them
if not use_folder_sequences:
folder_to_index = {folder: i for i, folder in enumerate(sources)}
folder_file_counts: dict[Path, int] = {}
expanded_files = []
for source_dir, filename in files:
folder_idx = folder_to_index.get(source_dir, 0)
file_idx = folder_file_counts.get(source_dir, 0)
folder_file_counts[source_dir] = file_idx + 1
expanded_files.append((source_dir, filename, folder_idx, file_idx))
files = expanded_files
for i, file_data in enumerate(files):
source_dir, filename, folder_idx, file_idx = file_data
source_path = source_dir / filename
ext = source_path.suffix
link_name = f"seq{folder_idx + 1:02d}_{file_idx:04d}{ext}"
link_path = dest / link_name
# Calculate relative path from destination to source
rel_source = Path(os.path.relpath(source_path.resolve(), dest.resolve()))
try:
link_path.symlink_to(rel_source)
if self.db and session_id:
self.db.record_symlink(
session_id=session_id,
source=str(source_path.resolve()),
link=str(link_path),
filename=filename,
seq=i
)
results.append(LinkResult(
source_path=source_path,
link_path=link_path,
sequence_number=i,
success=True
))
except OSError as e:
results.append(LinkResult(
source_path=source_path,
link_path=link_path,
sequence_number=i,
success=False,
error=str(e)
))
return results, session_id

143
core/models.py Normal file
View File

@@ -0,0 +1,143 @@
"""Data models, enums, and exceptions for Video Montage Linker."""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Optional
# --- Enums ---
class BlendCurve(Enum):
"""Blend curve types for cross-dissolve transitions."""
LINEAR = 'linear'
EASE_IN = 'ease_in'
EASE_OUT = 'ease_out'
EASE_IN_OUT = 'ease_in_out'
class BlendMethod(Enum):
"""Blend method types for transitions."""
ALPHA = 'alpha' # Simple cross-dissolve (PIL.Image.blend)
OPTICAL_FLOW = 'optical' # OpenCV Farneback optical flow
RIFE = 'rife' # AI frame interpolation (NCNN binary)
RIFE_PRACTICAL = 'rife_practical' # Practical-RIFE Python/PyTorch implementation
class FolderType(Enum):
"""Folder type for transition detection."""
AUTO = 'auto'
MAIN = 'main'
TRANSITION = 'transition'
# --- Data Classes ---
@dataclass
class TransitionSettings:
"""Settings for cross-dissolve transitions."""
enabled: bool = False
blend_curve: BlendCurve = BlendCurve.LINEAR
output_format: str = 'png'
webp_method: int = 4 # 0-6, used when format is webp (compression effort)
output_quality: int = 95 # used for jpeg only
trans_destination: Optional[Path] = None # separate destination for transition output
blend_method: BlendMethod = BlendMethod.ALPHA # blending method
rife_binary_path: Optional[Path] = None # path to rife-ncnn-vulkan binary
rife_model: str = 'rife-v4.6' # RIFE model to use
rife_uhd: bool = False # Enable UHD mode for high resolution
rife_tta: bool = False # Enable TTA mode for better quality
# Practical-RIFE settings
practical_rife_model: str = 'v4.25' # v4.25, v4.26, v4.22, etc.
practical_rife_ensemble: bool = False # Ensemble mode for better quality (slower)
@dataclass
class PerTransitionSettings:
"""Per-transition overlap settings for asymmetric cross-dissolves."""
trans_folder: Path
left_overlap: int = 16 # frames from main folder end
right_overlap: int = 16 # frames from trans folder start
@dataclass
class BlendResult:
"""Result of an image blend operation."""
output_path: Path
source_a: Path
source_b: Path
blend_factor: float
success: bool
error: Optional[str] = None
@dataclass
class TransitionSpec:
"""Specification for a transition boundary between two folders."""
main_folder: Path
trans_folder: Path
main_files: list[str]
trans_files: list[str]
left_overlap: int # asymmetric: frames from main folder end
right_overlap: int # asymmetric: frames from trans folder start
# Indices into the overall file list
main_start_idx: int
trans_start_idx: int
@dataclass
class LinkResult:
"""Result of a symlink creation operation."""
source_path: Path
link_path: Path
sequence_number: int
success: bool
error: Optional[str] = None
@dataclass
class SymlinkRecord:
"""Database record of a created symlink."""
id: int
session_id: int
source_path: str
link_path: str
original_filename: str
sequence_number: int
created_at: datetime
@dataclass
class SessionRecord:
"""Database record of a symlink session."""
id: int
created_at: datetime
destination: str
link_count: int = 0
# --- Exceptions ---
class SymlinkError(Exception):
"""Base exception for symlink operations."""
class PathValidationError(SymlinkError):
"""Error validating file paths."""
class SourceNotFoundError(PathValidationError):
"""Source directory does not exist."""
class DestinationError(PathValidationError):
"""Error with destination directory."""
class CleanupError(SymlinkError):
"""Error during cleanup of existing symlinks."""
class DatabaseError(SymlinkError):
"""Error with database operations."""

429
core/rife_worker.py Normal file
View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python
"""RIFE interpolation worker - runs in isolated venv with PyTorch.
This script is executed via subprocess from the main application.
It handles loading Practical-RIFE models and performing frame interpolation.
Note: The Practical-RIFE models require the IFNet architecture from the
Practical-RIFE repository. This script downloads and uses the model weights
with a simplified inference implementation.
"""
import argparse
import os
import shutil
import sys
import tempfile
import urllib.request
import zipfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) / scale
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat) + feat
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale * 2
mask = tmp[:, 4:5]
return flow, mask
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
backwarp_tenGrid = {}
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
class IFNet(nn.Module):
"""IFNet architecture for Practical-RIFE v4.25/v4.26 models."""
def __init__(self):
super(IFNet, self).__init__()
# v4.25/v4.26 architecture:
# block0 input: img0(3) + img1(3) + f0(4) + f1(4) + timestep(1) = 15
# block1+ input: img0(3) + img1(3) + wf0(4) + wf1(4) + f0(4) + f1(4) + timestep(1) + mask(1) + flow(4) = 28
self.block0 = IFBlock(3+3+4+4+1, c=192)
self.block1 = IFBlock(3+3+4+4+4+4+1+1+4, c=128)
self.block2 = IFBlock(3+3+4+4+4+4+1+1+4, c=96)
self.block3 = IFBlock(3+3+4+4+4+4+1+1+4, c=64)
# Encode produces 4-channel features
self.encode = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 32, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 32, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.ConvTranspose2d(32, 4, 4, 2, 1)
)
def forward(self, img0, img1, timestep=0.5, scale_list=[8, 4, 2, 1]):
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
None, scale=scale_list[i])
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, f0, f1, timestep, mask), 1),
flow, scale=scale_list[i])
flow = flow + fd
mask = mask + m0
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
mask_final = torch.sigmoid(mask)
merged_final = warped_img0 * mask_final + warped_img1 * (1 - mask_final)
return merged_final
# Model URLs for downloading (Google Drive direct download links)
# File IDs extracted from official Practical-RIFE repository
MODEL_URLS = {
'v4.26': 'https://drive.google.com/uc?export=download&id=1gViYvvQrtETBgU1w8axZSsr7YUuw31uy',
'v4.25': 'https://drive.google.com/uc?export=download&id=1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg',
'v4.22': 'https://drive.google.com/uc?export=download&id=1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_',
'v4.20': 'https://drive.google.com/uc?export=download&id=11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA',
'v4.18': 'https://drive.google.com/uc?export=download&id=1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s',
'v4.15': 'https://drive.google.com/uc?export=download&id=1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5',
}
def download_model(version: str, model_dir: Path) -> Path:
"""Download model if not already cached.
Google Drive links distribute zip files containing the model.
This function downloads and extracts the flownet.pkl file.
Args:
version: Model version (e.g., 'v4.25').
model_dir: Directory to store models.
Returns:
Path to the downloaded model file.
"""
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f'flownet_{version}.pkl'
if model_path.exists():
# Verify it's not a zip file (from previous failed attempt)
with open(model_path, 'rb') as f:
header = f.read(4)
if header == b'PK\x03\x04': # ZIP magic number
print(f"Removing corrupted zip file at {model_path}", file=sys.stderr)
model_path.unlink()
else:
return model_path
url = MODEL_URLS.get(version)
if not url:
raise ValueError(f"Unknown model version: {version}")
print(f"Downloading RIFE model {version}...", file=sys.stderr)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir) / 'download'
# Try using gdown for Google Drive (handles confirmations automatically)
downloaded = False
try:
import gdown
file_id = url.split('id=')[1] if 'id=' in url else None
if file_id:
gdown_url = f'https://drive.google.com/uc?id={file_id}'
gdown.download(gdown_url, str(tmp_path), quiet=False)
downloaded = tmp_path.exists()
except ImportError:
print("gdown not available, trying direct download...", file=sys.stderr)
except Exception as e:
print(f"gdown failed: {e}, trying direct download...", file=sys.stderr)
# Fallback: direct download
if not downloaded:
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req, timeout=300) as response:
data = response.read()
if data[:100].startswith(b'<!') or b'<html' in data[:500].lower():
raise RuntimeError("Google Drive returned HTML - install gdown: pip install gdown")
with open(tmp_path, 'wb') as f:
f.write(data)
downloaded = True
except Exception as e:
raise RuntimeError(f"Failed to download model: {e}")
if not downloaded or not tmp_path.exists():
raise RuntimeError("Download failed - no file received")
# Check if downloaded file is a zip archive
with open(tmp_path, 'rb') as f:
header = f.read(4)
if header == b'PK\x03\x04': # ZIP magic number
print(f"Extracting model from zip archive...", file=sys.stderr)
with zipfile.ZipFile(tmp_path, 'r') as zf:
# Find flownet.pkl in the archive
pkl_files = [n for n in zf.namelist() if n.endswith('flownet.pkl')]
if not pkl_files:
raise RuntimeError(f"No flownet.pkl found in zip. Contents: {zf.namelist()}")
# Extract the pkl file
pkl_name = pkl_files[0]
with zf.open(pkl_name) as src, open(model_path, 'wb') as dst:
dst.write(src.read())
else:
# Already a pkl file, just move it
shutil.move(str(tmp_path), str(model_path))
print(f"Model saved to {model_path}", file=sys.stderr)
return model_path
def load_model(model_path: Path, device: torch.device) -> IFNet:
"""Load IFNet model from state dict.
Args:
model_path: Path to flownet.pkl file.
device: Device to load model to.
Returns:
Loaded IFNet model.
"""
model = IFNet()
state_dict = torch.load(model_path, map_location='cpu')
# Handle different state dict formats
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Remove 'module.' prefix if present (from DataParallel)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
# Handle flownet. prefix
if k.startswith('flownet.'):
k = k[8:]
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
return model
def pad_image(img: torch.Tensor, padding: int = 64) -> tuple:
"""Pad image to be divisible by padding.
Args:
img: Input tensor (B, C, H, W).
padding: Padding divisor.
Returns:
Tuple of (padded image, (original H, original W)).
"""
_, _, h, w = img.shape
ph = ((h - 1) // padding + 1) * padding
pw = ((w - 1) // padding + 1) * padding
pad_h = ph - h
pad_w = pw - w
padded = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate')
return padded, (h, w)
@torch.no_grad()
def inference(model: IFNet, img0: torch.Tensor, img1: torch.Tensor,
timestep: float = 0.5, ensemble: bool = False) -> torch.Tensor:
"""Perform frame interpolation.
Args:
model: Loaded IFNet model.
img0: First frame tensor (B, C, H, W) normalized to [0, 1].
img1: Second frame tensor (B, C, H, W) normalized to [0, 1].
timestep: Interpolation timestep (0.0 to 1.0).
ensemble: Enable ensemble mode for better quality.
Returns:
Interpolated frame tensor.
"""
# Pad images
img0_padded, orig_size = pad_image(img0)
img1_padded, _ = pad_image(img1)
h, w = orig_size
# Create timestep tensor
timestep_tensor = torch.full((1, 1, img0_padded.shape[2], img0_padded.shape[3]),
timestep, device=img0.device)
if ensemble:
# Ensemble: average of forward and reverse
result1 = model(img0_padded, img1_padded, timestep_tensor)
result2 = model(img1_padded, img0_padded, 1 - timestep_tensor)
result = (result1 + result2) / 2
else:
result = model(img0_padded, img1_padded, timestep_tensor)
# Crop back to original size
result = result[:, :, :h, :w]
return result.clamp(0, 1)
def load_image(path: Path, device: torch.device) -> torch.Tensor:
"""Load image as tensor.
Args:
path: Path to image file.
device: Device to load tensor to.
Returns:
Image tensor (1, 3, H, W) normalized to [0, 1].
"""
img = Image.open(path).convert('RGB')
arr = np.array(img).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(device)
def save_image(tensor: torch.Tensor, path: Path) -> None:
"""Save tensor as image.
Args:
tensor: Image tensor (1, 3, H, W) normalized to [0, 1].
path: Output path.
"""
arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
arr = (arr * 255).clip(0, 255).astype(np.uint8)
Image.fromarray(arr).save(path)
# Global model cache
_model_cache: dict = {}
def get_model(version: str, model_dir: Path, device: torch.device) -> IFNet:
"""Get or load model (cached).
Args:
version: Model version.
model_dir: Model cache directory.
device: Device to run on.
Returns:
IFNet model instance.
"""
cache_key = f"{version}_{device}"
if cache_key not in _model_cache:
model_path = download_model(version, model_dir)
_model_cache[cache_key] = load_model(model_path, device)
return _model_cache[cache_key]
def main():
parser = argparse.ArgumentParser(description='RIFE frame interpolation worker')
parser.add_argument('--input0', required=True, help='Path to first input image')
parser.add_argument('--input1', required=True, help='Path to second input image')
parser.add_argument('--output', required=True, help='Path to output image')
parser.add_argument('--timestep', type=float, default=0.5, help='Interpolation timestep (0-1)')
parser.add_argument('--model', default='v4.25', help='Model version')
parser.add_argument('--model-dir', required=True, help='Model cache directory')
parser.add_argument('--ensemble', action='store_true', help='Enable ensemble mode')
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use')
args = parser.parse_args()
try:
# Select device
if args.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Load model
model_dir = Path(args.model_dir)
model = get_model(args.model, model_dir, device)
# Load images
img0 = load_image(Path(args.input0), device)
img1 = load_image(Path(args.input1), device)
# Interpolate
result = inference(model, img0, img1, args.timestep, args.ensemble)
# Save result
save_image(result, Path(args.output))
print("Success", file=sys.stderr)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@@ -5,744 +5,20 @@ Supports both GUI and CLI modes for creating numbered symlinks from one or more
source directories into a single destination directory.
"""
# --- Imports ---
import argparse
import os
import sqlite3
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from PyQt6.QtCore import Qt
from PyQt6.QtGui import QDragEnterEvent, QDropEvent
from PyQt6.QtWidgets import (
QApplication,
QWidget,
QVBoxLayout,
QPushButton,
QLabel,
QFileDialog,
QLineEdit,
QHBoxLayout,
QMessageBox,
QListWidget,
QTreeWidget,
QTreeWidgetItem,
QAbstractItemView,
QGroupBox,
QHeaderView,
from PyQt6.QtWidgets import QApplication
from core import (
DatabaseManager,
SymlinkManager,
CleanupError,
)
from ui import SequenceLinkerUI
# --- Configuration ---
SUPPORTED_EXTENSIONS = ('.png', '.webp', '.jpg', '.jpeg')
DB_PATH = Path.home() / '.config' / 'video-montage-linker' / 'symlinks.db'
# --- Exceptions ---
class SymlinkError(Exception):
"""Base exception for symlink operations."""
class PathValidationError(SymlinkError):
"""Error validating file paths."""
class SourceNotFoundError(PathValidationError):
"""Source directory does not exist."""
class DestinationError(PathValidationError):
"""Error with destination directory."""
class CleanupError(SymlinkError):
"""Error during cleanup of existing symlinks."""
class DatabaseError(SymlinkError):
"""Error with database operations."""
# --- Data Classes ---
@dataclass
class LinkResult:
"""Result of a symlink creation operation."""
source_path: Path
link_path: Path
sequence_number: int
success: bool
error: Optional[str] = None
@dataclass
class SymlinkRecord:
"""Database record of a created symlink."""
id: int
session_id: int
source_path: str
link_path: str
original_filename: str
sequence_number: int
created_at: datetime
@dataclass
class SessionRecord:
"""Database record of a symlink session."""
id: int
created_at: datetime
destination: str
link_count: int = 0
# --- Database ---
class DatabaseManager:
"""Manages SQLite database for tracking symlink sessions and links."""
def __init__(self, db_path: Path = DB_PATH) -> None:
"""Initialize database manager.
Args:
db_path: Path to the SQLite database file.
"""
self.db_path = db_path
self._ensure_db_exists()
def _ensure_db_exists(self) -> None:
"""Create database and tables if they don't exist."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS symlink_sessions (
id INTEGER PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
destination TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS symlinks (
id INTEGER PRIMARY KEY,
session_id INTEGER REFERENCES symlink_sessions(id) ON DELETE CASCADE,
source_path TEXT NOT NULL,
link_path TEXT NOT NULL,
original_filename TEXT NOT NULL,
sequence_number INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
def _connect(self) -> sqlite3.Connection:
"""Create a database connection with foreign keys enabled."""
conn = sqlite3.connect(self.db_path)
conn.execute("PRAGMA foreign_keys = ON")
return conn
def create_session(self, destination: str) -> int:
"""Create a new linking session.
Args:
destination: The destination directory path.
Returns:
The ID of the created session.
Raises:
DatabaseError: If session creation fails.
"""
try:
with self._connect() as conn:
cursor = conn.execute(
"INSERT INTO symlink_sessions (destination) VALUES (?)",
(destination,)
)
return cursor.lastrowid
except sqlite3.Error as e:
raise DatabaseError(f"Failed to create session: {e}") from e
def record_symlink(
self,
session_id: int,
source: str,
link: str,
filename: str,
seq: int
) -> int:
"""Record a created symlink.
Args:
session_id: The session this symlink belongs to.
source: Full path to the source file.
link: Full path to the created symlink.
filename: Original filename.
seq: Sequence number in the destination.
Returns:
The ID of the created record.
Raises:
DatabaseError: If recording fails.
"""
try:
with self._connect() as conn:
cursor = conn.execute(
"""INSERT INTO symlinks
(session_id, source_path, link_path, original_filename, sequence_number)
VALUES (?, ?, ?, ?, ?)""",
(session_id, source, link, filename, seq)
)
return cursor.lastrowid
except sqlite3.Error as e:
raise DatabaseError(f"Failed to record symlink: {e}") from e
def get_sessions(self) -> list[SessionRecord]:
"""List all sessions with link counts.
Returns:
List of session records.
"""
with self._connect() as conn:
rows = conn.execute("""
SELECT s.id, s.created_at, s.destination, COUNT(l.id) as link_count
FROM symlink_sessions s
LEFT JOIN symlinks l ON s.id = l.session_id
GROUP BY s.id
ORDER BY s.created_at DESC
""").fetchall()
return [
SessionRecord(
id=row[0],
created_at=datetime.fromisoformat(row[1]),
destination=row[2],
link_count=row[3]
)
for row in rows
]
def get_symlinks_by_session(self, session_id: int) -> list[SymlinkRecord]:
"""Get all symlinks for a session.
Args:
session_id: The session ID to query.
Returns:
List of symlink records.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT id, session_id, source_path, link_path,
original_filename, sequence_number, created_at
FROM symlinks WHERE session_id = ?
ORDER BY sequence_number""",
(session_id,)
).fetchall()
return [
SymlinkRecord(
id=row[0],
session_id=row[1],
source_path=row[2],
link_path=row[3],
original_filename=row[4],
sequence_number=row[5],
created_at=datetime.fromisoformat(row[6])
)
for row in rows
]
def get_symlinks_by_destination(self, dest: str) -> list[SymlinkRecord]:
"""Get all symlinks for a destination directory.
Args:
dest: The destination directory path.
Returns:
List of symlink records.
"""
with self._connect() as conn:
rows = conn.execute(
"""SELECT l.id, l.session_id, l.source_path, l.link_path,
l.original_filename, l.sequence_number, l.created_at
FROM symlinks l
JOIN symlink_sessions s ON l.session_id = s.id
WHERE s.destination = ?
ORDER BY l.sequence_number""",
(dest,)
).fetchall()
return [
SymlinkRecord(
id=row[0],
session_id=row[1],
source_path=row[2],
link_path=row[3],
original_filename=row[4],
sequence_number=row[5],
created_at=datetime.fromisoformat(row[6])
)
for row in rows
]
def delete_session(self, session_id: int) -> None:
"""Delete a session and all its symlink records.
Args:
session_id: The session ID to delete.
Raises:
DatabaseError: If deletion fails.
"""
try:
with self._connect() as conn:
conn.execute("DELETE FROM symlinks WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM symlink_sessions WHERE id = ?", (session_id,))
except sqlite3.Error as e:
raise DatabaseError(f"Failed to delete session: {e}") from e
def get_sessions_by_destination(self, dest: str) -> list[SessionRecord]:
"""Get all sessions for a destination directory.
Args:
dest: The destination directory path.
Returns:
List of session records.
"""
with self._connect() as conn:
rows = conn.execute("""
SELECT s.id, s.created_at, s.destination, COUNT(l.id) as link_count
FROM symlink_sessions s
LEFT JOIN symlinks l ON s.id = l.session_id
WHERE s.destination = ?
GROUP BY s.id
ORDER BY s.created_at DESC
""", (dest,)).fetchall()
return [
SessionRecord(
id=row[0],
created_at=datetime.fromisoformat(row[1]),
destination=row[2],
link_count=row[3]
)
for row in rows
]
# --- Business Logic ---
class SymlinkManager:
"""Manages symlink creation and cleanup operations."""
def __init__(self, db: Optional[DatabaseManager] = None) -> None:
"""Initialize the symlink manager.
Args:
db: Optional database manager for tracking operations.
"""
self.db = db
@staticmethod
def get_supported_files(directories: list[Path]) -> list[tuple[Path, str]]:
"""Get all supported image files from multiple directories.
Files are returned sorted by directory order (as provided), then
alphabetically by filename within each directory.
Args:
directories: List of source directories to scan.
Returns:
List of (directory, filename) tuples.
"""
files: list[tuple[Path, str]] = []
for directory in directories:
if not directory.is_dir():
continue
dir_files = []
for item in directory.iterdir():
if item.is_file() and item.suffix.lower() in SUPPORTED_EXTENSIONS:
dir_files.append((directory, item.name))
# Sort files within this directory alphabetically
dir_files.sort(key=lambda x: x[1].lower())
files.extend(dir_files)
return files
@staticmethod
def validate_paths(sources: list[Path], dest: Path) -> None:
"""Validate source and destination paths.
Args:
sources: List of source directories.
dest: Destination directory.
Raises:
SourceNotFoundError: If any source directory doesn't exist.
DestinationError: If destination cannot be created or accessed.
"""
if not sources:
raise SourceNotFoundError("No source directories specified")
for source in sources:
if not source.exists():
raise SourceNotFoundError(f"Source directory not found: {source}")
if not source.is_dir():
raise SourceNotFoundError(f"Source is not a directory: {source}")
try:
dest.mkdir(parents=True, exist_ok=True)
except OSError as e:
raise DestinationError(f"Cannot create destination directory: {e}") from e
if not dest.is_dir():
raise DestinationError(f"Destination is not a directory: {dest}")
@staticmethod
def cleanup_old_links(directory: Path) -> int:
"""Remove existing seq_* symlinks from a directory.
Args:
directory: Directory to clean up.
Returns:
Number of files removed.
Raises:
CleanupError: If cleanup fails.
"""
removed = 0
try:
for item in directory.iterdir():
if item.name.startswith("seq_") and item.is_symlink():
item.unlink()
removed += 1
except OSError as e:
raise CleanupError(f"Failed to clean up old links: {e}") from e
return removed
def create_sequence_links(
self,
sources: list[Path],
dest: Path,
files: list[tuple[Path, str]],
) -> list[LinkResult]:
"""Create sequenced symlinks from source files to destination.
Args:
sources: List of source directories (for validation).
dest: Destination directory.
files: List of (source_dir, filename) tuples in desired order.
Returns:
List of LinkResult objects for each operation.
"""
self.validate_paths(sources, dest)
self.cleanup_old_links(dest)
session_id = None
if self.db:
session_id = self.db.create_session(str(dest))
results: list[LinkResult] = []
for i, (source_dir, filename) in enumerate(files):
source_path = source_dir / filename
ext = source_path.suffix
link_name = f"seq_{i:04d}{ext}"
link_path = dest / link_name
# Calculate relative path from destination to source
rel_source = Path(os.path.relpath(source_path.resolve(), dest.resolve()))
try:
link_path.symlink_to(rel_source)
if self.db and session_id:
self.db.record_symlink(
session_id=session_id,
source=str(source_path.resolve()),
link=str(link_path),
filename=filename,
seq=i
)
results.append(LinkResult(
source_path=source_path,
link_path=link_path,
sequence_number=i,
success=True
))
except OSError as e:
results.append(LinkResult(
source_path=source_path,
link_path=link_path,
sequence_number=i,
success=False,
error=str(e)
))
return results
# --- GUI ---
class SequenceLinkerUI(QWidget):
"""PyQt6 GUI for the Video Montage Linker."""
def __init__(self) -> None:
"""Initialize the UI."""
super().__init__()
self.source_folders: list[Path] = []
self.last_directory: Optional[str] = None
self.db = DatabaseManager()
self.manager = SymlinkManager(self.db)
self._setup_window()
self._create_widgets()
self._create_layout()
self._connect_signals()
self.setAcceptDrops(True)
def _setup_window(self) -> None:
"""Configure the main window properties."""
self.setWindowTitle('Video Montage Linker')
self.setMinimumSize(700, 600)
def _create_widgets(self) -> None:
"""Create all UI widgets."""
# Source folders group
self.source_group = QGroupBox("Source Folders (drag & drop folders here)")
self.source_list = QListWidget()
self.source_list.setMaximumHeight(100)
self.source_list.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection)
self.add_source_btn = QPushButton("Add Folder")
self.remove_source_btn = QPushButton("Remove Folder")
# Destination
self.dst_label = QLabel("Destination Folder:")
self.dst_path = QLineEdit(placeholderText="Select destination folder")
self.dst_btn = QPushButton("Browse")
# File list
self.files_label = QLabel("Sequence Order (Drag to reorder, Del to remove):")
self.file_list = QTreeWidget()
self.file_list.setHeaderLabels(["Filename", "Source Folder"])
self.file_list.setDragDropMode(QAbstractItemView.DragDropMode.InternalMove)
self.file_list.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection)
self.file_list.setRootIsDecorated(False)
self.file_list.header().setStretchLastSection(True)
self.file_list.header().setSectionResizeMode(0, QHeaderView.ResizeMode.Interactive)
# Action buttons
self.remove_files_btn = QPushButton("Remove Files")
self.refresh_btn = QPushButton("Refresh Files")
self.run_btn = QPushButton("Generate Virtual Sequence")
self.run_btn.setStyleSheet(
"background-color: #3498db; color: white; "
"height: 40px; font-weight: bold;"
)
def _create_layout(self) -> None:
"""Arrange widgets in layouts."""
main_layout = QVBoxLayout()
# Source folders group layout
source_group_layout = QVBoxLayout()
source_btn_layout = QHBoxLayout()
source_btn_layout.addWidget(self.add_source_btn)
source_btn_layout.addWidget(self.remove_source_btn)
source_btn_layout.addStretch()
source_group_layout.addWidget(self.source_list)
source_group_layout.addLayout(source_btn_layout)
self.source_group.setLayout(source_group_layout)
# Destination layout
dst_layout = QHBoxLayout()
dst_layout.addWidget(self.dst_path)
dst_layout.addWidget(self.dst_btn)
# Button layout
btn_layout = QHBoxLayout()
btn_layout.addWidget(self.remove_files_btn)
btn_layout.addWidget(self.refresh_btn)
btn_layout.addStretch()
# Assemble main layout
main_layout.addWidget(self.source_group)
main_layout.addWidget(self.dst_label)
main_layout.addLayout(dst_layout)
main_layout.addWidget(self.files_label)
main_layout.addWidget(self.file_list)
main_layout.addLayout(btn_layout)
main_layout.addWidget(self.run_btn)
self.setLayout(main_layout)
def _connect_signals(self) -> None:
"""Connect widget signals to slots."""
self.add_source_btn.clicked.connect(self._add_source_folder)
self.remove_source_btn.clicked.connect(self._remove_source_folder)
self.dst_btn.clicked.connect(self._browse_destination)
self.remove_files_btn.clicked.connect(self._remove_selected_files)
self.refresh_btn.clicked.connect(self._refresh_files)
self.run_btn.clicked.connect(self._process_links)
def _add_source_folder(self, folder_path: Optional[str] = None) -> None:
"""Add a source folder via file dialog or direct path.
Args:
folder_path: Optional path to add directly (for drag-drop).
"""
if folder_path:
path = folder_path
else:
start_dir = self.last_directory or ""
path = QFileDialog.getExistingDirectory(
self, "Select Source Folder", start_dir
)
if path:
folder = Path(path)
if folder.is_dir() and folder not in self.source_folders:
self.source_folders.append(folder)
self.source_list.addItem(str(folder))
self.last_directory = str(folder.parent)
self._refresh_files()
def _remove_source_folder(self) -> None:
"""Remove selected source folder(s)."""
selected = self.source_list.selectedItems()
if not selected:
return
# Remove in reverse order to maintain correct indices
rows = sorted([self.source_list.row(item) for item in selected], reverse=True)
for row in rows:
self.source_list.takeItem(row)
del self.source_folders[row]
self._refresh_files()
def _remove_selected_files(self) -> None:
"""Remove selected files from the file list."""
selected = self.file_list.selectedItems()
if not selected:
return
# Remove in reverse order to maintain correct indices
rows = sorted([self.file_list.indexOfTopLevelItem(item) for item in selected], reverse=True)
for row in rows:
self.file_list.takeTopLevelItem(row)
def _browse_destination(self) -> None:
"""Select destination folder via file dialog."""
start_dir = self.last_directory or ""
path = QFileDialog.getExistingDirectory(
self, "Select Destination Folder", start_dir
)
if path:
self.dst_path.setText(path)
self.last_directory = str(Path(path).parent)
def keyPressEvent(self, event) -> None:
"""Handle key press events for deletion."""
if event.key() == Qt.Key.Key_Delete:
# Check which widget has focus
if self.file_list.hasFocus():
self._remove_selected_files()
elif self.source_list.hasFocus():
self._remove_source_folder()
else:
super().keyPressEvent(event)
def dragEnterEvent(self, event: QDragEnterEvent) -> None:
"""Accept drag events with URLs (folders)."""
if event.mimeData().hasUrls():
event.acceptProposedAction()
def dropEvent(self, event: QDropEvent) -> None:
"""Handle dropped folders."""
for url in event.mimeData().urls():
path = url.toLocalFile()
if path and Path(path).is_dir():
self._add_source_folder(path)
def _refresh_files(self) -> None:
"""Refresh the file list from all source folders."""
self.file_list.clear()
if not self.source_folders:
return
files = self.manager.get_supported_files(self.source_folders)
for source_dir, filename in files:
item = QTreeWidgetItem([filename, str(source_dir)])
item.setData(0, Qt.ItemDataRole.UserRole, (source_dir, filename))
self.file_list.addTopLevelItem(item)
def _get_files_in_order(self) -> list[tuple[Path, str]]:
"""Get files in the current list order.
Returns:
List of (source_dir, filename) tuples in display order.
"""
files = []
for i in range(self.file_list.topLevelItemCount()):
item = self.file_list.topLevelItem(i)
data = item.data(0, Qt.ItemDataRole.UserRole)
if data:
files.append(data)
return files
def _process_links(self) -> None:
"""Create symlinks based on current configuration."""
dst = self.dst_path.text()
if not self.source_folders:
QMessageBox.warning(self, "Error", "Add at least one source folder!")
return
if not dst:
QMessageBox.warning(self, "Error", "Select a destination folder!")
return
files = self._get_files_in_order()
if not files:
QMessageBox.warning(self, "Error", "No files to process!")
return
try:
results = self.manager.create_sequence_links(
sources=self.source_folders,
dest=Path(dst),
files=files
)
successful = sum(1 for r in results if r.success)
failed = sum(1 for r in results if not r.success)
if failed > 0:
QMessageBox.warning(
self, "Partial Success",
f"Linked {successful} files, {failed} failed.\n"
f"Destination: {dst}"
)
else:
QMessageBox.information(
self, "Success",
f"Linked {successful} files to {dst}"
)
except SymlinkError as e:
QMessageBox.critical(self, "Error", str(e))
except Exception as e:
QMessageBox.critical(self, "Unexpected Error", str(e))
# --- CLI ---
def create_parser() -> argparse.ArgumentParser:
"""Create the argument parser for CLI mode.
@@ -851,6 +127,8 @@ def run_cli(args: argparse.Namespace) -> int:
# Create symlinks
if args.src and args.dst:
from core import SymlinkError
sources = [Path(s).resolve() for s in args.src]
dest = Path(args.dst).resolve()
@@ -864,7 +142,7 @@ def run_cli(args: argparse.Namespace) -> int:
print(f"Found {len(files)} files in {len(sources)} source folder(s)")
results = manager.create_sequence_links(
results, _ = manager.create_sequence_links(
sources=sources,
dest=dest,
files=files
@@ -899,7 +177,6 @@ def run_cli(args: argparse.Namespace) -> int:
return 0
# --- Entry Point ---
def main() -> int:
"""Main entry point for the application.
@@ -910,7 +187,6 @@ def main() -> int:
args = parser.parse_args()
# Determine if we should launch GUI
# GUI is launched if: --gui flag, OR no arguments at all
launch_gui = args.gui or (
not args.src and
not args.dst and

9
ui/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
"""UI modules for Video Montage Linker."""
from .widgets import TrimSlider
from .main_window import SequenceLinkerUI
__all__ = [
'TrimSlider',
'SequenceLinkerUI',
]

2623
ui/main_window.py Normal file

File diff suppressed because it is too large Load Diff

291
ui/widgets.py Normal file
View File

@@ -0,0 +1,291 @@
"""Custom widgets for Video Montage Linker UI."""
from typing import Optional
from PyQt6.QtCore import Qt, pyqtSignal, QRect
from PyQt6.QtGui import QPainter, QColor, QBrush, QPen, QMouseEvent
from PyQt6.QtWidgets import QWidget
class TrimSlider(QWidget):
"""A slider widget with two draggable handles for trimming sequences.
Allows setting in/out points for a sequence by dragging left and right handles.
Gray areas indicate trimmed regions, colored area indicates included images.
"""
trimChanged = pyqtSignal(int, int, str) # Emits (trim_start, trim_end, 'left' or 'right')
def __init__(self, parent: Optional[QWidget] = None) -> None:
"""Initialize the trim slider.
Args:
parent: Parent widget.
"""
super().__init__(parent)
self._total = 0
self._trim_start = 0
self._trim_end = 0
self._current_pos = 0
self._dragging: Optional[str] = None # 'left', 'right', or None
self._handle_width = 10
self._track_height = 20
self._enabled = True
self.setMinimumHeight(40)
self.setMinimumWidth(100)
self.setCursor(Qt.CursorShape.ArrowCursor)
self.setMouseTracking(True)
def setRange(self, total: int) -> None:
"""Set the total number of items in the sequence.
Args:
total: Total number of items.
"""
self._total = max(0, total)
# Clamp trim values to valid range
self._trim_start = min(self._trim_start, max(0, self._total - 1))
self._trim_end = min(self._trim_end, max(0, self._total - 1 - self._trim_start))
self.update()
def setTrimStart(self, value: int) -> None:
"""Set the trim start value.
Args:
value: Number of items to trim from start.
"""
max_start = max(0, self._total - 1 - self._trim_end)
self._trim_start = max(0, min(value, max_start))
self.update()
def setTrimEnd(self, value: int) -> None:
"""Set the trim end value.
Args:
value: Number of items to trim from end.
"""
max_end = max(0, self._total - 1 - self._trim_start)
self._trim_end = max(0, min(value, max_end))
self.update()
def setCurrentPosition(self, pos: int) -> None:
"""Set the current position indicator.
Args:
pos: Current position index.
"""
self._current_pos = max(0, min(pos, self._total - 1)) if self._total > 0 else 0
self.update()
def trimStart(self) -> int:
"""Get the trim start value."""
return self._trim_start
def trimEnd(self) -> int:
"""Get the trim end value."""
return self._trim_end
def total(self) -> int:
"""Get the total number of items."""
return self._total
def includedRange(self) -> tuple[int, int]:
"""Get the range of included items (after trimming).
Returns:
Tuple of (first_included_index, last_included_index).
Returns (-1, -1) if no items are included.
"""
if self._total == 0:
return (-1, -1)
first = self._trim_start
last = self._total - 1 - self._trim_end
if first > last:
return (-1, -1)
return (first, last)
def setEnabled(self, enabled: bool) -> None:
"""Enable or disable the widget."""
self._enabled = enabled
self.update()
def _track_rect(self) -> QRect:
"""Get the rectangle for the slider track."""
margin = self._handle_width
return QRect(
margin,
(self.height() - self._track_height) // 2,
self.width() - 2 * margin,
self._track_height
)
def _value_to_x(self, value: int) -> int:
"""Convert a value to an x coordinate."""
track = self._track_rect()
if self._total <= 1:
return track.left()
ratio = value / (self._total - 1)
return int(track.left() + ratio * track.width())
def _x_to_value(self, x: int) -> int:
"""Convert an x coordinate to a value."""
track = self._track_rect()
if track.width() == 0 or self._total <= 1:
return 0
ratio = (x - track.left()) / track.width()
ratio = max(0.0, min(1.0, ratio))
return int(round(ratio * (self._total - 1)))
def _left_handle_rect(self) -> QRect:
"""Get the rectangle for the left (trim start) handle."""
x = self._value_to_x(self._trim_start)
return QRect(
x - self._handle_width // 2,
(self.height() - self._track_height - 10) // 2,
self._handle_width,
self._track_height + 10
)
def _right_handle_rect(self) -> QRect:
"""Get the rectangle for the right (trim end) handle."""
x = self._value_to_x(self._total - 1 - self._trim_end) if self._total > 0 else 0
return QRect(
x - self._handle_width // 2,
(self.height() - self._track_height - 10) // 2,
self._handle_width,
self._track_height + 10
)
def paintEvent(self, event) -> None:
"""Paint the trim slider."""
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
track = self._track_rect()
# Colors
bg_color = QColor(60, 60, 60)
trimmed_color = QColor(80, 80, 80)
included_color = QColor(52, 152, 219) if self._enabled else QColor(100, 100, 100)
handle_color = QColor(200, 200, 200) if self._enabled else QColor(120, 120, 120)
position_color = QColor(255, 255, 255)
# Draw background track
painter.fillRect(track, bg_color)
if self._total > 0:
# Draw trimmed regions (darker)
left_trim_x = self._value_to_x(self._trim_start)
right_trim_x = self._value_to_x(self._total - 1 - self._trim_end)
# Left trimmed region
if self._trim_start > 0:
left_rect = QRect(track.left(), track.top(),
left_trim_x - track.left(), track.height())
painter.fillRect(left_rect, trimmed_color)
# Right trimmed region
if self._trim_end > 0:
right_rect = QRect(right_trim_x, track.top(),
track.right() - right_trim_x, track.height())
painter.fillRect(right_rect, trimmed_color)
# Draw included region
if left_trim_x < right_trim_x:
included_rect = QRect(left_trim_x, track.top(),
right_trim_x - left_trim_x, track.height())
painter.fillRect(included_rect, included_color)
# Draw current position indicator
if self._trim_start <= self._current_pos <= (self._total - 1 - self._trim_end):
pos_x = self._value_to_x(self._current_pos)
painter.setPen(QPen(position_color, 2))
painter.drawLine(pos_x, track.top() - 2, pos_x, track.bottom() + 2)
# Draw handles
painter.setBrush(QBrush(handle_color))
painter.setPen(QPen(Qt.GlobalColor.black, 1))
# Left handle
left_handle = self._left_handle_rect()
painter.drawRect(left_handle)
# Right handle
right_handle = self._right_handle_rect()
painter.drawRect(right_handle)
painter.end()
def mousePressEvent(self, event: QMouseEvent) -> None:
"""Handle mouse press to start dragging handles."""
if not self._enabled or self._total == 0:
return
pos = event.pos()
# Check if clicking on handles (check right first since it may overlap)
right_rect = self._right_handle_rect()
left_rect = self._left_handle_rect()
# Expand hit area slightly for easier grabbing
expand = 5
left_expanded = left_rect.adjusted(-expand, -expand, expand, expand)
right_expanded = right_rect.adjusted(-expand, -expand, expand, expand)
if right_expanded.contains(pos):
self._dragging = 'right'
elif left_expanded.contains(pos):
self._dragging = 'left'
else:
self._dragging = None
def mouseMoveEvent(self, event: QMouseEvent) -> None:
"""Handle mouse move to drag handles."""
if not self._enabled:
return
pos = event.pos()
# Update cursor based on position
if self._dragging:
self.setCursor(Qt.CursorShape.SizeHorCursor)
else:
left_rect = self._left_handle_rect()
right_rect = self._right_handle_rect()
expand = 5
left_expanded = left_rect.adjusted(-expand, -expand, expand, expand)
right_expanded = right_rect.adjusted(-expand, -expand, expand, expand)
if left_expanded.contains(pos) or right_expanded.contains(pos):
self.setCursor(Qt.CursorShape.SizeHorCursor)
else:
self.setCursor(Qt.CursorShape.ArrowCursor)
if self._dragging and self._total > 0:
value = self._x_to_value(pos.x())
if self._dragging == 'left':
# Left handle: set trim_start, clamped to not exceed right
max_start = self._total - 1 - self._trim_end
new_start = max(0, min(value, max_start))
if new_start != self._trim_start:
self._trim_start = new_start
self.update()
self.trimChanged.emit(self._trim_start, self._trim_end, 'left')
elif self._dragging == 'right':
# Right handle: set trim_end based on position
# value is the index position, trim_end is count from end
max_val = self._total - 1 - self._trim_start
clamped_value = max(self._trim_start, min(value, self._total - 1))
new_end = self._total - 1 - clamped_value
if new_end != self._trim_end:
self._trim_end = max(0, new_end)
self.update()
self.trimChanged.emit(self._trim_start, self._trim_end, 'right')
def mouseReleaseEvent(self, event: QMouseEvent) -> None:
"""Handle mouse release to stop dragging."""
self._dragging = None
self.setCursor(Qt.CursorShape.ArrowCursor)