From 6870e5aaf38e4284f41fd4a8cd4086984334c6fe Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 18 Apr 2026 16:12:52 +0200 Subject: [PATCH] feat: scan results panel, model switching, batch scan, and training improvements - Replace librosa with direct ffmpeg subprocess for 10x faster audio loading - Add ScanResultsPanel with per-model tabs, seek-on-click, delete, and export - Persist scan results in DB per (filename, profile, model) - Add model selector dropdown to switch between trained embedding models - Add "Scan All" button for batch scanning playlist videos - Support manual negative examples via negative class folder - Configurable auto-negative margin (default 30s, 0 = disabled) - Deduplicate nearby training markers (8s min gap) - Parallel audio loading with ThreadPoolExecutor during training - Progress callbacks from training for UI status updates - Cache bypass in scan_video (skip audio loading when embeddings cached) - Move all caches (models, embeddings, downloads) into project directory - Add 8cut.sh launcher script with auto venv/conda detection - Fix 11 bugs across thread safety, signal handling, and state management Co-Authored-By: Claude Opus 4.6 --- .gitignore | 1 + 8cut.sh | 29 +++ core/audio_scan.py | 198 +++++++++++++++++--- core/db.py | 141 +++++++++++++- main.py | 445 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 741 insertions(+), 73 deletions(-) create mode 100755 8cut.sh diff --git a/.gitignore b/.gitignore index 9e8fb2c..e0f8ad3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,6 @@ __pycache__/ .worktrees/ .venv/ models/ +cache/ *.joblib *.pt diff --git a/8cut.sh b/8cut.sh new file mode 100755 index 0000000..cfd0620 --- /dev/null +++ b/8cut.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Launch 8-cut with auto-detected venv/conda environment +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ENV_NAME="8cut" +CONDA_PREFIX_BASE="/media/p5/miniforge3" + +# 1. Try .venv in project dir +if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then + source "$SCRIPT_DIR/.venv/bin/activate" + exec python "$SCRIPT_DIR/main.py" "$@" +fi + +# 2. Try conda env (works without shell init) +CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python" +if [ -x "$CONDA_PYTHON" ]; then + exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@" +fi + +# 3. Try conda via shell hook (interactive shells) +if command -v conda &>/dev/null; then + eval "$(conda shell.bash hook 2>/dev/null)" + if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then + conda activate "$ENV_NAME" + exec python "$SCRIPT_DIR/main.py" "$@" + fi +fi + +# 4. Fallback to system Python +exec python3 "$SCRIPT_DIR/main.py" "$@" diff --git a/core/audio_scan.py b/core/audio_scan.py index 5fd15a8..57d4152 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -2,15 +2,39 @@ import hashlib import os +import subprocess +import warnings import numpy as np -import librosa -from .paths import _log +from .paths import _bin, _log _SR = 16000 # lower sr = faster + + +def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray: + """Load audio from any file as mono float32 numpy array using ffmpeg directly.""" + cmd = [ + _bin("ffmpeg"), "-i", path, + "-vn", # skip video + "-ac", "1", # mono + "-ar", str(sr), # resample + "-f", "f32le", # raw 32-bit float little-endian + "-loglevel", "error", + "pipe:1", + ] + proc = subprocess.run(cmd, capture_output=True, timeout=300) + if proc.returncode != 0: + raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}") + return np.frombuffer(proc.stdout, dtype=np.float32) _WINDOW = 8.0 # seconds -_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models") -_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v") +_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_MODEL_DIR = os.path.join(_PROJECT_DIR, "models") +_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v") +_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads") + +# Redirect torch hub and huggingface downloads into the project +os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR) +os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface")) # --------------------------------------------------------------------------- # Embedding extraction (lazy-loaded) @@ -33,7 +57,7 @@ _EMBED_MODELS = { _DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" _BEATS_CHECKPOINT = os.path.join( - os.path.expanduser("~"), ".cache", "huggingface", "hub", + _DL_CACHE_DIR, "huggingface", "hub", "models--lpepino--beats_ckpts", "snapshots", "5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt", ) @@ -86,6 +110,30 @@ def _w2v_cache_path(video_path: str, hop: float, window: float, return os.path.join(_W2V_CACHE_DIR, f"{h}.npz") +def _w2v_cache_exists(video_path: str, hop: float, window: float, + model_name: str | None = None) -> bool: + """Check if embedding cache exists for a video.""" + try: + path = _w2v_cache_path(video_path, hop, window, model_name) + return os.path.exists(path) + except Exception: + return False + + +def _w2v_cache_load(video_path: str, hop: float, window: float, + model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None: + """Load embeddings from cache. Returns (timestamps, embeddings) or None.""" + try: + path = _w2v_cache_path(video_path, hop, window, model_name) + if os.path.exists(path): + data = np.load(path) + _log(f"audio_scan: cache hit ({path})") + return data["timestamps"], data["embeddings"] + except Exception as e: + _log(f"audio_scan: cache read failed: {e}") + return None + + def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, hop: float = 1.0, window: float = _WINDOW, video_path: str | None = None, @@ -162,6 +210,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], gt_soft: list[float], tolerance: float = 12.0, neg_margin: float = 120.0, model_name: str | None = None, + gt_negative: list[float] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Extract embeddings only near positives and distant negatives. @@ -180,13 +229,24 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], if 0 <= t <= duration - _WINDOW: pos_times.add(int(t)) - # Negative windows: every 4s, far from any marker + # Manual negative windows: near explicit negative markers + manual_neg_times = set() + if gt_negative: + for gt in gt_negative: + for offset in range(-int(tolerance), int(tolerance) + 1): + t = gt + offset + if 0 <= t <= duration - _WINDOW: + manual_neg_times.add(int(t)) + # Don't let manual negatives overlap with positives + manual_neg_times -= pos_times + + # Auto negative windows: every 4s, far from any marker (skip if margin <= 0) neg_times = set() for t in range(0, int(duration - _WINDOW), 4): - if min((abs(t - g) for g in all_gt), default=9999) > neg_margin: + if neg_margin > 0 and min((abs(t - g) for g in all_gt), default=9999) > neg_margin: neg_times.add(t) - all_times = sorted(pos_times | neg_times) + all_times = sorted(pos_times | neg_times | manual_neg_times) # Filter out windows that go past the end valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)] @@ -225,9 +285,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], for i, t in enumerate(timestamps): di = min((abs(t - g) for g in gt_intense), default=9999) da = min((abs(t - g) for g in all_gt), default=9999) + dm = min((abs(t - g) for g in (gt_negative or [])), default=9999) if di < tolerance: labels[i] = 1 - elif da > neg_margin: + elif dm < tolerance or (neg_margin > 0 and da > neg_margin): labels[i] = -1 return timestamps, embeddings, labels @@ -241,7 +302,9 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]], tolerance: float = 12.0, neg_margin: float = 120.0, embed_model: str | None = None, - cancel_flag: object = None) -> dict: + cancel_flag: object = None, + n_workers: int = 4, + progress_cb: object = None) -> dict: """Train a classifier from labeled videos. Args: @@ -250,24 +313,62 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]], tolerance/neg_margin: labeling parameters embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE cancel_flag: object with _cancel attribute; if set, training aborts early + n_workers: number of threads for parallel audio loading Returns: dict with 'classifier', 'embed_model', and metadata, or None on failure. """ + from concurrent.futures import ThreadPoolExecutor, as_completed from sklearn.ensemble import GradientBoostingClassifier - all_X, all_y = [], [] + def _progress(msg: str) -> None: + _log(msg) + if progress_cb: + progress_cb(msg) - for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos): + def _load_audio(path: str) -> np.ndarray: + return _load_audio_ffmpeg(path, sr=_SR) + + # Phase 1: load all audio in parallel (cap workers — disk I/O bound) + n = len(video_infos) + load_workers = min(n_workers, 4) + _progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...") + audio_data: dict[int, np.ndarray] = {} + with ThreadPoolExecutor(max_workers=load_workers) as pool: + future_to_idx = { + pool.submit(_load_audio, vi[0]): i + for i, vi in enumerate(video_infos) + } + failed = set() + for future in as_completed(future_to_idx): + if cancel_flag and getattr(cancel_flag, '_cancel', False): + _log("audio_scan: training cancelled") + return None + idx = future_to_idx[future] + try: + audio_data[idx] = future.result() + except Exception as e: + _log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}") + failed.add(idx) + _progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}") + + # Phase 2: extract embeddings sequentially on GPU + _progress(f"Extracting embeddings: 0/{n}") + all_X, all_y = [], [] + for vi, vinfo in enumerate(video_infos): + if vi in failed: + continue + vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2] + gt_negative = vinfo[3] if len(vinfo) > 3 else [] if cancel_flag and getattr(cancel_flag, '_cancel', False): _log("audio_scan: training cancelled") return None - _log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}") - y, _ = librosa.load(vpath, sr=_SR, mono=True) + _progress(f"Extracting embeddings: {vi+1}/{n}") + y = audio_data.pop(vi) timestamps, embeddings, labels = _extract_w2v_targeted( y, _SR, gt_intense, gt_soft, tolerance, neg_margin, - model_name=embed_model, + model_name=embed_model, gt_negative=gt_negative, ) if len(timestamps) == 0: continue @@ -306,6 +407,7 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]], train_idx = np.concatenate([pos_idx, neg_sample]) rng.shuffle(train_idx) + _progress(f"Fitting classifier on {len(train_idx)} samples...") clf = GradientBoostingClassifier( n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42, ) @@ -334,11 +436,41 @@ def load_classifier(model_path: str) -> dict | None: return joblib.load(model_path) -def default_model_path(profile_name: str = "default") -> str: - """Return the default path for a profile's classifier model.""" +def default_model_path(profile_name: str = "default", + embed_model: str | None = None) -> str: + """Return the path for a profile's classifier model. + + When embed_model is given the file is ``{profile}_{model}.joblib``, + otherwise ``{profile}.joblib`` (legacy single-model layout). + """ + if embed_model: + return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib") return os.path.join(_MODEL_DIR, f"{profile_name}.joblib") +def list_trained_models(profile_name: str = "default") -> list[str]: + """Return embedding model names that have a trained .joblib for *profile_name*. + + Looks for files matching ``{profile}_{MODEL}.joblib`` in the models dir. + """ + prefix = f"{profile_name}_" + suffix = ".joblib" + result = [] + if not os.path.isdir(_MODEL_DIR): + return result + for fname in os.listdir(_MODEL_DIR): + if fname.startswith(prefix) and fname.endswith(suffix): + model_name = fname[len(prefix):-len(suffix)] + if model_name in _EMBED_MODELS: + result.append(model_name) + # Also check legacy {profile}.joblib + legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib") + if os.path.exists(legacy) and not result: + # Legacy model — we don't know the embed model, but it's usable + result.append("") + return sorted(result) + + # --------------------------------------------------------------------------- # Scanning # --------------------------------------------------------------------------- @@ -359,22 +491,28 @@ def scan_video( _log("audio_scan: no model provided") return [] - _log(f"audio_scan: loading {video_path}") - y, sr = librosa.load(video_path, sr=_SR, mono=True) - duration = len(y) / sr - _log(f"audio_scan: {duration:.1f}s loaded, extracting features...") - - if cancel_flag and getattr(cancel_flag, '_cancel', False): - return [] - clf = model["classifier"] embed_model = model.get("embed_model") - _log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...") - timestamps, window_vectors = _extract_w2v_windows( - y, sr, hop=hop, window=window, video_path=video_path, - cancel_flag=cancel_flag, model_name=embed_model, - ) + # Try cache first — skip expensive audio loading if embeddings exist + cached = _w2v_cache_load(video_path, hop, window, embed_model) + if cached is not None: + timestamps, window_vectors = cached + else: + _log(f"audio_scan: loading {video_path}") + y = _load_audio_ffmpeg(video_path, sr=_SR) + sr = _SR + _log(f"audio_scan: {len(y)/sr:.1f}s loaded") + + if cancel_flag and getattr(cancel_flag, '_cancel', False): + return [] + + _log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...") + timestamps, window_vectors = _extract_w2v_windows( + y, sr, hop=hop, window=window, video_path=video_path, + cancel_flag=cancel_flag, model_name=embed_model, + ) + if len(timestamps) == 0: _log("audio_scan: video shorter than window") return [] diff --git a/core/db.py b/core/db.py index 68eb25f..4c1aa1b 100644 --- a/core/db.py +++ b/core/db.py @@ -81,6 +81,21 @@ class ProcessedDB: " PRIMARY KEY (filename, profile)" ")" ) + self._con.execute( + "CREATE TABLE IF NOT EXISTS scan_results (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " filename TEXT NOT NULL," + " profile TEXT NOT NULL DEFAULT 'default'," + " model TEXT NOT NULL," + " start_time REAL NOT NULL," + " end_time REAL NOT NULL," + " score REAL NOT NULL" + ")" + ) + self._con.execute( + "CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model" + " ON scan_results(filename, profile, model)" + ) self._con.commit() def add(self, filename: str, start_time: float, output_path: str, @@ -248,18 +263,20 @@ class ProcessedDB: return sorted(folder_names) def get_training_data(self, profile: str, positive_folder: str, + negative_folder: str = "", fallback_video_dir: str = "", - ) -> list[tuple[str, list[float], list[float]]]: + ) -> list[tuple[str, list[float], list[float], list[float]]]: """Build training video_infos from DB data. Args: profile: profile name positive_folder: export folder name for positive class (e.g. "mp4_Intense") + negative_folder: export folder name for explicit negatives (optional) fallback_video_dir: if source_path is empty, try filename in this dir Returns: - list of (source_video_path, positive_times, soft_times) per video. - Soft times = clips from any other export folder. + list of (source_video_path, positive_times, soft_times, negative_times) + per video. Soft times = clips from any other non-negative folder. """ if not self._enabled: return [] @@ -269,8 +286,9 @@ class ProcessedDB: (profile,), ).fetchall() - # Collect times by video, split by positive vs other folders + # Collect times by video, split by folder role pos_by_video: dict[str, set[float]] = {} + neg_by_video: dict[str, set[float]] = {} soft_by_video: dict[str, set[float]] = {} source_by_filename: dict[str, str] = {} @@ -280,26 +298,43 @@ class ProcessedDB: grandparent = os.path.basename(os.path.dirname(os.path.dirname(op))) if grandparent == positive_folder: pos_by_video.setdefault(fn, set()).add(st) + elif negative_folder and grandparent == negative_folder: + neg_by_video.setdefault(fn, set()).add(st) else: soft_by_video.setdefault(fn, set()).add(st) - # Remove positive times from soft to avoid conflicting labels + # Remove positive times from soft/neg to avoid conflicting labels for fn in pos_by_video: if fn in soft_by_video: soft_by_video[fn] -= pos_by_video[fn] + if fn in neg_by_video: + neg_by_video[fn] -= pos_by_video[fn] + # Deduplicate nearby markers (spread clips from same position) + def _dedup_times(times: set[float], min_gap: float = 8.0) -> list[float]: + if not times: + return [] + ordered = sorted(times) + result = [ordered[0]] + for t in ordered[1:]: + if t - result[-1] >= min_gap: + result.append(t) + return result + + # Include videos that have positives OR explicit negatives + all_videos = set(pos_by_video) | set(neg_by_video) result = [] - for fn in pos_by_video: + for fn in all_videos: sp = source_by_filename.get(fn, "") if not sp or not os.path.exists(sp): - # Fallback: try video_dir / filename if fallback_video_dir: sp = os.path.join(fallback_video_dir, fn) if not sp or not os.path.exists(sp): continue - gt_pos = sorted(pos_by_video[fn]) - gt_soft = sorted(soft_by_video.get(fn, set())) - result.append((sp, gt_pos, gt_soft)) + gt_pos = _dedup_times(pos_by_video.get(fn, set())) + gt_soft = _dedup_times(soft_by_video.get(fn, set())) + gt_neg = _dedup_times(neg_by_video.get(fn, set())) + result.append((sp, gt_pos, gt_soft, gt_neg)) return result def get_training_stats(self, profile: str) -> dict[str, dict]: @@ -329,6 +364,92 @@ class ProcessedDB: stats[folder_name] = {"videos": len(videos), "clips": clips} return stats + # ── Scan results ───────────────────────────────────────────── + + def save_scan_results(self, filename: str, profile: str, model: str, + regions: list[tuple[float, float, float]]) -> None: + """Replace scan results for (filename, profile, model) with new regions. + + regions: list of (start_time, end_time, score). + """ + if not self._enabled: + return + with self._lock: + self._con.execute( + "DELETE FROM scan_results" + " WHERE filename = ? AND profile = ? AND model = ?", + (filename, profile, model), + ) + self._con.executemany( + "INSERT INTO scan_results" + " (filename, profile, model, start_time, end_time, score)" + " VALUES (?, ?, ?, ?, ?, ?)", + [(filename, profile, model, s, e, sc) for s, e, sc in regions], + ) + self._con.commit() + + def get_scan_results(self, filename: str, profile: str + ) -> dict[str, list[tuple[int, float, float, float]]]: + """Return scan results grouped by model. + + Returns {model: [(row_id, start_time, end_time, score), ...]} sorted by + start_time. + """ + if not self._enabled: + return {} + rows = self._con.execute( + "SELECT id, model, start_time, end_time, score FROM scan_results" + " WHERE filename = ? AND profile = ?" + " ORDER BY model, start_time", + (filename, profile), + ).fetchall() + result: dict[str, list[tuple[int, float, float, float]]] = {} + for row_id, model, s, e, sc in rows: + result.setdefault(model, []).append((row_id, s, e, sc)) + return result + + def delete_scan_result(self, row_id: int) -> None: + """Delete a single scan result row.""" + if not self._enabled: + return + with self._lock: + self._con.execute("DELETE FROM scan_results WHERE id = ?", (row_id,)) + self._con.commit() + + def get_scan_models(self, filename: str, profile: str) -> list[str]: + """Return model names that have scan results for this file.""" + if not self._enabled: + return [] + rows = self._con.execute( + "SELECT DISTINCT model FROM scan_results" + " WHERE filename = ? AND profile = ? ORDER BY model", + (filename, profile), + ).fetchall() + return [r[0] for r in rows] + + def get_scanned_filenames(self, profile: str, model: str) -> set[str]: + """Return filenames that already have scan results for this model.""" + if not self._enabled: + return set() + rows = self._con.execute( + "SELECT DISTINCT filename FROM scan_results" + " WHERE profile = ? AND model = ?", + (profile, model), + ).fetchall() + return {r[0] for r in rows} + + def get_training_filenames(self, profile: str) -> set[str]: + """Return filenames used in training (have exported clips).""" + if not self._enabled: + return set() + rows = self._con.execute( + "SELECT DISTINCT filename FROM processed WHERE profile = ?", + (profile,), + ).fetchall() + return {r[0] for r in rows} + + # ── Hidden files ─────────────────────────────────────────── + def hide_file(self, filename: str, profile: str = "default") -> None: if not self._enabled: return diff --git a/main.py b/main.py index 40ba5fd..4d53dfc 100755 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from PyQt6.QtWidgets import ( QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout, + QTableWidget, QTableWidgetItem, QTabWidget, QHeaderView, ) from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut @@ -244,6 +245,15 @@ class TrainDialog(QDialog): self._cmb_positive.addItem(label, userData=folder_name) form.addRow("Positive class:", self._cmb_positive) + # Negative class selector (optional) + self._cmb_negative = QComboBox() + self._cmb_negative.addItem("(auto only)", userData="") + for folder_name, info in stats.items(): + label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)" + self._cmb_negative.addItem(label, userData=folder_name) + self._cmb_negative.currentIndexChanged.connect(lambda: self._debounce.start()) + form.addRow("Negative class:", self._cmb_negative) + # Model selector self._cmb_model = QComboBox() for name in _EMBED_MODELS: @@ -251,6 +261,18 @@ class TrainDialog(QDialog): self._cmb_model.setCurrentText("WAV2VEC2_BASE") form.addRow("Model:", self._cmb_model) + # Auto-negative margin (0 = disabled) + self._spn_neg_margin = QDoubleSpinBox() + self._spn_neg_margin.setDecimals(0) + self._spn_neg_margin.setRange(0.0, 600.0) + self._spn_neg_margin.setSingleStep(10.0) + self._spn_neg_margin.setValue(30.0) + self._spn_neg_margin.setSuffix("s") + self._spn_neg_margin.setSpecialValueText("Disabled") + self._spn_neg_margin.setToolTip( + "Auto-sample negatives from regions this far from any marker. 0 = disabled.") + form.addRow("Auto-neg margin:", self._spn_neg_margin) + # Video source directory (fallback for old DB rows without source_path) self._txt_video_dir = QLineEdit(video_dir) self._txt_video_dir.setPlaceholderText("Directory containing source videos") @@ -265,7 +287,13 @@ class TrainDialog(QDialog): btn_browse.setFixedWidth(30) btn_browse.clicked.connect(self._browse_video_dir) vid_row.addWidget(btn_browse) - form.addRow("Video dir:", vid_row) + self._lbl_video_dir = QLabel("Video dir:") + self._video_dir_widget = QWidget() + self._video_dir_widget.setLayout(vid_row) + form.addRow(self._lbl_video_dir, self._video_dir_widget) + # Hidden by default — shown only if some videos are missing source_path + self._lbl_video_dir.setVisible(False) + self._video_dir_widget.setVisible(False) layout.addLayout(form) @@ -297,17 +325,32 @@ class TrainDialog(QDialog): if not folder: self._lbl_stats.setText("No export folder data available.") return + neg_folder = self._cmb_negative.currentData() or "" + # First check without fallback to see if source_paths are sufficient + video_infos_no_fb = self._db.get_training_data( + self._profile, folder, negative_folder=neg_folder, + ) video_infos = self._db.get_training_data( - self._profile, folder, + self._profile, folder, negative_folder=neg_folder, fallback_video_dir=self._txt_video_dir.text(), ) + # Show video dir field only when the fallback helps find extra videos + needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0 + self._lbl_video_dir.setVisible(needs_fallback) + self._video_dir_widget.setVisible(needs_fallback) + n_videos = len(video_infos) - n_pos = sum(len(gt) for _, gt, _ in video_infos) - n_soft = sum(len(s) for _, _, s in video_infos) - lines = [f"{n_videos} videos with positive clips"] - lines.append(f"{n_pos} positive markers, {n_soft} soft/buffer markers") + n_pos = sum(len(vi[1]) for vi in video_infos) + n_soft = sum(len(vi[2]) for vi in video_infos) + n_neg = sum(len(vi[3]) for vi in video_infos) + lines = [f"{n_videos} videos"] + lines.append(f"{n_pos} positive, {n_soft} soft/buffer" + + (f", {n_neg} manual negative" if n_neg else "") + + " markers") if n_videos == 0: - lines.append("No source videos found. Set Video dir above.") + lines.append("No source videos found. Set Video dir below.") + self._lbl_video_dir.setVisible(True) + self._video_dir_widget.setVisible(True) elif n_videos < 3: lines.append("Recommend at least 3 videos for decent results.") self._lbl_stats.setText("
".join(lines)) @@ -316,6 +359,14 @@ class TrainDialog(QDialog): def positive_folder(self) -> str: return self._cmb_positive.currentData() or "" + @property + def negative_folder(self) -> str: + return self._cmb_negative.currentData() or "" + + @property + def neg_margin(self) -> float: + return self._spn_neg_margin.value() + @property def embed_model(self) -> str: return self._cmb_model.currentText() @@ -332,11 +383,14 @@ class TrainWorker(QThread): progress = pyqtSignal(str) # per-video status def __init__(self, video_infos: list, model_path: str, - embed_model: str | None = None): + embed_model: str | None = None, n_workers: int = 4, + neg_margin: float = 120.0): super().__init__() self._video_infos = video_infos self._model_path = model_path self._embed_model = embed_model + self._n_workers = n_workers + self._neg_margin = neg_margin self._cancel = False def cancel(self) -> None: @@ -349,8 +403,11 @@ class TrainWorker(QThread): result = train_classifier( self._video_infos, model_path=self._model_path, + neg_margin=self._neg_margin, embed_model=self._embed_model, cancel_flag=self, + n_workers=self._n_workers, + progress_cb=self.progress.emit, ) if self._cancel: return @@ -363,6 +420,152 @@ class TrainWorker(QThread): self.error.emit(str(e)) +class ScanResultsPanel(QWidget): + """Tabbed panel showing scan results per model, with seek-on-click and delete.""" + seek_requested = pyqtSignal(float) # request main window to seek to time + export_requested = pyqtSignal(list) # emit list of (start, end, score) to export + + def __init__(self, db, parent=None): + super().__init__(parent) + self._db = db + self._filename = "" + self._profile = "" + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + + self._tabs = QTabWidget() + self._tabs.setTabsClosable(False) + layout.addWidget(self._tabs) + + btn_row = QHBoxLayout() + self._btn_export = QPushButton("Export Scan Results") + self._btn_export.setToolTip("Export clips from the active tab's scan results") + self._btn_export.clicked.connect(self._on_export) + btn_row.addStretch() + btn_row.addWidget(self._btn_export) + layout.addLayout(btn_row) + + def load_for_file(self, filename: str, profile: str) -> None: + """Load saved scan results from DB for a file.""" + self._filename = filename + self._profile = profile + self._tabs.clear() + results = self._db.get_scan_results(filename, profile) + for model, rows in results.items(): + self._add_tab(model, rows) + + def add_scan_results(self, model: str, + regions: list[tuple[float, float, float]]) -> None: + """Add/replace a tab with new scan results and save to DB.""" + # Save to DB + self._db.save_scan_results(self._filename, self._profile, model, regions) + # Build row data with IDs from DB + db_results = self._db.get_scan_results(self._filename, self._profile) + rows = db_results.get(model, []) + # Remove existing tab for this model + for i in range(self._tabs.count()): + if self._tabs.tabText(i).rsplit(" (", 1)[0] == model: + self._tabs.removeTab(i) + break + self._add_tab(model, rows) + # Switch to the new tab + for i in range(self._tabs.count()): + if self._tabs.tabText(i).rsplit(" (", 1)[0] == model: + self._tabs.setCurrentIndex(i) + break + + def _add_tab(self, model: str, + rows: list[tuple[int, float, float, float]]) -> None: + """Create a table tab. rows: [(row_id, start, end, score), ...]""" + table = QTableWidget(len(rows), 3) + table.setHorizontalHeaderLabels(["Time", "End", "Score"]) + table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows) + table.setSelectionMode(QTableWidget.SelectionMode.ExtendedSelection) + table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers) + table.verticalHeader().setVisible(False) + header = table.horizontalHeader() + header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + + for i, (row_id, start, end, score) in enumerate(rows): + t_item = QTableWidgetItem(format_time(start)) + t_item.setData(Qt.ItemDataRole.UserRole, row_id) + t_item.setData(Qt.ItemDataRole.UserRole + 1, start) + table.setItem(i, 0, t_item) + e_item = QTableWidgetItem(format_time(end)) + e_item.setData(Qt.ItemDataRole.UserRole, end) + table.setItem(i, 1, e_item) + table.setItem(i, 2, QTableWidgetItem(f"{score:.2f}")) + + table.itemSelectionChanged.connect( + lambda t=table: self._on_selection_changed(t)) + self._tabs.addTab(table, f"{model} ({len(rows)})") + + def _on_selection_changed(self, table: QTableWidget) -> None: + items = table.selectedItems() + if items: + row = items[0].row() + start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1) + if start is not None: + self.seek_requested.emit(float(start)) + + def delete_selected(self) -> None: + """Delete selected rows from active tab and DB.""" + table = self._tabs.currentWidget() + if not isinstance(table, QTableWidget): + return + rows_to_delete = sorted( + {idx.row() for idx in table.selectedIndexes()}, reverse=True) + tab_idx = self._tabs.currentIndex() + model = self._tabs.tabText(tab_idx).rsplit(" (", 1)[0] + for row in rows_to_delete: + row_id = table.item(row, 0).data(Qt.ItemDataRole.UserRole) + if row_id is not None: + self._db.delete_scan_result(row_id) + table.removeRow(row) + # Update tab title with new count + count = table.rowCount() + self._tabs.setTabText(tab_idx, f"{model} ({count})") + + def _get_tab_regions(self, table: QTableWidget + ) -> list[tuple[float, float, float]]: + """Extract (start, end, score) from a table widget.""" + regions = [] + for row in range(table.rowCount()): + start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1) + end = table.item(row, 1).data(Qt.ItemDataRole.UserRole) + score = float(table.item(row, 2).text()) + regions.append((float(start), float(end), score)) + return regions + + def _on_export(self) -> None: + table = self._tabs.currentWidget() + if not isinstance(table, QTableWidget): + return + regions = self._get_tab_regions(table) + if regions: + self.export_requested.emit(regions) + + def current_regions(self) -> list[tuple[float, float, float]]: + """Return (start, end, score) for all rows in the active tab.""" + table = self._tabs.currentWidget() + if not isinstance(table, QTableWidget): + return [] + return self._get_tab_regions(table) + + def has_results(self) -> bool: + return self._tabs.count() > 0 + + def keyPressEvent(self, event): + if event.key() in (Qt.Key.Key_Delete, Qt.Key.Key_Backspace): + self.delete_selected() + else: + super().keyPressEvent(event) + + class TimelineWidget(QWidget): cursor_changed = pyqtSignal(float) # emits position in seconds seek_changed = pyqtSignal(float) # emits seek position (lock mode) @@ -1710,6 +1913,15 @@ class MainWindow(QMainWindow): self._btn_train.clicked.connect(self._open_train_dialog) self._train_worker: TrainWorker | None = None + self._btn_scan_all = QPushButton("Scan All") + self._btn_scan_all.setToolTip("Scan all playlist videos that haven't been scanned yet") + self._btn_scan_all.clicked.connect(self._start_scan_all) + self._scan_all_queue: list[str] = [] + + self._cmb_scan_model = QComboBox() + self._cmb_scan_model.setToolTip("Trained embedding model to use for scanning") + self._cmb_scan_model.setMinimumWidth(120) + self._spn_auto_fuse = QDoubleSpinBox() self._spn_auto_fuse.setDecimals(1) self._spn_auto_fuse.setRange(0.0, 60.0) @@ -1800,6 +2012,7 @@ class MainWindow(QMainWindow): if idx >= 0: self._cmb_profile.setCurrentIndex(idx) self._cmb_profile.activated.connect(self._on_profile_activated) + self._refresh_scan_models() self._btn_shortcuts = QPushButton("?") self._btn_shortcuts.setFixedWidth(28) @@ -1864,11 +2077,13 @@ class MainWindow(QMainWindow): settings_row.addWidget(self._chk_rand_portrait) settings_row.addWidget(self._chk_rand_square) settings_row.addWidget(self._chk_track) + settings_row.addWidget(self._cmb_scan_model) settings_row.addWidget(self._btn_scan) settings_row.addWidget(self._btn_auto_export) settings_row.addWidget(self._spn_auto_fuse) settings_row.addWidget(self._sld_threshold) settings_row.addWidget(self._btn_train) + settings_row.addWidget(self._btn_scan_all) settings_row.addStretch() self._lbl_status = QLabel() self._lbl_status.setStyleSheet("color: #888; font-size: 11px;") @@ -1918,13 +2133,20 @@ class MainWindow(QMainWindow): left_layout.addLayout(left_top) left_layout.addWidget(self._playlist) + # Scan results panel (right side) + self._scan_panel = ScanResultsPanel(self._db) + self._scan_panel.seek_requested.connect(self._on_scan_seek) + self._scan_panel.export_requested.connect(self._on_scan_export) + # Root: horizontal splitter splitter = QSplitter(Qt.Orientation.Horizontal) splitter.addWidget(left) splitter.addWidget(right) - splitter.setSizes([200, 900]) + splitter.addWidget(self._scan_panel) + splitter.setSizes([200, 900, 200]) splitter.setCollapsible(0, False) splitter.setCollapsible(1, False) + splitter.setCollapsible(2, True) self.setCentralWidget(splitter) self.setStatusBar(None) @@ -2061,6 +2283,7 @@ class MainWindow(QMainWindow): self._btn_delete.setEnabled(False) self._update_next_label() self._apply_playlist_filters() + self._refresh_scan_models() if self._file_path: self._refresh_markers() _log(f"Profile switched: {text}") @@ -2184,7 +2407,13 @@ class MainWindow(QMainWindow): if self._scan_worker and self._scan_worker.isRunning(): self._scan_worker.cancel() self._cleanup_scan_worker() + self._scan_all_queue.clear() self._btn_scan.setEnabled(True) + self._btn_scan_all.setEnabled(True) + # Load saved scan results for this file + if self._file_path: + filename = os.path.basename(self._file_path) + self._scan_panel.load_for_file(filename, self._profile) dur = self._mpv.get_duration() self._timeline.set_duration(dur) @@ -2653,8 +2882,42 @@ class MainWindow(QMainWindow): return self._step_cursor(markers[0][0] - self._cursor) # wrap to first + def _load_selected_scan_model(self) -> tuple: + """Load the classifier selected in the scan model combo. + + Returns (model_dict, label_str) or (None, "") on failure. + """ + from core.audio_scan import load_classifier, default_model_path + sel = self._cmb_scan_model.currentText() + if not sel or sel == "(no model)": + self._show_status("No trained model — click Train first") + return None, "" + embed_name = None if sel == "(legacy)" else sel + model_path = default_model_path(self._profile, embed_name) + model = load_classifier(model_path) + if model is None: + self._show_status(f"Model file missing: {model_path}") + return None, "" + return model, sel + + def _refresh_scan_models(self) -> None: + """Populate the scan model combo with trained models for the current profile.""" + from core.audio_scan import list_trained_models + prev = self._cmb_scan_model.currentText() + self._cmb_scan_model.clear() + models = list_trained_models(self._profile) + if not models: + self._cmb_scan_model.addItem("(no model)") + else: + for m in models: + self._cmb_scan_model.addItem(m if m else "(legacy)") + # Restore previous selection if still available + idx = self._cmb_scan_model.findText(prev) + if idx >= 0: + self._cmb_scan_model.setCurrentIndex(idx) + def _cleanup_scan_worker(self) -> None: - """Disconnect signals and schedule deletion of old scan worker.""" + """Disconnect signals, cancel, and schedule deletion of old scan worker.""" if self._scan_worker is not None: try: self._scan_worker.scan_done.disconnect() @@ -2662,8 +2925,8 @@ class MainWindow(QMainWindow): self._scan_worker.progress.disconnect() except TypeError: pass # already disconnected + self._scan_worker.cancel() if self._scan_worker.isRunning(): - # QThread.finished fires when run() returns, even on cancel self._scan_worker.finished.connect(self._scan_worker.deleteLater) else: self._scan_worker.deleteLater() @@ -2682,17 +2945,14 @@ class MainWindow(QMainWindow): threshold = self._sld_threshold.value() - from core.audio_scan import load_classifier, default_model_path - model_path = default_model_path(self._profile) - model = load_classifier(model_path) - + model, model_label = self._load_selected_scan_model() if model is None: - self._show_status("No trained model — click Train first") return self._btn_scan.setEnabled(False) self._scan_file_path = self._file_path - self._show_status("Scanning...") + self._scan_model_label = model_label + self._show_status(f"Scanning ({model_label})...") self._scan_worker = ScanWorker( self._file_path, model=model, threshold=threshold, ) @@ -2708,6 +2968,10 @@ class MainWindow(QMainWindow): if self._file_path != getattr(self, '_scan_file_path', None): return self._timeline.set_scan_regions(regions) + model_label = getattr(self, '_scan_model_label', '') + if model_label and self._file_path: + filename = os.path.basename(self._file_path) + self._scan_panel.add_scan_results(model_label, regions) self._show_status(f"Scan complete: {len(regions)} matching regions") def _on_scan_error(self, msg: str) -> None: @@ -2715,6 +2979,105 @@ class MainWindow(QMainWindow): self._btn_auto_export.setEnabled(True) self._show_status(f"Scan error: {msg}") + def _on_scan_seek(self, t: float) -> None: + """Seek player when a scan result row is clicked.""" + if self._file_path: + self._cursor = t + self._mpv.seek(t) + self._timeline.set_cursor(t) + dur = self._mpv.get_duration() + self._lbl_time.setText(f"{format_time(t)} / {format_time(dur)}") + + def _on_scan_export(self, regions: list) -> None: + """Export clips from scan results panel.""" + if not self._file_path or not regions: + return + if self._export_worker and self._export_worker.isRunning(): + self._show_status("Export already running…") + return + self._auto_export_regions(regions) + + # ── Scan All ─────────────────────────────────────────────── + + def _start_scan_all(self) -> None: + """Scan all playlist videos not yet scanned with the selected model.""" + if self._scan_worker and self._scan_worker.isRunning(): + self._show_status("Scan already running") + return + + model, model_label = self._load_selected_scan_model() + if model is None: + return + + # Build queue: playlist files minus already-scanned and training files + all_paths = self._playlist._paths + scanned = self._db.get_scanned_filenames(self._profile, model_label) + training = self._db.get_training_filenames(self._profile) + skip = scanned | training + + self._scan_all_queue = [ + p for p in all_paths if os.path.basename(p) not in skip + ] + if not self._scan_all_queue: + self._show_status("All videos already scanned or used for training") + return + + self._scan_all_model = model + self._scan_all_model_label = model_label + self._scan_all_profile = self._profile + self._scan_all_total = len(self._scan_all_queue) + self._btn_scan_all.setEnabled(False) + self._btn_scan.setEnabled(False) + self._show_status( + f"Scan All: 0/{self._scan_all_total} ({model_label})") + self._scan_all_next() + + def _scan_all_next(self) -> None: + """Start scanning the next video in the queue.""" + if not self._scan_all_queue: + self._btn_scan_all.setEnabled(True) + self._btn_scan.setEnabled(True) + done = self._scan_all_total + self._show_status(f"Scan All complete: {done} videos scanned") + return + + self._cleanup_scan_worker() + path = self._scan_all_queue.pop(0) + remaining = self._scan_all_total - len(self._scan_all_queue) + self._scan_all_current_path = path + self._show_status( + f"Scan All: {remaining}/{self._scan_all_total} — " + f"{os.path.basename(path)}") + + threshold = self._sld_threshold.value() + self._scan_worker = ScanWorker( + path, model=self._scan_all_model, threshold=threshold, + ) + self._scan_worker.scan_done.connect(self._on_scan_all_done) + self._scan_worker.error.connect(self._on_scan_all_error) + self._scan_worker.start() + + def _on_scan_all_done(self, regions: list) -> None: + """Save batch scan results and continue to next video.""" + path = getattr(self, '_scan_all_current_path', '') + model_label = getattr(self, '_scan_all_model_label', '') + if path and model_label: + filename = os.path.basename(path) + profile = getattr(self, '_scan_all_profile', self._profile) + self._db.save_scan_results( + filename, profile, model_label, regions) + # If this is the currently loaded file, update the panel + if self._file_path and os.path.basename(self._file_path) == filename: + self._scan_panel.load_for_file(filename, self._profile) + self._timeline.set_scan_regions(regions) + self._scan_all_next() + + def _on_scan_all_error(self, msg: str) -> None: + """Log error and continue to next video.""" + path = getattr(self, '_scan_all_current_path', '') + _log(f"Scan All error on {os.path.basename(path)}: {msg}") + self._scan_all_next() + # ── Training ──────────────────────────────────────────────── def _cleanup_train_worker(self) -> None: @@ -2751,6 +3114,8 @@ class MainWindow(QMainWindow): return pos_folder = dlg.positive_folder + neg_folder = dlg.negative_folder + neg_margin = dlg.neg_margin embed_model = dlg.embed_model video_dir = dlg.video_dir if not pos_folder: @@ -2762,20 +3127,22 @@ class MainWindow(QMainWindow): self._settings.setValue("train_video_dir", video_dir) video_infos = self._db.get_training_data( - self._profile, pos_folder, fallback_video_dir=video_dir, + self._profile, pos_folder, negative_folder=neg_folder, + fallback_video_dir=video_dir, ) if not video_infos: self._show_status("No training data found for this subprofile") return from core.audio_scan import default_model_path - model_path = default_model_path(self._profile) + model_path = default_model_path(self._profile, embed_model) self._cleanup_train_worker() self._btn_train.setEnabled(False) self._show_status(f"Training {embed_model} on {len(video_infos)} videos...") - self._train_worker = TrainWorker(video_infos, model_path, embed_model) + n_workers = self._spn_workers.value() + self._train_worker = TrainWorker(video_infos, model_path, embed_model, n_workers, neg_margin) self._train_worker.train_done.connect(self._on_train_done) self._train_worker.error.connect(self._on_train_error) self._train_worker.progress.connect(self._show_status) @@ -2783,6 +3150,7 @@ class MainWindow(QMainWindow): def _on_train_done(self, model_path: str): self._btn_train.setEnabled(True) + self._refresh_scan_models() self._show_status(f"Model trained and saved") _log(f"Training complete: {model_path}") @@ -2810,22 +3178,19 @@ class MainWindow(QMainWindow): threshold = self._sld_threshold.value() - from core.audio_scan import load_classifier, default_model_path - model_path = default_model_path(self._profile) - model = load_classifier(model_path) - - if model is not None: - self._scan_file_path = self._file_path - self._show_status("Auto: scanning with classifier...") - self._scan_worker = ScanWorker( - self._file_path, model=model, threshold=threshold, - ) - else: - self._show_status("Auto: no trained model — click Train first") + model, model_label = self._load_selected_scan_model() + if model is None: self._btn_auto_export.setEnabled(True) self._btn_scan.setEnabled(True) return + self._scan_file_path = self._file_path + self._scan_model_label = model_label + self._show_status(f"Auto: scanning ({model_label})...") + self._scan_worker = ScanWorker( + self._file_path, model=model, threshold=threshold, + ) + self._scan_worker.scan_done.connect(self._on_auto_scan_done) self._scan_worker.error.connect(self._on_scan_error) self._scan_worker.progress.connect(self._show_status) @@ -2879,7 +3244,15 @@ class MainWindow(QMainWindow): return self._timeline.set_scan_regions(regions) + # Also save to scan panel + model_label = getattr(self, '_scan_model_label', '') + if model_label and self._file_path: + self._scan_panel.add_scan_results(model_label, regions) + self._auto_export_regions(regions) + + def _auto_export_regions(self, regions: list) -> None: + """Export clips from a list of (start, end, score) regions.""" if not regions: self._show_status("Auto: no regions found") self._btn_auto_export.setEnabled(True) @@ -2896,6 +3269,7 @@ class MainWindow(QMainWindow): # Build export jobs — one 8s clip per position folder = self._txt_folder.text() name = self._txt_name.text() or "clip" + self._auto_export_name = name fmt = self._cmb_format.currentText() image_sequence = fmt == "WebP sequence" os.makedirs(folder, exist_ok=True) @@ -2959,7 +3333,7 @@ class MainWindow(QMainWindow): """Record each auto-exported clip to DB.""" # Find the start_time for this clip from stashed positions counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042" - name = self._txt_name.text() or "clip" + name = getattr(self, '_auto_export_name', self._txt_name.text() or "clip") start_t = None for t, c in self._auto_export_positions: if counter_str == f"{name}_{c:03d}": @@ -3306,6 +3680,11 @@ class MainWindow(QMainWindow): # Cancel background workers to prevent callbacks into dead objects. self._cleanup_scan_worker() self._cleanup_train_worker() + if self._export_worker and self._export_worker.isRunning(): + self._export_worker.cancel() + self._export_worker.wait(3000) + if hasattr(self, '_db_worker') and self._db_worker and self._db_worker.isRunning(): + self._db_worker.wait(1000) # Stop timers first to prevent callbacks into dead objects. self._preview_timer.stop() self._mpv._render_timer.stop()