feat: scan results panel, model switching, batch scan, and training improvements

- Replace librosa with direct ffmpeg subprocess for 10x faster audio loading
- Add ScanResultsPanel with per-model tabs, seek-on-click, delete, and export
- Persist scan results in DB per (filename, profile, model)
- Add model selector dropdown to switch between trained embedding models
- Add "Scan All" button for batch scanning playlist videos
- Support manual negative examples via negative class folder
- Configurable auto-negative margin (default 30s, 0 = disabled)
- Deduplicate nearby training markers (8s min gap)
- Parallel audio loading with ThreadPoolExecutor during training
- Progress callbacks from training for UI status updates
- Cache bypass in scan_video (skip audio loading when embeddings cached)
- Move all caches (models, embeddings, downloads) into project directory
- Add 8cut.sh launcher script with auto venv/conda detection
- Fix 11 bugs across thread safety, signal handling, and state management

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-18 16:12:52 +02:00
parent f597ff29e8
commit 6870e5aaf3
5 changed files with 741 additions and 73 deletions
+1
View File
@@ -5,5 +5,6 @@ __pycache__/
.worktrees/ .worktrees/
.venv/ .venv/
models/ models/
cache/
*.joblib *.joblib
*.pt *.pt
Executable
+29
View File
@@ -0,0 +1,29 @@
#!/bin/bash
# Launch 8-cut with auto-detected venv/conda environment
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
ENV_NAME="8cut"
CONDA_PREFIX_BASE="/media/p5/miniforge3"
# 1. Try .venv in project dir
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
source "$SCRIPT_DIR/.venv/bin/activate"
exec python "$SCRIPT_DIR/main.py" "$@"
fi
# 2. Try conda env (works without shell init)
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
if [ -x "$CONDA_PYTHON" ]; then
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
fi
# 3. Try conda via shell hook (interactive shells)
if command -v conda &>/dev/null; then
eval "$(conda shell.bash hook 2>/dev/null)"
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
conda activate "$ENV_NAME"
exec python "$SCRIPT_DIR/main.py" "$@"
fi
fi
# 4. Fallback to system Python
exec python3 "$SCRIPT_DIR/main.py" "$@"
+168 -30
View File
@@ -2,15 +2,39 @@
import hashlib import hashlib
import os import os
import subprocess
import warnings
import numpy as np import numpy as np
import librosa
from .paths import _log from .paths import _bin, _log
_SR = 16000 # lower sr = faster _SR = 16000 # lower sr = faster
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
cmd = [
_bin("ffmpeg"), "-i", path,
"-vn", # skip video
"-ac", "1", # mono
"-ar", str(sr), # resample
"-f", "f32le", # raw 32-bit float little-endian
"-loglevel", "error",
"pipe:1",
]
proc = subprocess.run(cmd, capture_output=True, timeout=300)
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
return np.frombuffer(proc.stdout, dtype=np.float32)
_WINDOW = 8.0 # seconds _WINDOW = 8.0 # seconds
_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models") _PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v") _MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
# Redirect torch hub and huggingface downloads into the project
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Embedding extraction (lazy-loaded) # Embedding extraction (lazy-loaded)
@@ -33,7 +57,7 @@ _EMBED_MODELS = {
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
_BEATS_CHECKPOINT = os.path.join( _BEATS_CHECKPOINT = os.path.join(
os.path.expanduser("~"), ".cache", "huggingface", "hub", _DL_CACHE_DIR, "huggingface", "hub",
"models--lpepino--beats_ckpts", "snapshots", "models--lpepino--beats_ckpts", "snapshots",
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt", "5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
) )
@@ -86,6 +110,30 @@ def _w2v_cache_path(video_path: str, hop: float, window: float,
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz") return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
def _w2v_cache_exists(video_path: str, hop: float, window: float,
model_name: str | None = None) -> bool:
"""Check if embedding cache exists for a video."""
try:
path = _w2v_cache_path(video_path, hop, window, model_name)
return os.path.exists(path)
except Exception:
return False
def _w2v_cache_load(video_path: str, hop: float, window: float,
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
try:
path = _w2v_cache_path(video_path, hop, window, model_name)
if os.path.exists(path):
data = np.load(path)
_log(f"audio_scan: cache hit ({path})")
return data["timestamps"], data["embeddings"]
except Exception as e:
_log(f"audio_scan: cache read failed: {e}")
return None
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
hop: float = 1.0, window: float = _WINDOW, hop: float = 1.0, window: float = _WINDOW,
video_path: str | None = None, video_path: str | None = None,
@@ -162,6 +210,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
gt_soft: list[float], tolerance: float = 12.0, gt_soft: list[float], tolerance: float = 12.0,
neg_margin: float = 120.0, neg_margin: float = 120.0,
model_name: str | None = None, model_name: str | None = None,
gt_negative: list[float] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract embeddings only near positives and distant negatives. """Extract embeddings only near positives and distant negatives.
@@ -180,13 +229,24 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
if 0 <= t <= duration - _WINDOW: if 0 <= t <= duration - _WINDOW:
pos_times.add(int(t)) pos_times.add(int(t))
# Negative windows: every 4s, far from any marker # Manual negative windows: near explicit negative markers
manual_neg_times = set()
if gt_negative:
for gt in gt_negative:
for offset in range(-int(tolerance), int(tolerance) + 1):
t = gt + offset
if 0 <= t <= duration - _WINDOW:
manual_neg_times.add(int(t))
# Don't let manual negatives overlap with positives
manual_neg_times -= pos_times
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0)
neg_times = set() neg_times = set()
for t in range(0, int(duration - _WINDOW), 4): for t in range(0, int(duration - _WINDOW), 4):
if min((abs(t - g) for g in all_gt), default=9999) > neg_margin: if neg_margin > 0 and min((abs(t - g) for g in all_gt), default=9999) > neg_margin:
neg_times.add(t) neg_times.add(t)
all_times = sorted(pos_times | neg_times) all_times = sorted(pos_times | neg_times | manual_neg_times)
# Filter out windows that go past the end # Filter out windows that go past the end
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)] valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
@@ -225,9 +285,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
for i, t in enumerate(timestamps): for i, t in enumerate(timestamps):
di = min((abs(t - g) for g in gt_intense), default=9999) di = min((abs(t - g) for g in gt_intense), default=9999)
da = min((abs(t - g) for g in all_gt), default=9999) da = min((abs(t - g) for g in all_gt), default=9999)
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
if di < tolerance: if di < tolerance:
labels[i] = 1 labels[i] = 1
elif da > neg_margin: elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
labels[i] = -1 labels[i] = -1
return timestamps, embeddings, labels return timestamps, embeddings, labels
@@ -241,7 +302,9 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
tolerance: float = 12.0, tolerance: float = 12.0,
neg_margin: float = 120.0, neg_margin: float = 120.0,
embed_model: str | None = None, embed_model: str | None = None,
cancel_flag: object = None) -> dict: cancel_flag: object = None,
n_workers: int = 4,
progress_cb: object = None) -> dict:
"""Train a classifier from labeled videos. """Train a classifier from labeled videos.
Args: Args:
@@ -250,24 +313,62 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
tolerance/neg_margin: labeling parameters tolerance/neg_margin: labeling parameters
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
cancel_flag: object with _cancel attribute; if set, training aborts early cancel_flag: object with _cancel attribute; if set, training aborts early
n_workers: number of threads for parallel audio loading
Returns: Returns:
dict with 'classifier', 'embed_model', and metadata, or None on failure. dict with 'classifier', 'embed_model', and metadata, or None on failure.
""" """
from concurrent.futures import ThreadPoolExecutor, as_completed
from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingClassifier
all_X, all_y = [], [] def _progress(msg: str) -> None:
_log(msg)
if progress_cb:
progress_cb(msg)
for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos): def _load_audio(path: str) -> np.ndarray:
return _load_audio_ffmpeg(path, sr=_SR)
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
n = len(video_infos)
load_workers = min(n_workers, 4)
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
audio_data: dict[int, np.ndarray] = {}
with ThreadPoolExecutor(max_workers=load_workers) as pool:
future_to_idx = {
pool.submit(_load_audio, vi[0]): i
for i, vi in enumerate(video_infos)
}
failed = set()
for future in as_completed(future_to_idx):
if cancel_flag and getattr(cancel_flag, '_cancel', False):
_log("audio_scan: training cancelled")
return None
idx = future_to_idx[future]
try:
audio_data[idx] = future.result()
except Exception as e:
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
failed.add(idx)
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
# Phase 2: extract embeddings sequentially on GPU
_progress(f"Extracting embeddings: 0/{n}")
all_X, all_y = [], []
for vi, vinfo in enumerate(video_infos):
if vi in failed:
continue
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
gt_negative = vinfo[3] if len(vinfo) > 3 else []
if cancel_flag and getattr(cancel_flag, '_cancel', False): if cancel_flag and getattr(cancel_flag, '_cancel', False):
_log("audio_scan: training cancelled") _log("audio_scan: training cancelled")
return None return None
_log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}") _progress(f"Extracting embeddings: {vi+1}/{n}")
y, _ = librosa.load(vpath, sr=_SR, mono=True) y = audio_data.pop(vi)
timestamps, embeddings, labels = _extract_w2v_targeted( timestamps, embeddings, labels = _extract_w2v_targeted(
y, _SR, gt_intense, gt_soft, tolerance, neg_margin, y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
model_name=embed_model, model_name=embed_model, gt_negative=gt_negative,
) )
if len(timestamps) == 0: if len(timestamps) == 0:
continue continue
@@ -306,6 +407,7 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
train_idx = np.concatenate([pos_idx, neg_sample]) train_idx = np.concatenate([pos_idx, neg_sample])
rng.shuffle(train_idx) rng.shuffle(train_idx)
_progress(f"Fitting classifier on {len(train_idx)} samples...")
clf = GradientBoostingClassifier( clf = GradientBoostingClassifier(
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42, n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42,
) )
@@ -334,11 +436,41 @@ def load_classifier(model_path: str) -> dict | None:
return joblib.load(model_path) return joblib.load(model_path)
def default_model_path(profile_name: str = "default") -> str: def default_model_path(profile_name: str = "default",
"""Return the default path for a profile's classifier model.""" embed_model: str | None = None) -> str:
"""Return the path for a profile's classifier model.
When embed_model is given the file is ``{profile}_{model}.joblib``,
otherwise ``{profile}.joblib`` (legacy single-model layout).
"""
if embed_model:
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib") return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
def list_trained_models(profile_name: str = "default") -> list[str]:
"""Return embedding model names that have a trained .joblib for *profile_name*.
Looks for files matching ``{profile}_{MODEL}.joblib`` in the models dir.
"""
prefix = f"{profile_name}_"
suffix = ".joblib"
result = []
if not os.path.isdir(_MODEL_DIR):
return result
for fname in os.listdir(_MODEL_DIR):
if fname.startswith(prefix) and fname.endswith(suffix):
model_name = fname[len(prefix):-len(suffix)]
if model_name in _EMBED_MODELS:
result.append(model_name)
# Also check legacy {profile}.joblib
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
if os.path.exists(legacy) and not result:
# Legacy model — we don't know the embed model, but it's usable
result.append("")
return sorted(result)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Scanning # Scanning
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -359,22 +491,28 @@ def scan_video(
_log("audio_scan: no model provided") _log("audio_scan: no model provided")
return [] return []
_log(f"audio_scan: loading {video_path}")
y, sr = librosa.load(video_path, sr=_SR, mono=True)
duration = len(y) / sr
_log(f"audio_scan: {duration:.1f}s loaded, extracting features...")
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return []
clf = model["classifier"] clf = model["classifier"]
embed_model = model.get("embed_model") embed_model = model.get("embed_model")
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...") # Try cache first — skip expensive audio loading if embeddings exist
timestamps, window_vectors = _extract_w2v_windows( cached = _w2v_cache_load(video_path, hop, window, embed_model)
y, sr, hop=hop, window=window, video_path=video_path, if cached is not None:
cancel_flag=cancel_flag, model_name=embed_model, timestamps, window_vectors = cached
) else:
_log(f"audio_scan: loading {video_path}")
y = _load_audio_ffmpeg(video_path, sr=_SR)
sr = _SR
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return []
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
timestamps, window_vectors = _extract_w2v_windows(
y, sr, hop=hop, window=window, video_path=video_path,
cancel_flag=cancel_flag, model_name=embed_model,
)
if len(timestamps) == 0: if len(timestamps) == 0:
_log("audio_scan: video shorter than window") _log("audio_scan: video shorter than window")
return [] return []
+131 -10
View File
@@ -81,6 +81,21 @@ class ProcessedDB:
" PRIMARY KEY (filename, profile)" " PRIMARY KEY (filename, profile)"
")" ")"
) )
self._con.execute(
"CREATE TABLE IF NOT EXISTS scan_results ("
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
" filename TEXT NOT NULL,"
" profile TEXT NOT NULL DEFAULT 'default',"
" model TEXT NOT NULL,"
" start_time REAL NOT NULL,"
" end_time REAL NOT NULL,"
" score REAL NOT NULL"
")"
)
self._con.execute(
"CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model"
" ON scan_results(filename, profile, model)"
)
self._con.commit() self._con.commit()
def add(self, filename: str, start_time: float, output_path: str, def add(self, filename: str, start_time: float, output_path: str,
@@ -248,18 +263,20 @@ class ProcessedDB:
return sorted(folder_names) return sorted(folder_names)
def get_training_data(self, profile: str, positive_folder: str, def get_training_data(self, profile: str, positive_folder: str,
negative_folder: str = "",
fallback_video_dir: str = "", fallback_video_dir: str = "",
) -> list[tuple[str, list[float], list[float]]]: ) -> list[tuple[str, list[float], list[float], list[float]]]:
"""Build training video_infos from DB data. """Build training video_infos from DB data.
Args: Args:
profile: profile name profile: profile name
positive_folder: export folder name for positive class (e.g. "mp4_Intense") positive_folder: export folder name for positive class (e.g. "mp4_Intense")
negative_folder: export folder name for explicit negatives (optional)
fallback_video_dir: if source_path is empty, try filename in this dir fallback_video_dir: if source_path is empty, try filename in this dir
Returns: Returns:
list of (source_video_path, positive_times, soft_times) per video. list of (source_video_path, positive_times, soft_times, negative_times)
Soft times = clips from any other export folder. per video. Soft times = clips from any other non-negative folder.
""" """
if not self._enabled: if not self._enabled:
return [] return []
@@ -269,8 +286,9 @@ class ProcessedDB:
(profile,), (profile,),
).fetchall() ).fetchall()
# Collect times by video, split by positive vs other folders # Collect times by video, split by folder role
pos_by_video: dict[str, set[float]] = {} pos_by_video: dict[str, set[float]] = {}
neg_by_video: dict[str, set[float]] = {}
soft_by_video: dict[str, set[float]] = {} soft_by_video: dict[str, set[float]] = {}
source_by_filename: dict[str, str] = {} source_by_filename: dict[str, str] = {}
@@ -280,26 +298,43 @@ class ProcessedDB:
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op))) grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent == positive_folder: if grandparent == positive_folder:
pos_by_video.setdefault(fn, set()).add(st) pos_by_video.setdefault(fn, set()).add(st)
elif negative_folder and grandparent == negative_folder:
neg_by_video.setdefault(fn, set()).add(st)
else: else:
soft_by_video.setdefault(fn, set()).add(st) soft_by_video.setdefault(fn, set()).add(st)
# Remove positive times from soft to avoid conflicting labels # Remove positive times from soft/neg to avoid conflicting labels
for fn in pos_by_video: for fn in pos_by_video:
if fn in soft_by_video: if fn in soft_by_video:
soft_by_video[fn] -= pos_by_video[fn] soft_by_video[fn] -= pos_by_video[fn]
if fn in neg_by_video:
neg_by_video[fn] -= pos_by_video[fn]
# Deduplicate nearby markers (spread clips from same position)
def _dedup_times(times: set[float], min_gap: float = 8.0) -> list[float]:
if not times:
return []
ordered = sorted(times)
result = [ordered[0]]
for t in ordered[1:]:
if t - result[-1] >= min_gap:
result.append(t)
return result
# Include videos that have positives OR explicit negatives
all_videos = set(pos_by_video) | set(neg_by_video)
result = [] result = []
for fn in pos_by_video: for fn in all_videos:
sp = source_by_filename.get(fn, "") sp = source_by_filename.get(fn, "")
if not sp or not os.path.exists(sp): if not sp or not os.path.exists(sp):
# Fallback: try video_dir / filename
if fallback_video_dir: if fallback_video_dir:
sp = os.path.join(fallback_video_dir, fn) sp = os.path.join(fallback_video_dir, fn)
if not sp or not os.path.exists(sp): if not sp or not os.path.exists(sp):
continue continue
gt_pos = sorted(pos_by_video[fn]) gt_pos = _dedup_times(pos_by_video.get(fn, set()))
gt_soft = sorted(soft_by_video.get(fn, set())) gt_soft = _dedup_times(soft_by_video.get(fn, set()))
result.append((sp, gt_pos, gt_soft)) gt_neg = _dedup_times(neg_by_video.get(fn, set()))
result.append((sp, gt_pos, gt_soft, gt_neg))
return result return result
def get_training_stats(self, profile: str) -> dict[str, dict]: def get_training_stats(self, profile: str) -> dict[str, dict]:
@@ -329,6 +364,92 @@ class ProcessedDB:
stats[folder_name] = {"videos": len(videos), "clips": clips} stats[folder_name] = {"videos": len(videos), "clips": clips}
return stats return stats
# ── Scan results ─────────────────────────────────────────────
def save_scan_results(self, filename: str, profile: str, model: str,
regions: list[tuple[float, float, float]]) -> None:
"""Replace scan results for (filename, profile, model) with new regions.
regions: list of (start_time, end_time, score).
"""
if not self._enabled:
return
with self._lock:
self._con.execute(
"DELETE FROM scan_results"
" WHERE filename = ? AND profile = ? AND model = ?",
(filename, profile, model),
)
self._con.executemany(
"INSERT INTO scan_results"
" (filename, profile, model, start_time, end_time, score)"
" VALUES (?, ?, ?, ?, ?, ?)",
[(filename, profile, model, s, e, sc) for s, e, sc in regions],
)
self._con.commit()
def get_scan_results(self, filename: str, profile: str
) -> dict[str, list[tuple[int, float, float, float]]]:
"""Return scan results grouped by model.
Returns {model: [(row_id, start_time, end_time, score), ...]} sorted by
start_time.
"""
if not self._enabled:
return {}
rows = self._con.execute(
"SELECT id, model, start_time, end_time, score FROM scan_results"
" WHERE filename = ? AND profile = ?"
" ORDER BY model, start_time",
(filename, profile),
).fetchall()
result: dict[str, list[tuple[int, float, float, float]]] = {}
for row_id, model, s, e, sc in rows:
result.setdefault(model, []).append((row_id, s, e, sc))
return result
def delete_scan_result(self, row_id: int) -> None:
"""Delete a single scan result row."""
if not self._enabled:
return
with self._lock:
self._con.execute("DELETE FROM scan_results WHERE id = ?", (row_id,))
self._con.commit()
def get_scan_models(self, filename: str, profile: str) -> list[str]:
"""Return model names that have scan results for this file."""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT DISTINCT model FROM scan_results"
" WHERE filename = ? AND profile = ? ORDER BY model",
(filename, profile),
).fetchall()
return [r[0] for r in rows]
def get_scanned_filenames(self, profile: str, model: str) -> set[str]:
"""Return filenames that already have scan results for this model."""
if not self._enabled:
return set()
rows = self._con.execute(
"SELECT DISTINCT filename FROM scan_results"
" WHERE profile = ? AND model = ?",
(profile, model),
).fetchall()
return {r[0] for r in rows}
def get_training_filenames(self, profile: str) -> set[str]:
"""Return filenames used in training (have exported clips)."""
if not self._enabled:
return set()
rows = self._con.execute(
"SELECT DISTINCT filename FROM processed WHERE profile = ?",
(profile,),
).fetchall()
return {r[0] for r in rows}
# ── Hidden files ───────────────────────────────────────────
def hide_file(self, filename: str, profile: str = "default") -> None: def hide_file(self, filename: str, profile: str = "default") -> None:
if not self._enabled: if not self._enabled:
return return
+412 -33
View File
@@ -16,6 +16,7 @@ from PyQt6.QtWidgets import (
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip, QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout, QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
QTableWidget, QTableWidgetItem, QTabWidget, QHeaderView,
) )
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
@@ -244,6 +245,15 @@ class TrainDialog(QDialog):
self._cmb_positive.addItem(label, userData=folder_name) self._cmb_positive.addItem(label, userData=folder_name)
form.addRow("Positive class:", self._cmb_positive) form.addRow("Positive class:", self._cmb_positive)
# Negative class selector (optional)
self._cmb_negative = QComboBox()
self._cmb_negative.addItem("(auto only)", userData="")
for folder_name, info in stats.items():
label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)"
self._cmb_negative.addItem(label, userData=folder_name)
self._cmb_negative.currentIndexChanged.connect(lambda: self._debounce.start())
form.addRow("Negative class:", self._cmb_negative)
# Model selector # Model selector
self._cmb_model = QComboBox() self._cmb_model = QComboBox()
for name in _EMBED_MODELS: for name in _EMBED_MODELS:
@@ -251,6 +261,18 @@ class TrainDialog(QDialog):
self._cmb_model.setCurrentText("WAV2VEC2_BASE") self._cmb_model.setCurrentText("WAV2VEC2_BASE")
form.addRow("Model:", self._cmb_model) form.addRow("Model:", self._cmb_model)
# Auto-negative margin (0 = disabled)
self._spn_neg_margin = QDoubleSpinBox()
self._spn_neg_margin.setDecimals(0)
self._spn_neg_margin.setRange(0.0, 600.0)
self._spn_neg_margin.setSingleStep(10.0)
self._spn_neg_margin.setValue(30.0)
self._spn_neg_margin.setSuffix("s")
self._spn_neg_margin.setSpecialValueText("Disabled")
self._spn_neg_margin.setToolTip(
"Auto-sample negatives from regions this far from any marker. 0 = disabled.")
form.addRow("Auto-neg margin:", self._spn_neg_margin)
# Video source directory (fallback for old DB rows without source_path) # Video source directory (fallback for old DB rows without source_path)
self._txt_video_dir = QLineEdit(video_dir) self._txt_video_dir = QLineEdit(video_dir)
self._txt_video_dir.setPlaceholderText("Directory containing source videos") self._txt_video_dir.setPlaceholderText("Directory containing source videos")
@@ -265,7 +287,13 @@ class TrainDialog(QDialog):
btn_browse.setFixedWidth(30) btn_browse.setFixedWidth(30)
btn_browse.clicked.connect(self._browse_video_dir) btn_browse.clicked.connect(self._browse_video_dir)
vid_row.addWidget(btn_browse) vid_row.addWidget(btn_browse)
form.addRow("Video dir:", vid_row) self._lbl_video_dir = QLabel("Video dir:")
self._video_dir_widget = QWidget()
self._video_dir_widget.setLayout(vid_row)
form.addRow(self._lbl_video_dir, self._video_dir_widget)
# Hidden by default — shown only if some videos are missing source_path
self._lbl_video_dir.setVisible(False)
self._video_dir_widget.setVisible(False)
layout.addLayout(form) layout.addLayout(form)
@@ -297,17 +325,32 @@ class TrainDialog(QDialog):
if not folder: if not folder:
self._lbl_stats.setText("No export folder data available.") self._lbl_stats.setText("No export folder data available.")
return return
neg_folder = self._cmb_negative.currentData() or ""
# First check without fallback to see if source_paths are sufficient
video_infos_no_fb = self._db.get_training_data(
self._profile, folder, negative_folder=neg_folder,
)
video_infos = self._db.get_training_data( video_infos = self._db.get_training_data(
self._profile, folder, self._profile, folder, negative_folder=neg_folder,
fallback_video_dir=self._txt_video_dir.text(), fallback_video_dir=self._txt_video_dir.text(),
) )
# Show video dir field only when the fallback helps find extra videos
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
self._lbl_video_dir.setVisible(needs_fallback)
self._video_dir_widget.setVisible(needs_fallback)
n_videos = len(video_infos) n_videos = len(video_infos)
n_pos = sum(len(gt) for _, gt, _ in video_infos) n_pos = sum(len(vi[1]) for vi in video_infos)
n_soft = sum(len(s) for _, _, s in video_infos) n_soft = sum(len(vi[2]) for vi in video_infos)
lines = [f"<b>{n_videos}</b> videos with positive clips"] n_neg = sum(len(vi[3]) for vi in video_infos)
lines.append(f"<b>{n_pos}</b> positive markers, <b>{n_soft}</b> soft/buffer markers") lines = [f"<b>{n_videos}</b> videos"]
lines.append(f"<b>{n_pos}</b> positive, <b>{n_soft}</b> soft/buffer"
+ (f", <b>{n_neg}</b> manual negative" if n_neg else "")
+ " markers")
if n_videos == 0: if n_videos == 0:
lines.append("<i>No source videos found. Set Video dir above.</i>") lines.append("<i>No source videos found. Set Video dir below.</i>")
self._lbl_video_dir.setVisible(True)
self._video_dir_widget.setVisible(True)
elif n_videos < 3: elif n_videos < 3:
lines.append("<i>Recommend at least 3 videos for decent results.</i>") lines.append("<i>Recommend at least 3 videos for decent results.</i>")
self._lbl_stats.setText("<br>".join(lines)) self._lbl_stats.setText("<br>".join(lines))
@@ -316,6 +359,14 @@ class TrainDialog(QDialog):
def positive_folder(self) -> str: def positive_folder(self) -> str:
return self._cmb_positive.currentData() or "" return self._cmb_positive.currentData() or ""
@property
def negative_folder(self) -> str:
return self._cmb_negative.currentData() or ""
@property
def neg_margin(self) -> float:
return self._spn_neg_margin.value()
@property @property
def embed_model(self) -> str: def embed_model(self) -> str:
return self._cmb_model.currentText() return self._cmb_model.currentText()
@@ -332,11 +383,14 @@ class TrainWorker(QThread):
progress = pyqtSignal(str) # per-video status progress = pyqtSignal(str) # per-video status
def __init__(self, video_infos: list, model_path: str, def __init__(self, video_infos: list, model_path: str,
embed_model: str | None = None): embed_model: str | None = None, n_workers: int = 4,
neg_margin: float = 120.0):
super().__init__() super().__init__()
self._video_infos = video_infos self._video_infos = video_infos
self._model_path = model_path self._model_path = model_path
self._embed_model = embed_model self._embed_model = embed_model
self._n_workers = n_workers
self._neg_margin = neg_margin
self._cancel = False self._cancel = False
def cancel(self) -> None: def cancel(self) -> None:
@@ -349,8 +403,11 @@ class TrainWorker(QThread):
result = train_classifier( result = train_classifier(
self._video_infos, self._video_infos,
model_path=self._model_path, model_path=self._model_path,
neg_margin=self._neg_margin,
embed_model=self._embed_model, embed_model=self._embed_model,
cancel_flag=self, cancel_flag=self,
n_workers=self._n_workers,
progress_cb=self.progress.emit,
) )
if self._cancel: if self._cancel:
return return
@@ -363,6 +420,152 @@ class TrainWorker(QThread):
self.error.emit(str(e)) self.error.emit(str(e))
class ScanResultsPanel(QWidget):
"""Tabbed panel showing scan results per model, with seek-on-click and delete."""
seek_requested = pyqtSignal(float) # request main window to seek to time
export_requested = pyqtSignal(list) # emit list of (start, end, score) to export
def __init__(self, db, parent=None):
super().__init__(parent)
self._db = db
self._filename = ""
self._profile = ""
layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(2)
self._tabs = QTabWidget()
self._tabs.setTabsClosable(False)
layout.addWidget(self._tabs)
btn_row = QHBoxLayout()
self._btn_export = QPushButton("Export Scan Results")
self._btn_export.setToolTip("Export clips from the active tab's scan results")
self._btn_export.clicked.connect(self._on_export)
btn_row.addStretch()
btn_row.addWidget(self._btn_export)
layout.addLayout(btn_row)
def load_for_file(self, filename: str, profile: str) -> None:
"""Load saved scan results from DB for a file."""
self._filename = filename
self._profile = profile
self._tabs.clear()
results = self._db.get_scan_results(filename, profile)
for model, rows in results.items():
self._add_tab(model, rows)
def add_scan_results(self, model: str,
regions: list[tuple[float, float, float]]) -> None:
"""Add/replace a tab with new scan results and save to DB."""
# Save to DB
self._db.save_scan_results(self._filename, self._profile, model, regions)
# Build row data with IDs from DB
db_results = self._db.get_scan_results(self._filename, self._profile)
rows = db_results.get(model, [])
# Remove existing tab for this model
for i in range(self._tabs.count()):
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
self._tabs.removeTab(i)
break
self._add_tab(model, rows)
# Switch to the new tab
for i in range(self._tabs.count()):
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
self._tabs.setCurrentIndex(i)
break
def _add_tab(self, model: str,
rows: list[tuple[int, float, float, float]]) -> None:
"""Create a table tab. rows: [(row_id, start, end, score), ...]"""
table = QTableWidget(len(rows), 3)
table.setHorizontalHeaderLabels(["Time", "End", "Score"])
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
table.setSelectionMode(QTableWidget.SelectionMode.ExtendedSelection)
table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
table.verticalHeader().setVisible(False)
header = table.horizontalHeader()
header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
for i, (row_id, start, end, score) in enumerate(rows):
t_item = QTableWidgetItem(format_time(start))
t_item.setData(Qt.ItemDataRole.UserRole, row_id)
t_item.setData(Qt.ItemDataRole.UserRole + 1, start)
table.setItem(i, 0, t_item)
e_item = QTableWidgetItem(format_time(end))
e_item.setData(Qt.ItemDataRole.UserRole, end)
table.setItem(i, 1, e_item)
table.setItem(i, 2, QTableWidgetItem(f"{score:.2f}"))
table.itemSelectionChanged.connect(
lambda t=table: self._on_selection_changed(t))
self._tabs.addTab(table, f"{model} ({len(rows)})")
def _on_selection_changed(self, table: QTableWidget) -> None:
items = table.selectedItems()
if items:
row = items[0].row()
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
if start is not None:
self.seek_requested.emit(float(start))
def delete_selected(self) -> None:
"""Delete selected rows from active tab and DB."""
table = self._tabs.currentWidget()
if not isinstance(table, QTableWidget):
return
rows_to_delete = sorted(
{idx.row() for idx in table.selectedIndexes()}, reverse=True)
tab_idx = self._tabs.currentIndex()
model = self._tabs.tabText(tab_idx).rsplit(" (", 1)[0]
for row in rows_to_delete:
row_id = table.item(row, 0).data(Qt.ItemDataRole.UserRole)
if row_id is not None:
self._db.delete_scan_result(row_id)
table.removeRow(row)
# Update tab title with new count
count = table.rowCount()
self._tabs.setTabText(tab_idx, f"{model} ({count})")
def _get_tab_regions(self, table: QTableWidget
) -> list[tuple[float, float, float]]:
"""Extract (start, end, score) from a table widget."""
regions = []
for row in range(table.rowCount()):
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
end = table.item(row, 1).data(Qt.ItemDataRole.UserRole)
score = float(table.item(row, 2).text())
regions.append((float(start), float(end), score))
return regions
def _on_export(self) -> None:
table = self._tabs.currentWidget()
if not isinstance(table, QTableWidget):
return
regions = self._get_tab_regions(table)
if regions:
self.export_requested.emit(regions)
def current_regions(self) -> list[tuple[float, float, float]]:
"""Return (start, end, score) for all rows in the active tab."""
table = self._tabs.currentWidget()
if not isinstance(table, QTableWidget):
return []
return self._get_tab_regions(table)
def has_results(self) -> bool:
return self._tabs.count() > 0
def keyPressEvent(self, event):
if event.key() in (Qt.Key.Key_Delete, Qt.Key.Key_Backspace):
self.delete_selected()
else:
super().keyPressEvent(event)
class TimelineWidget(QWidget): class TimelineWidget(QWidget):
cursor_changed = pyqtSignal(float) # emits position in seconds cursor_changed = pyqtSignal(float) # emits position in seconds
seek_changed = pyqtSignal(float) # emits seek position (lock mode) seek_changed = pyqtSignal(float) # emits seek position (lock mode)
@@ -1710,6 +1913,15 @@ class MainWindow(QMainWindow):
self._btn_train.clicked.connect(self._open_train_dialog) self._btn_train.clicked.connect(self._open_train_dialog)
self._train_worker: TrainWorker | None = None self._train_worker: TrainWorker | None = None
self._btn_scan_all = QPushButton("Scan All")
self._btn_scan_all.setToolTip("Scan all playlist videos that haven't been scanned yet")
self._btn_scan_all.clicked.connect(self._start_scan_all)
self._scan_all_queue: list[str] = []
self._cmb_scan_model = QComboBox()
self._cmb_scan_model.setToolTip("Trained embedding model to use for scanning")
self._cmb_scan_model.setMinimumWidth(120)
self._spn_auto_fuse = QDoubleSpinBox() self._spn_auto_fuse = QDoubleSpinBox()
self._spn_auto_fuse.setDecimals(1) self._spn_auto_fuse.setDecimals(1)
self._spn_auto_fuse.setRange(0.0, 60.0) self._spn_auto_fuse.setRange(0.0, 60.0)
@@ -1800,6 +2012,7 @@ class MainWindow(QMainWindow):
if idx >= 0: if idx >= 0:
self._cmb_profile.setCurrentIndex(idx) self._cmb_profile.setCurrentIndex(idx)
self._cmb_profile.activated.connect(self._on_profile_activated) self._cmb_profile.activated.connect(self._on_profile_activated)
self._refresh_scan_models()
self._btn_shortcuts = QPushButton("?") self._btn_shortcuts = QPushButton("?")
self._btn_shortcuts.setFixedWidth(28) self._btn_shortcuts.setFixedWidth(28)
@@ -1864,11 +2077,13 @@ class MainWindow(QMainWindow):
settings_row.addWidget(self._chk_rand_portrait) settings_row.addWidget(self._chk_rand_portrait)
settings_row.addWidget(self._chk_rand_square) settings_row.addWidget(self._chk_rand_square)
settings_row.addWidget(self._chk_track) settings_row.addWidget(self._chk_track)
settings_row.addWidget(self._cmb_scan_model)
settings_row.addWidget(self._btn_scan) settings_row.addWidget(self._btn_scan)
settings_row.addWidget(self._btn_auto_export) settings_row.addWidget(self._btn_auto_export)
settings_row.addWidget(self._spn_auto_fuse) settings_row.addWidget(self._spn_auto_fuse)
settings_row.addWidget(self._sld_threshold) settings_row.addWidget(self._sld_threshold)
settings_row.addWidget(self._btn_train) settings_row.addWidget(self._btn_train)
settings_row.addWidget(self._btn_scan_all)
settings_row.addStretch() settings_row.addStretch()
self._lbl_status = QLabel() self._lbl_status = QLabel()
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;") self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
@@ -1918,13 +2133,20 @@ class MainWindow(QMainWindow):
left_layout.addLayout(left_top) left_layout.addLayout(left_top)
left_layout.addWidget(self._playlist) left_layout.addWidget(self._playlist)
# Scan results panel (right side)
self._scan_panel = ScanResultsPanel(self._db)
self._scan_panel.seek_requested.connect(self._on_scan_seek)
self._scan_panel.export_requested.connect(self._on_scan_export)
# Root: horizontal splitter # Root: horizontal splitter
splitter = QSplitter(Qt.Orientation.Horizontal) splitter = QSplitter(Qt.Orientation.Horizontal)
splitter.addWidget(left) splitter.addWidget(left)
splitter.addWidget(right) splitter.addWidget(right)
splitter.setSizes([200, 900]) splitter.addWidget(self._scan_panel)
splitter.setSizes([200, 900, 200])
splitter.setCollapsible(0, False) splitter.setCollapsible(0, False)
splitter.setCollapsible(1, False) splitter.setCollapsible(1, False)
splitter.setCollapsible(2, True)
self.setCentralWidget(splitter) self.setCentralWidget(splitter)
self.setStatusBar(None) self.setStatusBar(None)
@@ -2061,6 +2283,7 @@ class MainWindow(QMainWindow):
self._btn_delete.setEnabled(False) self._btn_delete.setEnabled(False)
self._update_next_label() self._update_next_label()
self._apply_playlist_filters() self._apply_playlist_filters()
self._refresh_scan_models()
if self._file_path: if self._file_path:
self._refresh_markers() self._refresh_markers()
_log(f"Profile switched: {text}") _log(f"Profile switched: {text}")
@@ -2184,7 +2407,13 @@ class MainWindow(QMainWindow):
if self._scan_worker and self._scan_worker.isRunning(): if self._scan_worker and self._scan_worker.isRunning():
self._scan_worker.cancel() self._scan_worker.cancel()
self._cleanup_scan_worker() self._cleanup_scan_worker()
self._scan_all_queue.clear()
self._btn_scan.setEnabled(True) self._btn_scan.setEnabled(True)
self._btn_scan_all.setEnabled(True)
# Load saved scan results for this file
if self._file_path:
filename = os.path.basename(self._file_path)
self._scan_panel.load_for_file(filename, self._profile)
dur = self._mpv.get_duration() dur = self._mpv.get_duration()
self._timeline.set_duration(dur) self._timeline.set_duration(dur)
@@ -2653,8 +2882,42 @@ class MainWindow(QMainWindow):
return return
self._step_cursor(markers[0][0] - self._cursor) # wrap to first self._step_cursor(markers[0][0] - self._cursor) # wrap to first
def _load_selected_scan_model(self) -> tuple:
"""Load the classifier selected in the scan model combo.
Returns (model_dict, label_str) or (None, "") on failure.
"""
from core.audio_scan import load_classifier, default_model_path
sel = self._cmb_scan_model.currentText()
if not sel or sel == "(no model)":
self._show_status("No trained model — click Train first")
return None, ""
embed_name = None if sel == "(legacy)" else sel
model_path = default_model_path(self._profile, embed_name)
model = load_classifier(model_path)
if model is None:
self._show_status(f"Model file missing: {model_path}")
return None, ""
return model, sel
def _refresh_scan_models(self) -> None:
"""Populate the scan model combo with trained models for the current profile."""
from core.audio_scan import list_trained_models
prev = self._cmb_scan_model.currentText()
self._cmb_scan_model.clear()
models = list_trained_models(self._profile)
if not models:
self._cmb_scan_model.addItem("(no model)")
else:
for m in models:
self._cmb_scan_model.addItem(m if m else "(legacy)")
# Restore previous selection if still available
idx = self._cmb_scan_model.findText(prev)
if idx >= 0:
self._cmb_scan_model.setCurrentIndex(idx)
def _cleanup_scan_worker(self) -> None: def _cleanup_scan_worker(self) -> None:
"""Disconnect signals and schedule deletion of old scan worker.""" """Disconnect signals, cancel, and schedule deletion of old scan worker."""
if self._scan_worker is not None: if self._scan_worker is not None:
try: try:
self._scan_worker.scan_done.disconnect() self._scan_worker.scan_done.disconnect()
@@ -2662,8 +2925,8 @@ class MainWindow(QMainWindow):
self._scan_worker.progress.disconnect() self._scan_worker.progress.disconnect()
except TypeError: except TypeError:
pass # already disconnected pass # already disconnected
self._scan_worker.cancel()
if self._scan_worker.isRunning(): if self._scan_worker.isRunning():
# QThread.finished fires when run() returns, even on cancel
self._scan_worker.finished.connect(self._scan_worker.deleteLater) self._scan_worker.finished.connect(self._scan_worker.deleteLater)
else: else:
self._scan_worker.deleteLater() self._scan_worker.deleteLater()
@@ -2682,17 +2945,14 @@ class MainWindow(QMainWindow):
threshold = self._sld_threshold.value() threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path model, model_label = self._load_selected_scan_model()
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
if model is None: if model is None:
self._show_status("No trained model — click Train first")
return return
self._btn_scan.setEnabled(False) self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path self._scan_file_path = self._file_path
self._show_status("Scanning...") self._scan_model_label = model_label
self._show_status(f"Scanning ({model_label})...")
self._scan_worker = ScanWorker( self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold, self._file_path, model=model, threshold=threshold,
) )
@@ -2708,6 +2968,10 @@ class MainWindow(QMainWindow):
if self._file_path != getattr(self, '_scan_file_path', None): if self._file_path != getattr(self, '_scan_file_path', None):
return return
self._timeline.set_scan_regions(regions) self._timeline.set_scan_regions(regions)
model_label = getattr(self, '_scan_model_label', '')
if model_label and self._file_path:
filename = os.path.basename(self._file_path)
self._scan_panel.add_scan_results(model_label, regions)
self._show_status(f"Scan complete: {len(regions)} matching regions") self._show_status(f"Scan complete: {len(regions)} matching regions")
def _on_scan_error(self, msg: str) -> None: def _on_scan_error(self, msg: str) -> None:
@@ -2715,6 +2979,105 @@ class MainWindow(QMainWindow):
self._btn_auto_export.setEnabled(True) self._btn_auto_export.setEnabled(True)
self._show_status(f"Scan error: {msg}") self._show_status(f"Scan error: {msg}")
def _on_scan_seek(self, t: float) -> None:
"""Seek player when a scan result row is clicked."""
if self._file_path:
self._cursor = t
self._mpv.seek(t)
self._timeline.set_cursor(t)
dur = self._mpv.get_duration()
self._lbl_time.setText(f"{format_time(t)} / {format_time(dur)}")
def _on_scan_export(self, regions: list) -> None:
"""Export clips from scan results panel."""
if not self._file_path or not regions:
return
if self._export_worker and self._export_worker.isRunning():
self._show_status("Export already running…")
return
self._auto_export_regions(regions)
# ── Scan All ───────────────────────────────────────────────
def _start_scan_all(self) -> None:
"""Scan all playlist videos not yet scanned with the selected model."""
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
model, model_label = self._load_selected_scan_model()
if model is None:
return
# Build queue: playlist files minus already-scanned and training files
all_paths = self._playlist._paths
scanned = self._db.get_scanned_filenames(self._profile, model_label)
training = self._db.get_training_filenames(self._profile)
skip = scanned | training
self._scan_all_queue = [
p for p in all_paths if os.path.basename(p) not in skip
]
if not self._scan_all_queue:
self._show_status("All videos already scanned or used for training")
return
self._scan_all_model = model
self._scan_all_model_label = model_label
self._scan_all_profile = self._profile
self._scan_all_total = len(self._scan_all_queue)
self._btn_scan_all.setEnabled(False)
self._btn_scan.setEnabled(False)
self._show_status(
f"Scan All: 0/{self._scan_all_total} ({model_label})")
self._scan_all_next()
def _scan_all_next(self) -> None:
"""Start scanning the next video in the queue."""
if not self._scan_all_queue:
self._btn_scan_all.setEnabled(True)
self._btn_scan.setEnabled(True)
done = self._scan_all_total
self._show_status(f"Scan All complete: {done} videos scanned")
return
self._cleanup_scan_worker()
path = self._scan_all_queue.pop(0)
remaining = self._scan_all_total - len(self._scan_all_queue)
self._scan_all_current_path = path
self._show_status(
f"Scan All: {remaining}/{self._scan_all_total}"
f"{os.path.basename(path)}")
threshold = self._sld_threshold.value()
self._scan_worker = ScanWorker(
path, model=self._scan_all_model, threshold=threshold,
)
self._scan_worker.scan_done.connect(self._on_scan_all_done)
self._scan_worker.error.connect(self._on_scan_all_error)
self._scan_worker.start()
def _on_scan_all_done(self, regions: list) -> None:
"""Save batch scan results and continue to next video."""
path = getattr(self, '_scan_all_current_path', '')
model_label = getattr(self, '_scan_all_model_label', '')
if path and model_label:
filename = os.path.basename(path)
profile = getattr(self, '_scan_all_profile', self._profile)
self._db.save_scan_results(
filename, profile, model_label, regions)
# If this is the currently loaded file, update the panel
if self._file_path and os.path.basename(self._file_path) == filename:
self._scan_panel.load_for_file(filename, self._profile)
self._timeline.set_scan_regions(regions)
self._scan_all_next()
def _on_scan_all_error(self, msg: str) -> None:
"""Log error and continue to next video."""
path = getattr(self, '_scan_all_current_path', '')
_log(f"Scan All error on {os.path.basename(path)}: {msg}")
self._scan_all_next()
# ── Training ──────────────────────────────────────────────── # ── Training ────────────────────────────────────────────────
def _cleanup_train_worker(self) -> None: def _cleanup_train_worker(self) -> None:
@@ -2751,6 +3114,8 @@ class MainWindow(QMainWindow):
return return
pos_folder = dlg.positive_folder pos_folder = dlg.positive_folder
neg_folder = dlg.negative_folder
neg_margin = dlg.neg_margin
embed_model = dlg.embed_model embed_model = dlg.embed_model
video_dir = dlg.video_dir video_dir = dlg.video_dir
if not pos_folder: if not pos_folder:
@@ -2762,20 +3127,22 @@ class MainWindow(QMainWindow):
self._settings.setValue("train_video_dir", video_dir) self._settings.setValue("train_video_dir", video_dir)
video_infos = self._db.get_training_data( video_infos = self._db.get_training_data(
self._profile, pos_folder, fallback_video_dir=video_dir, self._profile, pos_folder, negative_folder=neg_folder,
fallback_video_dir=video_dir,
) )
if not video_infos: if not video_infos:
self._show_status("No training data found for this subprofile") self._show_status("No training data found for this subprofile")
return return
from core.audio_scan import default_model_path from core.audio_scan import default_model_path
model_path = default_model_path(self._profile) model_path = default_model_path(self._profile, embed_model)
self._cleanup_train_worker() self._cleanup_train_worker()
self._btn_train.setEnabled(False) self._btn_train.setEnabled(False)
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...") self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
self._train_worker = TrainWorker(video_infos, model_path, embed_model) n_workers = self._spn_workers.value()
self._train_worker = TrainWorker(video_infos, model_path, embed_model, n_workers, neg_margin)
self._train_worker.train_done.connect(self._on_train_done) self._train_worker.train_done.connect(self._on_train_done)
self._train_worker.error.connect(self._on_train_error) self._train_worker.error.connect(self._on_train_error)
self._train_worker.progress.connect(self._show_status) self._train_worker.progress.connect(self._show_status)
@@ -2783,6 +3150,7 @@ class MainWindow(QMainWindow):
def _on_train_done(self, model_path: str): def _on_train_done(self, model_path: str):
self._btn_train.setEnabled(True) self._btn_train.setEnabled(True)
self._refresh_scan_models()
self._show_status(f"Model trained and saved") self._show_status(f"Model trained and saved")
_log(f"Training complete: {model_path}") _log(f"Training complete: {model_path}")
@@ -2810,22 +3178,19 @@ class MainWindow(QMainWindow):
threshold = self._sld_threshold.value() threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path model, model_label = self._load_selected_scan_model()
model_path = default_model_path(self._profile) if model is None:
model = load_classifier(model_path)
if model is not None:
self._scan_file_path = self._file_path
self._show_status("Auto: scanning with classifier...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
else:
self._show_status("Auto: no trained model — click Train first")
self._btn_auto_export.setEnabled(True) self._btn_auto_export.setEnabled(True)
self._btn_scan.setEnabled(True) self._btn_scan.setEnabled(True)
return return
self._scan_file_path = self._file_path
self._scan_model_label = model_label
self._show_status(f"Auto: scanning ({model_label})...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
self._scan_worker.scan_done.connect(self._on_auto_scan_done) self._scan_worker.scan_done.connect(self._on_auto_scan_done)
self._scan_worker.error.connect(self._on_scan_error) self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status) self._scan_worker.progress.connect(self._show_status)
@@ -2879,7 +3244,15 @@ class MainWindow(QMainWindow):
return return
self._timeline.set_scan_regions(regions) self._timeline.set_scan_regions(regions)
# Also save to scan panel
model_label = getattr(self, '_scan_model_label', '')
if model_label and self._file_path:
self._scan_panel.add_scan_results(model_label, regions)
self._auto_export_regions(regions)
def _auto_export_regions(self, regions: list) -> None:
"""Export clips from a list of (start, end, score) regions."""
if not regions: if not regions:
self._show_status("Auto: no regions found") self._show_status("Auto: no regions found")
self._btn_auto_export.setEnabled(True) self._btn_auto_export.setEnabled(True)
@@ -2896,6 +3269,7 @@ class MainWindow(QMainWindow):
# Build export jobs — one 8s clip per position # Build export jobs — one 8s clip per position
folder = self._txt_folder.text() folder = self._txt_folder.text()
name = self._txt_name.text() or "clip" name = self._txt_name.text() or "clip"
self._auto_export_name = name
fmt = self._cmb_format.currentText() fmt = self._cmb_format.currentText()
image_sequence = fmt == "WebP sequence" image_sequence = fmt == "WebP sequence"
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
@@ -2959,7 +3333,7 @@ class MainWindow(QMainWindow):
"""Record each auto-exported clip to DB.""" """Record each auto-exported clip to DB."""
# Find the start_time for this clip from stashed positions # Find the start_time for this clip from stashed positions
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042" counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
name = self._txt_name.text() or "clip" name = getattr(self, '_auto_export_name', self._txt_name.text() or "clip")
start_t = None start_t = None
for t, c in self._auto_export_positions: for t, c in self._auto_export_positions:
if counter_str == f"{name}_{c:03d}": if counter_str == f"{name}_{c:03d}":
@@ -3306,6 +3680,11 @@ class MainWindow(QMainWindow):
# Cancel background workers to prevent callbacks into dead objects. # Cancel background workers to prevent callbacks into dead objects.
self._cleanup_scan_worker() self._cleanup_scan_worker()
self._cleanup_train_worker() self._cleanup_train_worker()
if self._export_worker and self._export_worker.isRunning():
self._export_worker.cancel()
self._export_worker.wait(3000)
if hasattr(self, '_db_worker') and self._db_worker and self._db_worker.isRunning():
self._db_worker.wait(1000)
# Stop timers first to prevent callbacks into dead objects. # Stop timers first to prevent callbacks into dead objects.
self._preview_timer.stop() self._preview_timer.stop()
self._mpv._render_timer.stop() self._mpv._render_timer.stop()