feat: scan results panel, model switching, batch scan, and training improvements
- Replace librosa with direct ffmpeg subprocess for 10x faster audio loading - Add ScanResultsPanel with per-model tabs, seek-on-click, delete, and export - Persist scan results in DB per (filename, profile, model) - Add model selector dropdown to switch between trained embedding models - Add "Scan All" button for batch scanning playlist videos - Support manual negative examples via negative class folder - Configurable auto-negative margin (default 30s, 0 = disabled) - Deduplicate nearby training markers (8s min gap) - Parallel audio loading with ThreadPoolExecutor during training - Progress callbacks from training for UI status updates - Cache bypass in scan_video (skip audio loading when embeddings cached) - Move all caches (models, embeddings, downloads) into project directory - Add 8cut.sh launcher script with auto venv/conda detection - Fix 11 bugs across thread safety, signal handling, and state management Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,5 +5,6 @@ __pycache__/
|
|||||||
.worktrees/
|
.worktrees/
|
||||||
.venv/
|
.venv/
|
||||||
models/
|
models/
|
||||||
|
cache/
|
||||||
*.joblib
|
*.joblib
|
||||||
*.pt
|
*.pt
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Launch 8-cut with auto-detected venv/conda environment
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
ENV_NAME="8cut"
|
||||||
|
CONDA_PREFIX_BASE="/media/p5/miniforge3"
|
||||||
|
|
||||||
|
# 1. Try .venv in project dir
|
||||||
|
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
|
||||||
|
source "$SCRIPT_DIR/.venv/bin/activate"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. Try conda env (works without shell init)
|
||||||
|
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
|
||||||
|
if [ -x "$CONDA_PYTHON" ]; then
|
||||||
|
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. Try conda via shell hook (interactive shells)
|
||||||
|
if command -v conda &>/dev/null; then
|
||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Fallback to system Python
|
||||||
|
exec python3 "$SCRIPT_DIR/main.py" "$@"
|
||||||
+161
-23
@@ -2,15 +2,39 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
|
||||||
|
|
||||||
from .paths import _log
|
from .paths import _bin, _log
|
||||||
|
|
||||||
_SR = 16000 # lower sr = faster
|
_SR = 16000 # lower sr = faster
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
|
||||||
|
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
|
||||||
|
cmd = [
|
||||||
|
_bin("ffmpeg"), "-i", path,
|
||||||
|
"-vn", # skip video
|
||||||
|
"-ac", "1", # mono
|
||||||
|
"-ar", str(sr), # resample
|
||||||
|
"-f", "f32le", # raw 32-bit float little-endian
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
|
||||||
|
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||||
_WINDOW = 8.0 # seconds
|
_WINDOW = 8.0 # seconds
|
||||||
_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models")
|
_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v")
|
_MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
|
||||||
|
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
|
||||||
|
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
|
||||||
|
|
||||||
|
# Redirect torch hub and huggingface downloads into the project
|
||||||
|
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
|
||||||
|
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Embedding extraction (lazy-loaded)
|
# Embedding extraction (lazy-loaded)
|
||||||
@@ -33,7 +57,7 @@ _EMBED_MODELS = {
|
|||||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
_BEATS_CHECKPOINT = os.path.join(
|
_BEATS_CHECKPOINT = os.path.join(
|
||||||
os.path.expanduser("~"), ".cache", "huggingface", "hub",
|
_DL_CACHE_DIR, "huggingface", "hub",
|
||||||
"models--lpepino--beats_ckpts", "snapshots",
|
"models--lpepino--beats_ckpts", "snapshots",
|
||||||
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
||||||
)
|
)
|
||||||
@@ -86,6 +110,30 @@ def _w2v_cache_path(video_path: str, hop: float, window: float,
|
|||||||
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_exists(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> bool:
|
||||||
|
"""Check if embedding cache exists for a video."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
return os.path.exists(path)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_load(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
|
||||||
|
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
if os.path.exists(path):
|
||||||
|
data = np.load(path)
|
||||||
|
_log(f"audio_scan: cache hit ({path})")
|
||||||
|
return data["timestamps"], data["embeddings"]
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache read failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||||
hop: float = 1.0, window: float = _WINDOW,
|
hop: float = 1.0, window: float = _WINDOW,
|
||||||
video_path: str | None = None,
|
video_path: str | None = None,
|
||||||
@@ -162,6 +210,7 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
gt_soft: list[float], tolerance: float = 12.0,
|
gt_soft: list[float], tolerance: float = 12.0,
|
||||||
neg_margin: float = 120.0,
|
neg_margin: float = 120.0,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
|
gt_negative: list[float] | None = None,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Extract embeddings only near positives and distant negatives.
|
"""Extract embeddings only near positives and distant negatives.
|
||||||
|
|
||||||
@@ -180,13 +229,24 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
if 0 <= t <= duration - _WINDOW:
|
if 0 <= t <= duration - _WINDOW:
|
||||||
pos_times.add(int(t))
|
pos_times.add(int(t))
|
||||||
|
|
||||||
# Negative windows: every 4s, far from any marker
|
# Manual negative windows: near explicit negative markers
|
||||||
|
manual_neg_times = set()
|
||||||
|
if gt_negative:
|
||||||
|
for gt in gt_negative:
|
||||||
|
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||||
|
t = gt + offset
|
||||||
|
if 0 <= t <= duration - _WINDOW:
|
||||||
|
manual_neg_times.add(int(t))
|
||||||
|
# Don't let manual negatives overlap with positives
|
||||||
|
manual_neg_times -= pos_times
|
||||||
|
|
||||||
|
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0)
|
||||||
neg_times = set()
|
neg_times = set()
|
||||||
for t in range(0, int(duration - _WINDOW), 4):
|
for t in range(0, int(duration - _WINDOW), 4):
|
||||||
if min((abs(t - g) for g in all_gt), default=9999) > neg_margin:
|
if neg_margin > 0 and min((abs(t - g) for g in all_gt), default=9999) > neg_margin:
|
||||||
neg_times.add(t)
|
neg_times.add(t)
|
||||||
|
|
||||||
all_times = sorted(pos_times | neg_times)
|
all_times = sorted(pos_times | neg_times | manual_neg_times)
|
||||||
# Filter out windows that go past the end
|
# Filter out windows that go past the end
|
||||||
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
||||||
|
|
||||||
@@ -225,9 +285,10 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
for i, t in enumerate(timestamps):
|
for i, t in enumerate(timestamps):
|
||||||
di = min((abs(t - g) for g in gt_intense), default=9999)
|
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||||
da = min((abs(t - g) for g in all_gt), default=9999)
|
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||||
|
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
|
||||||
if di < tolerance:
|
if di < tolerance:
|
||||||
labels[i] = 1
|
labels[i] = 1
|
||||||
elif da > neg_margin:
|
elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
|
||||||
labels[i] = -1
|
labels[i] = -1
|
||||||
return timestamps, embeddings, labels
|
return timestamps, embeddings, labels
|
||||||
|
|
||||||
@@ -241,7 +302,9 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
|||||||
tolerance: float = 12.0,
|
tolerance: float = 12.0,
|
||||||
neg_margin: float = 120.0,
|
neg_margin: float = 120.0,
|
||||||
embed_model: str | None = None,
|
embed_model: str | None = None,
|
||||||
cancel_flag: object = None) -> dict:
|
cancel_flag: object = None,
|
||||||
|
n_workers: int = 4,
|
||||||
|
progress_cb: object = None) -> dict:
|
||||||
"""Train a classifier from labeled videos.
|
"""Train a classifier from labeled videos.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -250,24 +313,62 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
|||||||
tolerance/neg_margin: labeling parameters
|
tolerance/neg_margin: labeling parameters
|
||||||
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
||||||
cancel_flag: object with _cancel attribute; if set, training aborts early
|
cancel_flag: object with _cancel attribute; if set, training aborts early
|
||||||
|
n_workers: number of threads for parallel audio loading
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
||||||
"""
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from sklearn.ensemble import GradientBoostingClassifier
|
from sklearn.ensemble import GradientBoostingClassifier
|
||||||
|
|
||||||
all_X, all_y = [], []
|
def _progress(msg: str) -> None:
|
||||||
|
_log(msg)
|
||||||
|
if progress_cb:
|
||||||
|
progress_cb(msg)
|
||||||
|
|
||||||
for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos):
|
def _load_audio(path: str) -> np.ndarray:
|
||||||
|
return _load_audio_ffmpeg(path, sr=_SR)
|
||||||
|
|
||||||
|
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
|
||||||
|
n = len(video_infos)
|
||||||
|
load_workers = min(n_workers, 4)
|
||||||
|
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
|
||||||
|
audio_data: dict[int, np.ndarray] = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=load_workers) as pool:
|
||||||
|
future_to_idx = {
|
||||||
|
pool.submit(_load_audio, vi[0]): i
|
||||||
|
for i, vi in enumerate(video_infos)
|
||||||
|
}
|
||||||
|
failed = set()
|
||||||
|
for future in as_completed(future_to_idx):
|
||||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
_log("audio_scan: training cancelled")
|
_log("audio_scan: training cancelled")
|
||||||
return None
|
return None
|
||||||
_log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}")
|
idx = future_to_idx[future]
|
||||||
y, _ = librosa.load(vpath, sr=_SR, mono=True)
|
try:
|
||||||
|
audio_data[idx] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
|
||||||
|
failed.add(idx)
|
||||||
|
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
|
||||||
|
|
||||||
|
# Phase 2: extract embeddings sequentially on GPU
|
||||||
|
_progress(f"Extracting embeddings: 0/{n}")
|
||||||
|
all_X, all_y = [], []
|
||||||
|
for vi, vinfo in enumerate(video_infos):
|
||||||
|
if vi in failed:
|
||||||
|
continue
|
||||||
|
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
|
||||||
|
gt_negative = vinfo[3] if len(vinfo) > 3 else []
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: training cancelled")
|
||||||
|
return None
|
||||||
|
_progress(f"Extracting embeddings: {vi+1}/{n}")
|
||||||
|
y = audio_data.pop(vi)
|
||||||
|
|
||||||
timestamps, embeddings, labels = _extract_w2v_targeted(
|
timestamps, embeddings, labels = _extract_w2v_targeted(
|
||||||
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
||||||
model_name=embed_model,
|
model_name=embed_model, gt_negative=gt_negative,
|
||||||
)
|
)
|
||||||
if len(timestamps) == 0:
|
if len(timestamps) == 0:
|
||||||
continue
|
continue
|
||||||
@@ -306,6 +407,7 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
|||||||
train_idx = np.concatenate([pos_idx, neg_sample])
|
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||||
rng.shuffle(train_idx)
|
rng.shuffle(train_idx)
|
||||||
|
|
||||||
|
_progress(f"Fitting classifier on {len(train_idx)} samples...")
|
||||||
clf = GradientBoostingClassifier(
|
clf = GradientBoostingClassifier(
|
||||||
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42,
|
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42,
|
||||||
)
|
)
|
||||||
@@ -334,11 +436,41 @@ def load_classifier(model_path: str) -> dict | None:
|
|||||||
return joblib.load(model_path)
|
return joblib.load(model_path)
|
||||||
|
|
||||||
|
|
||||||
def default_model_path(profile_name: str = "default") -> str:
|
def default_model_path(profile_name: str = "default",
|
||||||
"""Return the default path for a profile's classifier model."""
|
embed_model: str | None = None) -> str:
|
||||||
|
"""Return the path for a profile's classifier model.
|
||||||
|
|
||||||
|
When embed_model is given the file is ``{profile}_{model}.joblib``,
|
||||||
|
otherwise ``{profile}.joblib`` (legacy single-model layout).
|
||||||
|
"""
|
||||||
|
if embed_model:
|
||||||
|
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
|
||||||
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def list_trained_models(profile_name: str = "default") -> list[str]:
|
||||||
|
"""Return embedding model names that have a trained .joblib for *profile_name*.
|
||||||
|
|
||||||
|
Looks for files matching ``{profile}_{MODEL}.joblib`` in the models dir.
|
||||||
|
"""
|
||||||
|
prefix = f"{profile_name}_"
|
||||||
|
suffix = ".joblib"
|
||||||
|
result = []
|
||||||
|
if not os.path.isdir(_MODEL_DIR):
|
||||||
|
return result
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
if fname.startswith(prefix) and fname.endswith(suffix):
|
||||||
|
model_name = fname[len(prefix):-len(suffix)]
|
||||||
|
if model_name in _EMBED_MODELS:
|
||||||
|
result.append(model_name)
|
||||||
|
# Also check legacy {profile}.joblib
|
||||||
|
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
if os.path.exists(legacy) and not result:
|
||||||
|
# Legacy model — we don't know the embed model, but it's usable
|
||||||
|
result.append("")
|
||||||
|
return sorted(result)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Scanning
|
# Scanning
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -359,22 +491,28 @@ def scan_video(
|
|||||||
_log("audio_scan: no model provided")
|
_log("audio_scan: no model provided")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
clf = model["classifier"]
|
||||||
|
embed_model = model.get("embed_model")
|
||||||
|
|
||||||
|
# Try cache first — skip expensive audio loading if embeddings exist
|
||||||
|
cached = _w2v_cache_load(video_path, hop, window, embed_model)
|
||||||
|
if cached is not None:
|
||||||
|
timestamps, window_vectors = cached
|
||||||
|
else:
|
||||||
_log(f"audio_scan: loading {video_path}")
|
_log(f"audio_scan: loading {video_path}")
|
||||||
y, sr = librosa.load(video_path, sr=_SR, mono=True)
|
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||||
duration = len(y) / sr
|
sr = _SR
|
||||||
_log(f"audio_scan: {duration:.1f}s loaded, extracting features...")
|
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
|
||||||
|
|
||||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
clf = model["classifier"]
|
|
||||||
embed_model = model.get("embed_model")
|
|
||||||
|
|
||||||
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
||||||
timestamps, window_vectors = _extract_w2v_windows(
|
timestamps, window_vectors = _extract_w2v_windows(
|
||||||
y, sr, hop=hop, window=window, video_path=video_path,
|
y, sr, hop=hop, window=window, video_path=video_path,
|
||||||
cancel_flag=cancel_flag, model_name=embed_model,
|
cancel_flag=cancel_flag, model_name=embed_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(timestamps) == 0:
|
if len(timestamps) == 0:
|
||||||
_log("audio_scan: video shorter than window")
|
_log("audio_scan: video shorter than window")
|
||||||
return []
|
return []
|
||||||
|
|||||||
+131
-10
@@ -81,6 +81,21 @@ class ProcessedDB:
|
|||||||
" PRIMARY KEY (filename, profile)"
|
" PRIMARY KEY (filename, profile)"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS scan_results ("
|
||||||
|
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||||
|
" filename TEXT NOT NULL,"
|
||||||
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
|
" model TEXT NOT NULL,"
|
||||||
|
" start_time REAL NOT NULL,"
|
||||||
|
" end_time REAL NOT NULL,"
|
||||||
|
" score REAL NOT NULL"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model"
|
||||||
|
" ON scan_results(filename, profile, model)"
|
||||||
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def add(self, filename: str, start_time: float, output_path: str,
|
def add(self, filename: str, start_time: float, output_path: str,
|
||||||
@@ -248,18 +263,20 @@ class ProcessedDB:
|
|||||||
return sorted(folder_names)
|
return sorted(folder_names)
|
||||||
|
|
||||||
def get_training_data(self, profile: str, positive_folder: str,
|
def get_training_data(self, profile: str, positive_folder: str,
|
||||||
|
negative_folder: str = "",
|
||||||
fallback_video_dir: str = "",
|
fallback_video_dir: str = "",
|
||||||
) -> list[tuple[str, list[float], list[float]]]:
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
"""Build training video_infos from DB data.
|
"""Build training video_infos from DB data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
profile: profile name
|
profile: profile name
|
||||||
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
|
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
|
||||||
|
negative_folder: export folder name for explicit negatives (optional)
|
||||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of (source_video_path, positive_times, soft_times) per video.
|
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||||
Soft times = clips from any other export folder.
|
per video. Soft times = clips from any other non-negative folder.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return []
|
return []
|
||||||
@@ -269,8 +286,9 @@ class ProcessedDB:
|
|||||||
(profile,),
|
(profile,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# Collect times by video, split by positive vs other folders
|
# Collect times by video, split by folder role
|
||||||
pos_by_video: dict[str, set[float]] = {}
|
pos_by_video: dict[str, set[float]] = {}
|
||||||
|
neg_by_video: dict[str, set[float]] = {}
|
||||||
soft_by_video: dict[str, set[float]] = {}
|
soft_by_video: dict[str, set[float]] = {}
|
||||||
source_by_filename: dict[str, str] = {}
|
source_by_filename: dict[str, str] = {}
|
||||||
|
|
||||||
@@ -280,26 +298,43 @@ class ProcessedDB:
|
|||||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
if grandparent == positive_folder:
|
if grandparent == positive_folder:
|
||||||
pos_by_video.setdefault(fn, set()).add(st)
|
pos_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif negative_folder and grandparent == negative_folder:
|
||||||
|
neg_by_video.setdefault(fn, set()).add(st)
|
||||||
else:
|
else:
|
||||||
soft_by_video.setdefault(fn, set()).add(st)
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
# Remove positive times from soft to avoid conflicting labels
|
# Remove positive times from soft/neg to avoid conflicting labels
|
||||||
for fn in pos_by_video:
|
for fn in pos_by_video:
|
||||||
if fn in soft_by_video:
|
if fn in soft_by_video:
|
||||||
soft_by_video[fn] -= pos_by_video[fn]
|
soft_by_video[fn] -= pos_by_video[fn]
|
||||||
|
if fn in neg_by_video:
|
||||||
|
neg_by_video[fn] -= pos_by_video[fn]
|
||||||
|
|
||||||
|
# Deduplicate nearby markers (spread clips from same position)
|
||||||
|
def _dedup_times(times: set[float], min_gap: float = 8.0) -> list[float]:
|
||||||
|
if not times:
|
||||||
|
return []
|
||||||
|
ordered = sorted(times)
|
||||||
|
result = [ordered[0]]
|
||||||
|
for t in ordered[1:]:
|
||||||
|
if t - result[-1] >= min_gap:
|
||||||
|
result.append(t)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Include videos that have positives OR explicit negatives
|
||||||
|
all_videos = set(pos_by_video) | set(neg_by_video)
|
||||||
result = []
|
result = []
|
||||||
for fn in pos_by_video:
|
for fn in all_videos:
|
||||||
sp = source_by_filename.get(fn, "")
|
sp = source_by_filename.get(fn, "")
|
||||||
if not sp or not os.path.exists(sp):
|
if not sp or not os.path.exists(sp):
|
||||||
# Fallback: try video_dir / filename
|
|
||||||
if fallback_video_dir:
|
if fallback_video_dir:
|
||||||
sp = os.path.join(fallback_video_dir, fn)
|
sp = os.path.join(fallback_video_dir, fn)
|
||||||
if not sp or not os.path.exists(sp):
|
if not sp or not os.path.exists(sp):
|
||||||
continue
|
continue
|
||||||
gt_pos = sorted(pos_by_video[fn])
|
gt_pos = _dedup_times(pos_by_video.get(fn, set()))
|
||||||
gt_soft = sorted(soft_by_video.get(fn, set()))
|
gt_soft = _dedup_times(soft_by_video.get(fn, set()))
|
||||||
result.append((sp, gt_pos, gt_soft))
|
gt_neg = _dedup_times(neg_by_video.get(fn, set()))
|
||||||
|
result.append((sp, gt_pos, gt_soft, gt_neg))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_training_stats(self, profile: str) -> dict[str, dict]:
|
def get_training_stats(self, profile: str) -> dict[str, dict]:
|
||||||
@@ -329,6 +364,92 @@ class ProcessedDB:
|
|||||||
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
# ── Scan results ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||||
|
regions: list[tuple[float, float, float]]) -> None:
|
||||||
|
"""Replace scan results for (filename, profile, model) with new regions.
|
||||||
|
|
||||||
|
regions: list of (start_time, end_time, score).
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?",
|
||||||
|
(filename, profile, model),
|
||||||
|
)
|
||||||
|
self._con.executemany(
|
||||||
|
"INSERT INTO scan_results"
|
||||||
|
" (filename, profile, model, start_time, end_time, score)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
[(filename, profile, model, s, e, sc) for s, e, sc in regions],
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_scan_results(self, filename: str, profile: str
|
||||||
|
) -> dict[str, list[tuple[int, float, float, float]]]:
|
||||||
|
"""Return scan results grouped by model.
|
||||||
|
|
||||||
|
Returns {model: [(row_id, start_time, end_time, score), ...]} sorted by
|
||||||
|
start_time.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return {}
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, model, start_time, end_time, score FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" ORDER BY model, start_time",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
result: dict[str, list[tuple[int, float, float, float]]] = {}
|
||||||
|
for row_id, model, s, e, sc in rows:
|
||||||
|
result.setdefault(model, []).append((row_id, s, e, sc))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_scan_result(self, row_id: int) -> None:
|
||||||
|
"""Delete a single scan result row."""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute("DELETE FROM scan_results WHERE id = ?", (row_id,))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_scan_models(self, filename: str, profile: str) -> list[str]:
|
||||||
|
"""Return model names that have scan results for this file."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT model FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? ORDER BY model",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def get_scanned_filenames(self, profile: str, model: str) -> set[str]:
|
||||||
|
"""Return filenames that already have scan results for this model."""
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT filename FROM scan_results"
|
||||||
|
" WHERE profile = ? AND model = ?",
|
||||||
|
(profile, model),
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def get_training_filenames(self, profile: str) -> set[str]:
|
||||||
|
"""Return filenames used in training (have exported clips)."""
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT filename FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
# ── Hidden files ───────────────────────────────────────────
|
||||||
|
|
||||||
def hide_file(self, filename: str, profile: str = "default") -> None:
|
def hide_file(self, filename: str, profile: str = "default") -> None:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from PyQt6.QtWidgets import (
|
|||||||
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
|
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
|
||||||
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
|
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
|
||||||
QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
|
QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
|
||||||
|
QTableWidget, QTableWidgetItem, QTabWidget, QHeaderView,
|
||||||
)
|
)
|
||||||
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
|
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
|
||||||
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
|
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
|
||||||
@@ -244,6 +245,15 @@ class TrainDialog(QDialog):
|
|||||||
self._cmb_positive.addItem(label, userData=folder_name)
|
self._cmb_positive.addItem(label, userData=folder_name)
|
||||||
form.addRow("Positive class:", self._cmb_positive)
|
form.addRow("Positive class:", self._cmb_positive)
|
||||||
|
|
||||||
|
# Negative class selector (optional)
|
||||||
|
self._cmb_negative = QComboBox()
|
||||||
|
self._cmb_negative.addItem("(auto only)", userData="")
|
||||||
|
for folder_name, info in stats.items():
|
||||||
|
label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)"
|
||||||
|
self._cmb_negative.addItem(label, userData=folder_name)
|
||||||
|
self._cmb_negative.currentIndexChanged.connect(lambda: self._debounce.start())
|
||||||
|
form.addRow("Negative class:", self._cmb_negative)
|
||||||
|
|
||||||
# Model selector
|
# Model selector
|
||||||
self._cmb_model = QComboBox()
|
self._cmb_model = QComboBox()
|
||||||
for name in _EMBED_MODELS:
|
for name in _EMBED_MODELS:
|
||||||
@@ -251,6 +261,18 @@ class TrainDialog(QDialog):
|
|||||||
self._cmb_model.setCurrentText("WAV2VEC2_BASE")
|
self._cmb_model.setCurrentText("WAV2VEC2_BASE")
|
||||||
form.addRow("Model:", self._cmb_model)
|
form.addRow("Model:", self._cmb_model)
|
||||||
|
|
||||||
|
# Auto-negative margin (0 = disabled)
|
||||||
|
self._spn_neg_margin = QDoubleSpinBox()
|
||||||
|
self._spn_neg_margin.setDecimals(0)
|
||||||
|
self._spn_neg_margin.setRange(0.0, 600.0)
|
||||||
|
self._spn_neg_margin.setSingleStep(10.0)
|
||||||
|
self._spn_neg_margin.setValue(30.0)
|
||||||
|
self._spn_neg_margin.setSuffix("s")
|
||||||
|
self._spn_neg_margin.setSpecialValueText("Disabled")
|
||||||
|
self._spn_neg_margin.setToolTip(
|
||||||
|
"Auto-sample negatives from regions this far from any marker. 0 = disabled.")
|
||||||
|
form.addRow("Auto-neg margin:", self._spn_neg_margin)
|
||||||
|
|
||||||
# Video source directory (fallback for old DB rows without source_path)
|
# Video source directory (fallback for old DB rows without source_path)
|
||||||
self._txt_video_dir = QLineEdit(video_dir)
|
self._txt_video_dir = QLineEdit(video_dir)
|
||||||
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
||||||
@@ -265,7 +287,13 @@ class TrainDialog(QDialog):
|
|||||||
btn_browse.setFixedWidth(30)
|
btn_browse.setFixedWidth(30)
|
||||||
btn_browse.clicked.connect(self._browse_video_dir)
|
btn_browse.clicked.connect(self._browse_video_dir)
|
||||||
vid_row.addWidget(btn_browse)
|
vid_row.addWidget(btn_browse)
|
||||||
form.addRow("Video dir:", vid_row)
|
self._lbl_video_dir = QLabel("Video dir:")
|
||||||
|
self._video_dir_widget = QWidget()
|
||||||
|
self._video_dir_widget.setLayout(vid_row)
|
||||||
|
form.addRow(self._lbl_video_dir, self._video_dir_widget)
|
||||||
|
# Hidden by default — shown only if some videos are missing source_path
|
||||||
|
self._lbl_video_dir.setVisible(False)
|
||||||
|
self._video_dir_widget.setVisible(False)
|
||||||
|
|
||||||
layout.addLayout(form)
|
layout.addLayout(form)
|
||||||
|
|
||||||
@@ -297,17 +325,32 @@ class TrainDialog(QDialog):
|
|||||||
if not folder:
|
if not folder:
|
||||||
self._lbl_stats.setText("No export folder data available.")
|
self._lbl_stats.setText("No export folder data available.")
|
||||||
return
|
return
|
||||||
|
neg_folder = self._cmb_negative.currentData() or ""
|
||||||
|
# First check without fallback to see if source_paths are sufficient
|
||||||
|
video_infos_no_fb = self._db.get_training_data(
|
||||||
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
|
)
|
||||||
video_infos = self._db.get_training_data(
|
video_infos = self._db.get_training_data(
|
||||||
self._profile, folder,
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
fallback_video_dir=self._txt_video_dir.text(),
|
fallback_video_dir=self._txt_video_dir.text(),
|
||||||
)
|
)
|
||||||
|
# Show video dir field only when the fallback helps find extra videos
|
||||||
|
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
||||||
|
self._lbl_video_dir.setVisible(needs_fallback)
|
||||||
|
self._video_dir_widget.setVisible(needs_fallback)
|
||||||
|
|
||||||
n_videos = len(video_infos)
|
n_videos = len(video_infos)
|
||||||
n_pos = sum(len(gt) for _, gt, _ in video_infos)
|
n_pos = sum(len(vi[1]) for vi in video_infos)
|
||||||
n_soft = sum(len(s) for _, _, s in video_infos)
|
n_soft = sum(len(vi[2]) for vi in video_infos)
|
||||||
lines = [f"<b>{n_videos}</b> videos with positive clips"]
|
n_neg = sum(len(vi[3]) for vi in video_infos)
|
||||||
lines.append(f"<b>{n_pos}</b> positive markers, <b>{n_soft}</b> soft/buffer markers")
|
lines = [f"<b>{n_videos}</b> videos"]
|
||||||
|
lines.append(f"<b>{n_pos}</b> positive, <b>{n_soft}</b> soft/buffer"
|
||||||
|
+ (f", <b>{n_neg}</b> manual negative" if n_neg else "")
|
||||||
|
+ " markers")
|
||||||
if n_videos == 0:
|
if n_videos == 0:
|
||||||
lines.append("<i>No source videos found. Set Video dir above.</i>")
|
lines.append("<i>No source videos found. Set Video dir below.</i>")
|
||||||
|
self._lbl_video_dir.setVisible(True)
|
||||||
|
self._video_dir_widget.setVisible(True)
|
||||||
elif n_videos < 3:
|
elif n_videos < 3:
|
||||||
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
|
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
|
||||||
self._lbl_stats.setText("<br>".join(lines))
|
self._lbl_stats.setText("<br>".join(lines))
|
||||||
@@ -316,6 +359,14 @@ class TrainDialog(QDialog):
|
|||||||
def positive_folder(self) -> str:
|
def positive_folder(self) -> str:
|
||||||
return self._cmb_positive.currentData() or ""
|
return self._cmb_positive.currentData() or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def negative_folder(self) -> str:
|
||||||
|
return self._cmb_negative.currentData() or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neg_margin(self) -> float:
|
||||||
|
return self._spn_neg_margin.value()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embed_model(self) -> str:
|
def embed_model(self) -> str:
|
||||||
return self._cmb_model.currentText()
|
return self._cmb_model.currentText()
|
||||||
@@ -332,11 +383,14 @@ class TrainWorker(QThread):
|
|||||||
progress = pyqtSignal(str) # per-video status
|
progress = pyqtSignal(str) # per-video status
|
||||||
|
|
||||||
def __init__(self, video_infos: list, model_path: str,
|
def __init__(self, video_infos: list, model_path: str,
|
||||||
embed_model: str | None = None):
|
embed_model: str | None = None, n_workers: int = 4,
|
||||||
|
neg_margin: float = 120.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._video_infos = video_infos
|
self._video_infos = video_infos
|
||||||
self._model_path = model_path
|
self._model_path = model_path
|
||||||
self._embed_model = embed_model
|
self._embed_model = embed_model
|
||||||
|
self._n_workers = n_workers
|
||||||
|
self._neg_margin = neg_margin
|
||||||
self._cancel = False
|
self._cancel = False
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
@@ -349,8 +403,11 @@ class TrainWorker(QThread):
|
|||||||
result = train_classifier(
|
result = train_classifier(
|
||||||
self._video_infos,
|
self._video_infos,
|
||||||
model_path=self._model_path,
|
model_path=self._model_path,
|
||||||
|
neg_margin=self._neg_margin,
|
||||||
embed_model=self._embed_model,
|
embed_model=self._embed_model,
|
||||||
cancel_flag=self,
|
cancel_flag=self,
|
||||||
|
n_workers=self._n_workers,
|
||||||
|
progress_cb=self.progress.emit,
|
||||||
)
|
)
|
||||||
if self._cancel:
|
if self._cancel:
|
||||||
return
|
return
|
||||||
@@ -363,6 +420,152 @@ class TrainWorker(QThread):
|
|||||||
self.error.emit(str(e))
|
self.error.emit(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
class ScanResultsPanel(QWidget):
|
||||||
|
"""Tabbed panel showing scan results per model, with seek-on-click and delete."""
|
||||||
|
seek_requested = pyqtSignal(float) # request main window to seek to time
|
||||||
|
export_requested = pyqtSignal(list) # emit list of (start, end, score) to export
|
||||||
|
|
||||||
|
def __init__(self, db, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._db = db
|
||||||
|
self._filename = ""
|
||||||
|
self._profile = ""
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
layout.setSpacing(2)
|
||||||
|
|
||||||
|
self._tabs = QTabWidget()
|
||||||
|
self._tabs.setTabsClosable(False)
|
||||||
|
layout.addWidget(self._tabs)
|
||||||
|
|
||||||
|
btn_row = QHBoxLayout()
|
||||||
|
self._btn_export = QPushButton("Export Scan Results")
|
||||||
|
self._btn_export.setToolTip("Export clips from the active tab's scan results")
|
||||||
|
self._btn_export.clicked.connect(self._on_export)
|
||||||
|
btn_row.addStretch()
|
||||||
|
btn_row.addWidget(self._btn_export)
|
||||||
|
layout.addLayout(btn_row)
|
||||||
|
|
||||||
|
def load_for_file(self, filename: str, profile: str) -> None:
|
||||||
|
"""Load saved scan results from DB for a file."""
|
||||||
|
self._filename = filename
|
||||||
|
self._profile = profile
|
||||||
|
self._tabs.clear()
|
||||||
|
results = self._db.get_scan_results(filename, profile)
|
||||||
|
for model, rows in results.items():
|
||||||
|
self._add_tab(model, rows)
|
||||||
|
|
||||||
|
def add_scan_results(self, model: str,
|
||||||
|
regions: list[tuple[float, float, float]]) -> None:
|
||||||
|
"""Add/replace a tab with new scan results and save to DB."""
|
||||||
|
# Save to DB
|
||||||
|
self._db.save_scan_results(self._filename, self._profile, model, regions)
|
||||||
|
# Build row data with IDs from DB
|
||||||
|
db_results = self._db.get_scan_results(self._filename, self._profile)
|
||||||
|
rows = db_results.get(model, [])
|
||||||
|
# Remove existing tab for this model
|
||||||
|
for i in range(self._tabs.count()):
|
||||||
|
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
|
||||||
|
self._tabs.removeTab(i)
|
||||||
|
break
|
||||||
|
self._add_tab(model, rows)
|
||||||
|
# Switch to the new tab
|
||||||
|
for i in range(self._tabs.count()):
|
||||||
|
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
|
||||||
|
self._tabs.setCurrentIndex(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _add_tab(self, model: str,
|
||||||
|
rows: list[tuple[int, float, float, float]]) -> None:
|
||||||
|
"""Create a table tab. rows: [(row_id, start, end, score), ...]"""
|
||||||
|
table = QTableWidget(len(rows), 3)
|
||||||
|
table.setHorizontalHeaderLabels(["Time", "End", "Score"])
|
||||||
|
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
||||||
|
table.setSelectionMode(QTableWidget.SelectionMode.ExtendedSelection)
|
||||||
|
table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
|
||||||
|
table.verticalHeader().setVisible(False)
|
||||||
|
header = table.horizontalHeader()
|
||||||
|
header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
||||||
|
header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
|
||||||
|
header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
|
||||||
|
|
||||||
|
for i, (row_id, start, end, score) in enumerate(rows):
|
||||||
|
t_item = QTableWidgetItem(format_time(start))
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole, row_id)
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole + 1, start)
|
||||||
|
table.setItem(i, 0, t_item)
|
||||||
|
e_item = QTableWidgetItem(format_time(end))
|
||||||
|
e_item.setData(Qt.ItemDataRole.UserRole, end)
|
||||||
|
table.setItem(i, 1, e_item)
|
||||||
|
table.setItem(i, 2, QTableWidgetItem(f"{score:.2f}"))
|
||||||
|
|
||||||
|
table.itemSelectionChanged.connect(
|
||||||
|
lambda t=table: self._on_selection_changed(t))
|
||||||
|
self._tabs.addTab(table, f"{model} ({len(rows)})")
|
||||||
|
|
||||||
|
def _on_selection_changed(self, table: QTableWidget) -> None:
|
||||||
|
items = table.selectedItems()
|
||||||
|
if items:
|
||||||
|
row = items[0].row()
|
||||||
|
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
|
||||||
|
if start is not None:
|
||||||
|
self.seek_requested.emit(float(start))
|
||||||
|
|
||||||
|
def delete_selected(self) -> None:
|
||||||
|
"""Delete selected rows from active tab and DB."""
|
||||||
|
table = self._tabs.currentWidget()
|
||||||
|
if not isinstance(table, QTableWidget):
|
||||||
|
return
|
||||||
|
rows_to_delete = sorted(
|
||||||
|
{idx.row() for idx in table.selectedIndexes()}, reverse=True)
|
||||||
|
tab_idx = self._tabs.currentIndex()
|
||||||
|
model = self._tabs.tabText(tab_idx).rsplit(" (", 1)[0]
|
||||||
|
for row in rows_to_delete:
|
||||||
|
row_id = table.item(row, 0).data(Qt.ItemDataRole.UserRole)
|
||||||
|
if row_id is not None:
|
||||||
|
self._db.delete_scan_result(row_id)
|
||||||
|
table.removeRow(row)
|
||||||
|
# Update tab title with new count
|
||||||
|
count = table.rowCount()
|
||||||
|
self._tabs.setTabText(tab_idx, f"{model} ({count})")
|
||||||
|
|
||||||
|
def _get_tab_regions(self, table: QTableWidget
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Extract (start, end, score) from a table widget."""
|
||||||
|
regions = []
|
||||||
|
for row in range(table.rowCount()):
|
||||||
|
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
|
||||||
|
end = table.item(row, 1).data(Qt.ItemDataRole.UserRole)
|
||||||
|
score = float(table.item(row, 2).text())
|
||||||
|
regions.append((float(start), float(end), score))
|
||||||
|
return regions
|
||||||
|
|
||||||
|
def _on_export(self) -> None:
|
||||||
|
table = self._tabs.currentWidget()
|
||||||
|
if not isinstance(table, QTableWidget):
|
||||||
|
return
|
||||||
|
regions = self._get_tab_regions(table)
|
||||||
|
if regions:
|
||||||
|
self.export_requested.emit(regions)
|
||||||
|
|
||||||
|
def current_regions(self) -> list[tuple[float, float, float]]:
|
||||||
|
"""Return (start, end, score) for all rows in the active tab."""
|
||||||
|
table = self._tabs.currentWidget()
|
||||||
|
if not isinstance(table, QTableWidget):
|
||||||
|
return []
|
||||||
|
return self._get_tab_regions(table)
|
||||||
|
|
||||||
|
def has_results(self) -> bool:
|
||||||
|
return self._tabs.count() > 0
|
||||||
|
|
||||||
|
def keyPressEvent(self, event):
|
||||||
|
if event.key() in (Qt.Key.Key_Delete, Qt.Key.Key_Backspace):
|
||||||
|
self.delete_selected()
|
||||||
|
else:
|
||||||
|
super().keyPressEvent(event)
|
||||||
|
|
||||||
|
|
||||||
class TimelineWidget(QWidget):
|
class TimelineWidget(QWidget):
|
||||||
cursor_changed = pyqtSignal(float) # emits position in seconds
|
cursor_changed = pyqtSignal(float) # emits position in seconds
|
||||||
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
|
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
|
||||||
@@ -1710,6 +1913,15 @@ class MainWindow(QMainWindow):
|
|||||||
self._btn_train.clicked.connect(self._open_train_dialog)
|
self._btn_train.clicked.connect(self._open_train_dialog)
|
||||||
self._train_worker: TrainWorker | None = None
|
self._train_worker: TrainWorker | None = None
|
||||||
|
|
||||||
|
self._btn_scan_all = QPushButton("Scan All")
|
||||||
|
self._btn_scan_all.setToolTip("Scan all playlist videos that haven't been scanned yet")
|
||||||
|
self._btn_scan_all.clicked.connect(self._start_scan_all)
|
||||||
|
self._scan_all_queue: list[str] = []
|
||||||
|
|
||||||
|
self._cmb_scan_model = QComboBox()
|
||||||
|
self._cmb_scan_model.setToolTip("Trained embedding model to use for scanning")
|
||||||
|
self._cmb_scan_model.setMinimumWidth(120)
|
||||||
|
|
||||||
self._spn_auto_fuse = QDoubleSpinBox()
|
self._spn_auto_fuse = QDoubleSpinBox()
|
||||||
self._spn_auto_fuse.setDecimals(1)
|
self._spn_auto_fuse.setDecimals(1)
|
||||||
self._spn_auto_fuse.setRange(0.0, 60.0)
|
self._spn_auto_fuse.setRange(0.0, 60.0)
|
||||||
@@ -1800,6 +2012,7 @@ class MainWindow(QMainWindow):
|
|||||||
if idx >= 0:
|
if idx >= 0:
|
||||||
self._cmb_profile.setCurrentIndex(idx)
|
self._cmb_profile.setCurrentIndex(idx)
|
||||||
self._cmb_profile.activated.connect(self._on_profile_activated)
|
self._cmb_profile.activated.connect(self._on_profile_activated)
|
||||||
|
self._refresh_scan_models()
|
||||||
|
|
||||||
self._btn_shortcuts = QPushButton("?")
|
self._btn_shortcuts = QPushButton("?")
|
||||||
self._btn_shortcuts.setFixedWidth(28)
|
self._btn_shortcuts.setFixedWidth(28)
|
||||||
@@ -1864,11 +2077,13 @@ class MainWindow(QMainWindow):
|
|||||||
settings_row.addWidget(self._chk_rand_portrait)
|
settings_row.addWidget(self._chk_rand_portrait)
|
||||||
settings_row.addWidget(self._chk_rand_square)
|
settings_row.addWidget(self._chk_rand_square)
|
||||||
settings_row.addWidget(self._chk_track)
|
settings_row.addWidget(self._chk_track)
|
||||||
|
settings_row.addWidget(self._cmb_scan_model)
|
||||||
settings_row.addWidget(self._btn_scan)
|
settings_row.addWidget(self._btn_scan)
|
||||||
settings_row.addWidget(self._btn_auto_export)
|
settings_row.addWidget(self._btn_auto_export)
|
||||||
settings_row.addWidget(self._spn_auto_fuse)
|
settings_row.addWidget(self._spn_auto_fuse)
|
||||||
settings_row.addWidget(self._sld_threshold)
|
settings_row.addWidget(self._sld_threshold)
|
||||||
settings_row.addWidget(self._btn_train)
|
settings_row.addWidget(self._btn_train)
|
||||||
|
settings_row.addWidget(self._btn_scan_all)
|
||||||
settings_row.addStretch()
|
settings_row.addStretch()
|
||||||
self._lbl_status = QLabel()
|
self._lbl_status = QLabel()
|
||||||
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
|
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
|
||||||
@@ -1918,13 +2133,20 @@ class MainWindow(QMainWindow):
|
|||||||
left_layout.addLayout(left_top)
|
left_layout.addLayout(left_top)
|
||||||
left_layout.addWidget(self._playlist)
|
left_layout.addWidget(self._playlist)
|
||||||
|
|
||||||
|
# Scan results panel (right side)
|
||||||
|
self._scan_panel = ScanResultsPanel(self._db)
|
||||||
|
self._scan_panel.seek_requested.connect(self._on_scan_seek)
|
||||||
|
self._scan_panel.export_requested.connect(self._on_scan_export)
|
||||||
|
|
||||||
# Root: horizontal splitter
|
# Root: horizontal splitter
|
||||||
splitter = QSplitter(Qt.Orientation.Horizontal)
|
splitter = QSplitter(Qt.Orientation.Horizontal)
|
||||||
splitter.addWidget(left)
|
splitter.addWidget(left)
|
||||||
splitter.addWidget(right)
|
splitter.addWidget(right)
|
||||||
splitter.setSizes([200, 900])
|
splitter.addWidget(self._scan_panel)
|
||||||
|
splitter.setSizes([200, 900, 200])
|
||||||
splitter.setCollapsible(0, False)
|
splitter.setCollapsible(0, False)
|
||||||
splitter.setCollapsible(1, False)
|
splitter.setCollapsible(1, False)
|
||||||
|
splitter.setCollapsible(2, True)
|
||||||
|
|
||||||
self.setCentralWidget(splitter)
|
self.setCentralWidget(splitter)
|
||||||
self.setStatusBar(None)
|
self.setStatusBar(None)
|
||||||
@@ -2061,6 +2283,7 @@ class MainWindow(QMainWindow):
|
|||||||
self._btn_delete.setEnabled(False)
|
self._btn_delete.setEnabled(False)
|
||||||
self._update_next_label()
|
self._update_next_label()
|
||||||
self._apply_playlist_filters()
|
self._apply_playlist_filters()
|
||||||
|
self._refresh_scan_models()
|
||||||
if self._file_path:
|
if self._file_path:
|
||||||
self._refresh_markers()
|
self._refresh_markers()
|
||||||
_log(f"Profile switched: {text}")
|
_log(f"Profile switched: {text}")
|
||||||
@@ -2184,7 +2407,13 @@ class MainWindow(QMainWindow):
|
|||||||
if self._scan_worker and self._scan_worker.isRunning():
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
self._scan_worker.cancel()
|
self._scan_worker.cancel()
|
||||||
self._cleanup_scan_worker()
|
self._cleanup_scan_worker()
|
||||||
|
self._scan_all_queue.clear()
|
||||||
self._btn_scan.setEnabled(True)
|
self._btn_scan.setEnabled(True)
|
||||||
|
self._btn_scan_all.setEnabled(True)
|
||||||
|
# Load saved scan results for this file
|
||||||
|
if self._file_path:
|
||||||
|
filename = os.path.basename(self._file_path)
|
||||||
|
self._scan_panel.load_for_file(filename, self._profile)
|
||||||
|
|
||||||
dur = self._mpv.get_duration()
|
dur = self._mpv.get_duration()
|
||||||
self._timeline.set_duration(dur)
|
self._timeline.set_duration(dur)
|
||||||
@@ -2653,8 +2882,42 @@ class MainWindow(QMainWindow):
|
|||||||
return
|
return
|
||||||
self._step_cursor(markers[0][0] - self._cursor) # wrap to first
|
self._step_cursor(markers[0][0] - self._cursor) # wrap to first
|
||||||
|
|
||||||
|
def _load_selected_scan_model(self) -> tuple:
|
||||||
|
"""Load the classifier selected in the scan model combo.
|
||||||
|
|
||||||
|
Returns (model_dict, label_str) or (None, "") on failure.
|
||||||
|
"""
|
||||||
|
from core.audio_scan import load_classifier, default_model_path
|
||||||
|
sel = self._cmb_scan_model.currentText()
|
||||||
|
if not sel or sel == "(no model)":
|
||||||
|
self._show_status("No trained model — click Train first")
|
||||||
|
return None, ""
|
||||||
|
embed_name = None if sel == "(legacy)" else sel
|
||||||
|
model_path = default_model_path(self._profile, embed_name)
|
||||||
|
model = load_classifier(model_path)
|
||||||
|
if model is None:
|
||||||
|
self._show_status(f"Model file missing: {model_path}")
|
||||||
|
return None, ""
|
||||||
|
return model, sel
|
||||||
|
|
||||||
|
def _refresh_scan_models(self) -> None:
|
||||||
|
"""Populate the scan model combo with trained models for the current profile."""
|
||||||
|
from core.audio_scan import list_trained_models
|
||||||
|
prev = self._cmb_scan_model.currentText()
|
||||||
|
self._cmb_scan_model.clear()
|
||||||
|
models = list_trained_models(self._profile)
|
||||||
|
if not models:
|
||||||
|
self._cmb_scan_model.addItem("(no model)")
|
||||||
|
else:
|
||||||
|
for m in models:
|
||||||
|
self._cmb_scan_model.addItem(m if m else "(legacy)")
|
||||||
|
# Restore previous selection if still available
|
||||||
|
idx = self._cmb_scan_model.findText(prev)
|
||||||
|
if idx >= 0:
|
||||||
|
self._cmb_scan_model.setCurrentIndex(idx)
|
||||||
|
|
||||||
def _cleanup_scan_worker(self) -> None:
|
def _cleanup_scan_worker(self) -> None:
|
||||||
"""Disconnect signals and schedule deletion of old scan worker."""
|
"""Disconnect signals, cancel, and schedule deletion of old scan worker."""
|
||||||
if self._scan_worker is not None:
|
if self._scan_worker is not None:
|
||||||
try:
|
try:
|
||||||
self._scan_worker.scan_done.disconnect()
|
self._scan_worker.scan_done.disconnect()
|
||||||
@@ -2662,8 +2925,8 @@ class MainWindow(QMainWindow):
|
|||||||
self._scan_worker.progress.disconnect()
|
self._scan_worker.progress.disconnect()
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass # already disconnected
|
pass # already disconnected
|
||||||
|
self._scan_worker.cancel()
|
||||||
if self._scan_worker.isRunning():
|
if self._scan_worker.isRunning():
|
||||||
# QThread.finished fires when run() returns, even on cancel
|
|
||||||
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
|
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
|
||||||
else:
|
else:
|
||||||
self._scan_worker.deleteLater()
|
self._scan_worker.deleteLater()
|
||||||
@@ -2682,17 +2945,14 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
threshold = self._sld_threshold.value()
|
threshold = self._sld_threshold.value()
|
||||||
|
|
||||||
from core.audio_scan import load_classifier, default_model_path
|
model, model_label = self._load_selected_scan_model()
|
||||||
model_path = default_model_path(self._profile)
|
|
||||||
model = load_classifier(model_path)
|
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
self._show_status("No trained model — click Train first")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._btn_scan.setEnabled(False)
|
self._btn_scan.setEnabled(False)
|
||||||
self._scan_file_path = self._file_path
|
self._scan_file_path = self._file_path
|
||||||
self._show_status("Scanning...")
|
self._scan_model_label = model_label
|
||||||
|
self._show_status(f"Scanning ({model_label})...")
|
||||||
self._scan_worker = ScanWorker(
|
self._scan_worker = ScanWorker(
|
||||||
self._file_path, model=model, threshold=threshold,
|
self._file_path, model=model, threshold=threshold,
|
||||||
)
|
)
|
||||||
@@ -2708,6 +2968,10 @@ class MainWindow(QMainWindow):
|
|||||||
if self._file_path != getattr(self, '_scan_file_path', None):
|
if self._file_path != getattr(self, '_scan_file_path', None):
|
||||||
return
|
return
|
||||||
self._timeline.set_scan_regions(regions)
|
self._timeline.set_scan_regions(regions)
|
||||||
|
model_label = getattr(self, '_scan_model_label', '')
|
||||||
|
if model_label and self._file_path:
|
||||||
|
filename = os.path.basename(self._file_path)
|
||||||
|
self._scan_panel.add_scan_results(model_label, regions)
|
||||||
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
||||||
|
|
||||||
def _on_scan_error(self, msg: str) -> None:
|
def _on_scan_error(self, msg: str) -> None:
|
||||||
@@ -2715,6 +2979,105 @@ class MainWindow(QMainWindow):
|
|||||||
self._btn_auto_export.setEnabled(True)
|
self._btn_auto_export.setEnabled(True)
|
||||||
self._show_status(f"Scan error: {msg}")
|
self._show_status(f"Scan error: {msg}")
|
||||||
|
|
||||||
|
def _on_scan_seek(self, t: float) -> None:
|
||||||
|
"""Seek player when a scan result row is clicked."""
|
||||||
|
if self._file_path:
|
||||||
|
self._cursor = t
|
||||||
|
self._mpv.seek(t)
|
||||||
|
self._timeline.set_cursor(t)
|
||||||
|
dur = self._mpv.get_duration()
|
||||||
|
self._lbl_time.setText(f"{format_time(t)} / {format_time(dur)}")
|
||||||
|
|
||||||
|
def _on_scan_export(self, regions: list) -> None:
|
||||||
|
"""Export clips from scan results panel."""
|
||||||
|
if not self._file_path or not regions:
|
||||||
|
return
|
||||||
|
if self._export_worker and self._export_worker.isRunning():
|
||||||
|
self._show_status("Export already running…")
|
||||||
|
return
|
||||||
|
self._auto_export_regions(regions)
|
||||||
|
|
||||||
|
# ── Scan All ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _start_scan_all(self) -> None:
|
||||||
|
"""Scan all playlist videos not yet scanned with the selected model."""
|
||||||
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
|
self._show_status("Scan already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
model, model_label = self._load_selected_scan_model()
|
||||||
|
if model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build queue: playlist files minus already-scanned and training files
|
||||||
|
all_paths = self._playlist._paths
|
||||||
|
scanned = self._db.get_scanned_filenames(self._profile, model_label)
|
||||||
|
training = self._db.get_training_filenames(self._profile)
|
||||||
|
skip = scanned | training
|
||||||
|
|
||||||
|
self._scan_all_queue = [
|
||||||
|
p for p in all_paths if os.path.basename(p) not in skip
|
||||||
|
]
|
||||||
|
if not self._scan_all_queue:
|
||||||
|
self._show_status("All videos already scanned or used for training")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._scan_all_model = model
|
||||||
|
self._scan_all_model_label = model_label
|
||||||
|
self._scan_all_profile = self._profile
|
||||||
|
self._scan_all_total = len(self._scan_all_queue)
|
||||||
|
self._btn_scan_all.setEnabled(False)
|
||||||
|
self._btn_scan.setEnabled(False)
|
||||||
|
self._show_status(
|
||||||
|
f"Scan All: 0/{self._scan_all_total} ({model_label})")
|
||||||
|
self._scan_all_next()
|
||||||
|
|
||||||
|
def _scan_all_next(self) -> None:
|
||||||
|
"""Start scanning the next video in the queue."""
|
||||||
|
if not self._scan_all_queue:
|
||||||
|
self._btn_scan_all.setEnabled(True)
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
done = self._scan_all_total
|
||||||
|
self._show_status(f"Scan All complete: {done} videos scanned")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._cleanup_scan_worker()
|
||||||
|
path = self._scan_all_queue.pop(0)
|
||||||
|
remaining = self._scan_all_total - len(self._scan_all_queue)
|
||||||
|
self._scan_all_current_path = path
|
||||||
|
self._show_status(
|
||||||
|
f"Scan All: {remaining}/{self._scan_all_total} — "
|
||||||
|
f"{os.path.basename(path)}")
|
||||||
|
|
||||||
|
threshold = self._sld_threshold.value()
|
||||||
|
self._scan_worker = ScanWorker(
|
||||||
|
path, model=self._scan_all_model, threshold=threshold,
|
||||||
|
)
|
||||||
|
self._scan_worker.scan_done.connect(self._on_scan_all_done)
|
||||||
|
self._scan_worker.error.connect(self._on_scan_all_error)
|
||||||
|
self._scan_worker.start()
|
||||||
|
|
||||||
|
def _on_scan_all_done(self, regions: list) -> None:
|
||||||
|
"""Save batch scan results and continue to next video."""
|
||||||
|
path = getattr(self, '_scan_all_current_path', '')
|
||||||
|
model_label = getattr(self, '_scan_all_model_label', '')
|
||||||
|
if path and model_label:
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
profile = getattr(self, '_scan_all_profile', self._profile)
|
||||||
|
self._db.save_scan_results(
|
||||||
|
filename, profile, model_label, regions)
|
||||||
|
# If this is the currently loaded file, update the panel
|
||||||
|
if self._file_path and os.path.basename(self._file_path) == filename:
|
||||||
|
self._scan_panel.load_for_file(filename, self._profile)
|
||||||
|
self._timeline.set_scan_regions(regions)
|
||||||
|
self._scan_all_next()
|
||||||
|
|
||||||
|
def _on_scan_all_error(self, msg: str) -> None:
|
||||||
|
"""Log error and continue to next video."""
|
||||||
|
path = getattr(self, '_scan_all_current_path', '')
|
||||||
|
_log(f"Scan All error on {os.path.basename(path)}: {msg}")
|
||||||
|
self._scan_all_next()
|
||||||
|
|
||||||
# ── Training ────────────────────────────────────────────────
|
# ── Training ────────────────────────────────────────────────
|
||||||
|
|
||||||
def _cleanup_train_worker(self) -> None:
|
def _cleanup_train_worker(self) -> None:
|
||||||
@@ -2751,6 +3114,8 @@ class MainWindow(QMainWindow):
|
|||||||
return
|
return
|
||||||
|
|
||||||
pos_folder = dlg.positive_folder
|
pos_folder = dlg.positive_folder
|
||||||
|
neg_folder = dlg.negative_folder
|
||||||
|
neg_margin = dlg.neg_margin
|
||||||
embed_model = dlg.embed_model
|
embed_model = dlg.embed_model
|
||||||
video_dir = dlg.video_dir
|
video_dir = dlg.video_dir
|
||||||
if not pos_folder:
|
if not pos_folder:
|
||||||
@@ -2762,20 +3127,22 @@ class MainWindow(QMainWindow):
|
|||||||
self._settings.setValue("train_video_dir", video_dir)
|
self._settings.setValue("train_video_dir", video_dir)
|
||||||
|
|
||||||
video_infos = self._db.get_training_data(
|
video_infos = self._db.get_training_data(
|
||||||
self._profile, pos_folder, fallback_video_dir=video_dir,
|
self._profile, pos_folder, negative_folder=neg_folder,
|
||||||
|
fallback_video_dir=video_dir,
|
||||||
)
|
)
|
||||||
if not video_infos:
|
if not video_infos:
|
||||||
self._show_status("No training data found for this subprofile")
|
self._show_status("No training data found for this subprofile")
|
||||||
return
|
return
|
||||||
|
|
||||||
from core.audio_scan import default_model_path
|
from core.audio_scan import default_model_path
|
||||||
model_path = default_model_path(self._profile)
|
model_path = default_model_path(self._profile, embed_model)
|
||||||
|
|
||||||
self._cleanup_train_worker()
|
self._cleanup_train_worker()
|
||||||
self._btn_train.setEnabled(False)
|
self._btn_train.setEnabled(False)
|
||||||
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
|
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
|
||||||
|
|
||||||
self._train_worker = TrainWorker(video_infos, model_path, embed_model)
|
n_workers = self._spn_workers.value()
|
||||||
|
self._train_worker = TrainWorker(video_infos, model_path, embed_model, n_workers, neg_margin)
|
||||||
self._train_worker.train_done.connect(self._on_train_done)
|
self._train_worker.train_done.connect(self._on_train_done)
|
||||||
self._train_worker.error.connect(self._on_train_error)
|
self._train_worker.error.connect(self._on_train_error)
|
||||||
self._train_worker.progress.connect(self._show_status)
|
self._train_worker.progress.connect(self._show_status)
|
||||||
@@ -2783,6 +3150,7 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
def _on_train_done(self, model_path: str):
|
def _on_train_done(self, model_path: str):
|
||||||
self._btn_train.setEnabled(True)
|
self._btn_train.setEnabled(True)
|
||||||
|
self._refresh_scan_models()
|
||||||
self._show_status(f"Model trained and saved")
|
self._show_status(f"Model trained and saved")
|
||||||
_log(f"Training complete: {model_path}")
|
_log(f"Training complete: {model_path}")
|
||||||
|
|
||||||
@@ -2810,22 +3178,19 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
threshold = self._sld_threshold.value()
|
threshold = self._sld_threshold.value()
|
||||||
|
|
||||||
from core.audio_scan import load_classifier, default_model_path
|
model, model_label = self._load_selected_scan_model()
|
||||||
model_path = default_model_path(self._profile)
|
if model is None:
|
||||||
model = load_classifier(model_path)
|
|
||||||
|
|
||||||
if model is not None:
|
|
||||||
self._scan_file_path = self._file_path
|
|
||||||
self._show_status("Auto: scanning with classifier...")
|
|
||||||
self._scan_worker = ScanWorker(
|
|
||||||
self._file_path, model=model, threshold=threshold,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._show_status("Auto: no trained model — click Train first")
|
|
||||||
self._btn_auto_export.setEnabled(True)
|
self._btn_auto_export.setEnabled(True)
|
||||||
self._btn_scan.setEnabled(True)
|
self._btn_scan.setEnabled(True)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._scan_file_path = self._file_path
|
||||||
|
self._scan_model_label = model_label
|
||||||
|
self._show_status(f"Auto: scanning ({model_label})...")
|
||||||
|
self._scan_worker = ScanWorker(
|
||||||
|
self._file_path, model=model, threshold=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
self._scan_worker.scan_done.connect(self._on_auto_scan_done)
|
self._scan_worker.scan_done.connect(self._on_auto_scan_done)
|
||||||
self._scan_worker.error.connect(self._on_scan_error)
|
self._scan_worker.error.connect(self._on_scan_error)
|
||||||
self._scan_worker.progress.connect(self._show_status)
|
self._scan_worker.progress.connect(self._show_status)
|
||||||
@@ -2879,7 +3244,15 @@ class MainWindow(QMainWindow):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._timeline.set_scan_regions(regions)
|
self._timeline.set_scan_regions(regions)
|
||||||
|
# Also save to scan panel
|
||||||
|
model_label = getattr(self, '_scan_model_label', '')
|
||||||
|
if model_label and self._file_path:
|
||||||
|
self._scan_panel.add_scan_results(model_label, regions)
|
||||||
|
|
||||||
|
self._auto_export_regions(regions)
|
||||||
|
|
||||||
|
def _auto_export_regions(self, regions: list) -> None:
|
||||||
|
"""Export clips from a list of (start, end, score) regions."""
|
||||||
if not regions:
|
if not regions:
|
||||||
self._show_status("Auto: no regions found")
|
self._show_status("Auto: no regions found")
|
||||||
self._btn_auto_export.setEnabled(True)
|
self._btn_auto_export.setEnabled(True)
|
||||||
@@ -2896,6 +3269,7 @@ class MainWindow(QMainWindow):
|
|||||||
# Build export jobs — one 8s clip per position
|
# Build export jobs — one 8s clip per position
|
||||||
folder = self._txt_folder.text()
|
folder = self._txt_folder.text()
|
||||||
name = self._txt_name.text() or "clip"
|
name = self._txt_name.text() or "clip"
|
||||||
|
self._auto_export_name = name
|
||||||
fmt = self._cmb_format.currentText()
|
fmt = self._cmb_format.currentText()
|
||||||
image_sequence = fmt == "WebP sequence"
|
image_sequence = fmt == "WebP sequence"
|
||||||
os.makedirs(folder, exist_ok=True)
|
os.makedirs(folder, exist_ok=True)
|
||||||
@@ -2959,7 +3333,7 @@ class MainWindow(QMainWindow):
|
|||||||
"""Record each auto-exported clip to DB."""
|
"""Record each auto-exported clip to DB."""
|
||||||
# Find the start_time for this clip from stashed positions
|
# Find the start_time for this clip from stashed positions
|
||||||
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
|
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
|
||||||
name = self._txt_name.text() or "clip"
|
name = getattr(self, '_auto_export_name', self._txt_name.text() or "clip")
|
||||||
start_t = None
|
start_t = None
|
||||||
for t, c in self._auto_export_positions:
|
for t, c in self._auto_export_positions:
|
||||||
if counter_str == f"{name}_{c:03d}":
|
if counter_str == f"{name}_{c:03d}":
|
||||||
@@ -3306,6 +3680,11 @@ class MainWindow(QMainWindow):
|
|||||||
# Cancel background workers to prevent callbacks into dead objects.
|
# Cancel background workers to prevent callbacks into dead objects.
|
||||||
self._cleanup_scan_worker()
|
self._cleanup_scan_worker()
|
||||||
self._cleanup_train_worker()
|
self._cleanup_train_worker()
|
||||||
|
if self._export_worker and self._export_worker.isRunning():
|
||||||
|
self._export_worker.cancel()
|
||||||
|
self._export_worker.wait(3000)
|
||||||
|
if hasattr(self, '_db_worker') and self._db_worker and self._db_worker.isRunning():
|
||||||
|
self._db_worker.wait(1000)
|
||||||
# Stop timers first to prevent callbacks into dead objects.
|
# Stop timers first to prevent callbacks into dead objects.
|
||||||
self._preview_timer.stop()
|
self._preview_timer.stop()
|
||||||
self._mpv._render_timer.stop()
|
self._mpv._render_timer.stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user