diff --git a/core/audio_scan.py b/core/audio_scan.py index b42e4b7..8a81fa2 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -425,6 +425,14 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]], parent = os.path.dirname(model_path) if parent: os.makedirs(parent, exist_ok=True) + # Version backup: keep previous model before overwriting + if os.path.exists(model_path): + from datetime import datetime + stem, ext = os.path.splitext(model_path) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + backup = f"{stem}_{ts}{ext}" + os.rename(model_path, backup) + _log(f"audio_scan: previous model backed up to {os.path.basename(backup)}") joblib.dump(model, model_path) _log(f"audio_scan: model saved to {model_path}") @@ -451,6 +459,49 @@ def default_model_path(profile_name: str = "default", return os.path.join(_MODEL_DIR, f"{profile_name}.joblib") +def list_model_versions(profile_name: str = "default", + embed_model: str | None = None) -> list[tuple[str, str]]: + """Return available backup versions for a model, newest first. + + Returns list of (timestamp_label, file_path). + The current (active) model is listed first as "current". + """ + import re + current = default_model_path(profile_name, embed_model) + stem, ext = os.path.splitext(current) + versions: list[tuple[str, str]] = [] + if os.path.exists(current): + versions.append(("current", current)) + if not os.path.isdir(_MODEL_DIR): + return versions + pattern = re.compile(re.escape(os.path.basename(stem)) + r"_(\d{8}_\d{6})" + re.escape(ext) + "$") + for fname in os.listdir(_MODEL_DIR): + m = pattern.match(fname) + if m: + versions.append((m.group(1), os.path.join(_MODEL_DIR, fname))) + # Sort backups newest first (after "current") + current_entry = versions[:1] + backups = sorted(versions[1:], key=lambda v: v[0], reverse=True) + return current_entry + backups + + +def restore_model_version(version_path: str, profile_name: str = "default", + embed_model: str | None = None) -> None: + """Restore a backup version as the active model.""" + from datetime import datetime + current = default_model_path(profile_name, embed_model) + if version_path == current: + return + # Back up current before replacing + if os.path.exists(current): + stem, ext = os.path.splitext(current) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + os.rename(current, f"{stem}_{ts}{ext}") + import shutil + shutil.copy2(version_path, current) + _log(f"audio_scan: restored {os.path.basename(version_path)} as active model") + + def list_trained_models(profile_name: str = "default") -> list[str]: """Return embedding model names that have a trained .joblib for *profile_name*. @@ -478,6 +529,25 @@ def list_trained_models(profile_name: str = "default") -> list[str]: # Scanning # --------------------------------------------------------------------------- +def _fuse_regions(regions: list[tuple[float, float, float]] + ) -> list[tuple[float, float, float]]: + """Merge overlapping/adjacent regions, keeping max score.""" + if not regions: + return [] + by_start = sorted(regions, key=lambda r: r[0]) + fused: list[tuple[float, float, float]] = [] + s, e, sc = by_start[0] + for s2, e2, sc2 in by_start[1:]: + if s2 <= e: # overlapping or touching + e = max(e, e2) + sc = max(sc, sc2) + else: + fused.append((s, e, sc)) + s, e, sc = s2, e2, sc2 + fused.append((s, e, sc)) + return fused + + def scan_video( video_path: str, model: dict = None, @@ -532,9 +602,10 @@ def scan_video( probs = clf.predict_proba(normed)[:, 1] mask = probs >= threshold - results = [ + raw = [ (timestamps[i], timestamps[i] + window, float(probs[i])) for i in np.nonzero(mask)[0] ] - _log(f"audio_scan: {len(results)} regions above threshold {threshold}") + results = _fuse_regions(raw) + _log(f"audio_scan: {len(results)} regions above threshold {threshold} (from {len(raw)} raw)") return results diff --git a/core/db.py b/core/db.py index 659b74b..ebde646 100644 --- a/core/db.py +++ b/core/db.py @@ -65,6 +65,7 @@ class ProcessedDB: "spread": "REAL NOT NULL DEFAULT 3.0", "profile": "TEXT NOT NULL DEFAULT 'default'", "source_path": "TEXT NOT NULL DEFAULT ''", + "scan_export": "INTEGER NOT NULL DEFAULT 0", } for col, typedef in new_cols.items(): if col not in cols: @@ -96,6 +97,19 @@ class ProcessedDB: "CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model" " ON scan_results(filename, profile, model)" ) + self._con.execute( + "CREATE TABLE IF NOT EXISTS hard_negatives (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " filename TEXT NOT NULL," + " profile TEXT NOT NULL DEFAULT 'default'," + " start_time REAL NOT NULL," + " source_path TEXT NOT NULL DEFAULT ''" + ")" + ) + self._con.execute( + "CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile" + " ON hard_negatives(filename, profile)" + ) self._con.commit() def add(self, filename: str, start_time: float, output_path: str, @@ -103,7 +117,8 @@ class ProcessedDB: short_side: int | None = None, portrait_ratio: str = "", crop_center: float = 0.5, fmt: str = "MP4", clip_count: int = 3, spread: float = 3.0, - profile: str = "default", source_path: str = "") -> None: + profile: str = "default", source_path: str = "", + scan_export: bool = False) -> None: if not self._enabled: return with self._lock: @@ -111,11 +126,12 @@ class ProcessedDB: "INSERT INTO processed" " (filename, start_time, output_path, label, category," " short_side, portrait_ratio, crop_center, format," - " clip_count, spread, profile, source_path, processed_at)" - " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + " clip_count, spread, profile, source_path, scan_export, processed_at)" + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (filename, start_time, output_path, label, category, short_side, portrait_ratio, crop_center, fmt, clip_count, spread, profile, source_path, + 1 if scan_export else 0, datetime.now(timezone.utc).isoformat()), ) self._con.commit() @@ -207,7 +223,8 @@ class ProcessedDB: def _get_markers_for(self, match: str, profile: str = "default") -> list[tuple[float, int, str]]: rows = self._con.execute( "SELECT start_time, output_path FROM processed" - " WHERE filename = ? AND profile = ? ORDER BY start_time", + " WHERE filename = ? AND profile = ? AND scan_export = 0" + " ORDER BY start_time", (match, profile), ).fetchall() # Deduplicate by start_time — batch exports share the same cursor. @@ -269,6 +286,7 @@ class ProcessedDB: def get_training_data(self, profile: str, positive_folder: str, negative_folder: str = "", fallback_video_dir: str = "", + include_scan_exports: bool = False, ) -> list[tuple[str, list[float], list[float], list[float]]]: """Build training video_infos from DB data. @@ -277,6 +295,7 @@ class ProcessedDB: positive_folder: export folder name for positive class (e.g. "mp4_Intense") negative_folder: export folder name for explicit negatives (optional) fallback_video_dir: if source_path is empty, try filename in this dir + include_scan_exports: if True, include auto-exported scan clips Returns: list of (source_video_path, positive_times, soft_times, negative_times) @@ -284,11 +303,18 @@ class ProcessedDB: """ if not self._enabled: return [] - rows = self._con.execute( - "SELECT filename, start_time, output_path, source_path" - " FROM processed WHERE profile = ?", - (profile,), - ).fetchall() + if include_scan_exports: + rows = self._con.execute( + "SELECT filename, start_time, output_path, source_path" + " FROM processed WHERE profile = ?", + (profile,), + ).fetchall() + else: + rows = self._con.execute( + "SELECT filename, start_time, output_path, source_path" + " FROM processed WHERE profile = ? AND scan_export = 0", + (profile,), + ).fetchall() # Collect times by video, split by folder role pos_by_video: dict[str, set[float]] = {} @@ -307,6 +333,17 @@ class ProcessedDB: else: soft_by_video.setdefault(fn, set()).add(st) + # Include hard negatives from scan feedback + hard_rows = self._con.execute( + "SELECT filename, start_time, source_path FROM hard_negatives" + " WHERE profile = ?", + (profile,), + ).fetchall() + for fn, st, sp in hard_rows: + neg_by_video.setdefault(fn, set()).add(st) + if sp: + source_by_filename.setdefault(fn, sp) + # Remove positive times from soft/neg to avoid conflicting labels for fn in pos_by_video: if fn in soft_by_video: @@ -442,6 +479,45 @@ class ProcessedDB: ).fetchall() return {r[0] for r in rows} + def add_hard_negatives(self, filename: str, profile: str, + times: list[float], source_path: str = "") -> None: + """Save timestamps as hard-negative training examples.""" + if not self._enabled or not times: + return + with self._lock: + for t in times: + self._con.execute( + "INSERT INTO hard_negatives (filename, profile, start_time, source_path)" + " VALUES (?, ?, ?, ?)", + (filename, profile, t, source_path), + ) + self._con.commit() + + def get_hard_negative_times(self, filename: str, profile: str) -> set[float]: + """Return start_times marked as hard negatives for this file.""" + if not self._enabled: + return set() + rows = self._con.execute( + "SELECT start_time FROM hard_negatives" + " WHERE filename = ? AND profile = ?", + (filename, profile), + ).fetchall() + return {r[0] for r in rows} + + def remove_hard_negatives(self, filename: str, profile: str, + times: list[float]) -> None: + """Remove specific hard-negative timestamps.""" + if not self._enabled or not times: + return + with self._lock: + for t in times: + self._con.execute( + "DELETE FROM hard_negatives" + " WHERE filename = ? AND profile = ? AND start_time = ?", + (filename, profile, t), + ) + self._con.commit() + def get_training_filenames(self, profile: str) -> set[str]: """Return filenames used in training (have exported clips).""" if not self._enabled: diff --git a/main.py b/main.py index aafe290..cb28a26 100755 --- a/main.py +++ b/main.py @@ -258,7 +258,7 @@ class TrainDialog(QDialog): self._cmb_model = QComboBox() for name in _EMBED_MODELS: self._cmb_model.addItem(name) - self._cmb_model.setCurrentText("WAV2VEC2_BASE") + self._cmb_model.setCurrentText("HUBERT_XLARGE") form.addRow("Model:", self._cmb_model) # Auto-negative margin (0 = disabled) @@ -273,6 +273,11 @@ class TrainDialog(QDialog): "Auto-sample negatives from regions this far from any marker. 0 = disabled.") form.addRow("Auto-neg margin:", self._spn_neg_margin) + self._chk_scan_exports = QCheckBox("Include scan-exported clips in training") + self._chk_scan_exports.setToolTip("When checked, clips auto-exported from scan results are included as training data") + self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start()) + form.addRow("", self._chk_scan_exports) + # Video source directory (fallback for old DB rows without source_path) self._txt_video_dir = QLineEdit(video_dir) self._txt_video_dir.setPlaceholderText("Directory containing source videos") @@ -326,13 +331,16 @@ class TrainDialog(QDialog): self._lbl_stats.setText("No export folder data available.") return neg_folder = self._cmb_negative.currentData() or "" + inc_scan = self._chk_scan_exports.isChecked() # First check without fallback to see if source_paths are sufficient video_infos_no_fb = self._db.get_training_data( self._profile, folder, negative_folder=neg_folder, + include_scan_exports=inc_scan, ) video_infos = self._db.get_training_data( self._profile, folder, negative_folder=neg_folder, fallback_video_dir=self._txt_video_dir.text(), + include_scan_exports=inc_scan, ) # Show video dir field only when the fallback helps find extra videos needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0 @@ -375,6 +383,10 @@ class TrainDialog(QDialog): def video_dir(self) -> str: return self._txt_video_dir.text() + @property + def include_scan_exports(self) -> bool: + return self._chk_scan_exports.isChecked() + class TrainWorker(QThread): """Trains an audio classifier off the main thread.""" @@ -424,12 +436,16 @@ class ScanResultsPanel(QWidget): """Tabbed panel showing scan results per model, with seek-on-click and delete.""" seek_requested = pyqtSignal(float) # request main window to seek to time export_requested = pyqtSignal(list) # emit list of (start, end, score) to export + negatives_requested = pyqtSignal(list) # emit list of start times to mark as hard negatives + negatives_removed = pyqtSignal(list) # emit list of start times to un-mark as negatives + tab_changed = pyqtSignal() # active tab changed def __init__(self, db, parent=None): super().__init__(parent) self._db = db self._filename = "" self._profile = "" + self._neg_times: set[float] = set() layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -437,13 +453,18 @@ class ScanResultsPanel(QWidget): self._tabs = QTabWidget() self._tabs.setTabsClosable(False) + self._tabs.currentChanged.connect(lambda: self.tab_changed.emit()) layout.addWidget(self._tabs) btn_row = QHBoxLayout() + self._btn_neg = QPushButton("Add to Negatives") + self._btn_neg.setToolTip("Mark selected rows as hard-negative training examples and remove them") + self._btn_neg.clicked.connect(self._on_add_negatives) self._btn_export = QPushButton("Export Scan Results") self._btn_export.setToolTip("Export clips from the active tab's scan results") self._btn_export.clicked.connect(self._on_export) btn_row.addStretch() + btn_row.addWidget(self._btn_neg) btn_row.addWidget(self._btn_export) layout.addLayout(btn_row) @@ -451,6 +472,7 @@ class ScanResultsPanel(QWidget): """Load saved scan results from DB for a file.""" self._filename = filename self._profile = profile + self._neg_times = self._db.get_hard_negative_times(filename, profile) self._tabs.clear() results = self._db.get_scan_results(filename, profile) for model, rows in results.items(): @@ -490,6 +512,7 @@ class ScanResultsPanel(QWidget): header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + red = QColor(220, 60, 60) for i, (row_id, start, end, score) in enumerate(rows): t_item = QTableWidgetItem(format_time(start)) t_item.setData(Qt.ItemDataRole.UserRole, row_id) @@ -499,6 +522,9 @@ class ScanResultsPanel(QWidget): e_item.setData(Qt.ItemDataRole.UserRole, end) table.setItem(i, 1, e_item) table.setItem(i, 2, QTableWidgetItem(f"{score:.2f}")) + if start in self._neg_times: + for col in range(3): + table.item(i, col).setForeground(red) table.itemSelectionChanged.connect( lambda t=table: self._on_selection_changed(t)) @@ -529,6 +555,7 @@ class ScanResultsPanel(QWidget): # Update tab title with new count count = table.rowCount() self._tabs.setTabText(tab_idx, f"{model} ({count})") + self.tab_changed.emit() # trigger export count refresh def _get_tab_regions(self, table: QTableWidget ) -> list[tuple[float, float, float]]: @@ -541,11 +568,43 @@ class ScanResultsPanel(QWidget): regions.append((float(start), float(end), score)) return regions + def _on_add_negatives(self) -> None: + """Toggle selected rows as hard negatives (red = negative, toggle off to remove).""" + table = self._tabs.currentWidget() + if not isinstance(table, QTableWidget): + return + selected_rows = sorted({idx.row() for idx in table.selectedIndexes()}) + if not selected_rows: + return + add_times: list[float] = [] + remove_times: list[float] = [] + red = QColor(220, 60, 60) + default_fg = table.palette().color(table.foregroundRole()) + for row in selected_rows: + start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1) + if start is None: + continue + t = float(start) + if t in self._neg_times: + remove_times.append(t) + self._neg_times.discard(t) + for col in range(3): + table.item(row, col).setForeground(default_fg) + else: + add_times.append(t) + self._neg_times.add(t) + for col in range(3): + table.item(row, col).setForeground(red) + if add_times: + self.negatives_requested.emit(add_times) + if remove_times: + self.negatives_removed.emit(remove_times) + def _on_export(self) -> None: table = self._tabs.currentWidget() if not isinstance(table, QTableWidget): return - regions = self._get_tab_regions(table) + regions = [r for r in self._get_tab_regions(table) if r[0] not in self._neg_times] if regions: self.export_requested.emit(regions) @@ -556,6 +615,13 @@ class ScanResultsPanel(QWidget): return [] return self._get_tab_regions(table) + def set_export_count(self, n: int) -> None: + """Update the export button label with estimated clip count.""" + if n > 0: + self._btn_export.setText(f"Export Scan Results ({n})") + else: + self._btn_export.setText("Export Scan Results") + def has_results(self) -> bool: return self._tabs.count() > 0 @@ -570,6 +636,7 @@ class TimelineWidget(QWidget): cursor_changed = pyqtSignal(float) # emits position in seconds seek_changed = pyqtSignal(float) # emits seek position (lock mode) marker_delete_requested = pyqtSignal(str) # emits output_path + markers_clear_requested = pyqtSignal() # clear all markers keyframe_delete_requested = pyqtSignal(float) # emits keyframe time marker_clicked = pyqtSignal(float, str) # emits (start_time, output_path) marker_deselected = pyqtSignal() # double-click on empty space @@ -584,12 +651,14 @@ class TimelineWidget(QWidget): self._duration = 0.0 self._cursor = 0.0 self._clip_span = 14.0 # 8 + 2*spread, updated from MainWindow + self._scan_mode = False self._play_pos: float | None = None # current playback position (seconds) self._locked = False # when True, clicks scrub playback, not cursor self._crop_keyframes: list[tuple[float, float, str | None, bool, bool]] = [] self._markers: list[tuple[float, int, str]] = [] self._hover_cache: list[tuple[float, str]] = [] # (t/duration, path) self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score) + self._scan_neg_times: set[float] = set() # Cached paint resources — created once, reused every frame self._cursor_pen = QPen(QColor(255, 210, 0)) @@ -622,7 +691,10 @@ class TimelineWidget(QWidget): self.update() def set_cursor(self, seconds: float): - clamped = max(0.0, min(seconds, max(0.0, self._duration - self._clip_span))) + if self._scan_mode: + clamped = max(0.0, min(seconds, self._duration)) + else: + clamped = max(0.0, min(seconds, max(0.0, self._duration - self._clip_span))) if clamped == self._cursor: return self._cursor = clamped @@ -634,9 +706,11 @@ class TimelineWidget(QWidget): self._rebuild_hover_cache() self.update() - def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None: + def set_scan_regions(self, regions: list[tuple[float, float, float]], + neg_times: set[float] | None = None) -> None: """regions: list of (start_time, end_time, score)""" self._scan_regions = regions + self._scan_neg_times = neg_times or set() self.update() def clear_scan_regions(self) -> None: @@ -734,12 +808,13 @@ class TimelineWidget(QWidget): # ── selection region (full clip span) ───────────────────────── x_start = int(self._cursor / self._duration * w) - x_end = int(min(self._cursor + self._clip_span, self._duration) / self._duration * w) - sel_w = max(x_end - x_start, 1) - p.fillRect(x_start, rh, sel_w, th, QColor(60, 130, 220, 90)) + if not self._scan_mode: + x_end = int(min(self._cursor + self._clip_span, self._duration) / self._duration * w) + sel_w = max(x_end - x_start, 1) + p.fillRect(x_start, rh, sel_w, th, QColor(60, 130, 220, 90)) # ── playback progress fill ──────────────────────────────────── - if self._play_pos is not None and self._play_pos > self._cursor: + if not self._scan_mode and self._play_pos is not None and self._play_pos > self._cursor: prog_end = min(self._play_pos, self._cursor + self._clip_span, self._duration) x_prog = int(prog_end / self._duration * w) prog_w = max(x_prog - x_start, 0) @@ -747,9 +822,10 @@ class TimelineWidget(QWidget): p.fillRect(x_start, rh, prog_w, th, QColor(100, 200, 255, 60)) # left/right edges of selection - p.setPen(QPen(QColor(60, 130, 220, 180), 1)) - p.drawLine(x_start, rh, x_start, h) - p.drawLine(x_end, rh, x_end, h) + if not self._scan_mode: + p.setPen(QPen(QColor(60, 130, 220, 180), 1)) + p.drawLine(x_start, rh, x_start, h) + p.drawLine(x_end, rh, x_end, h) # ── scan regions ────────────────────────────────────────────── if self._scan_regions and self._duration > 0: @@ -757,19 +833,28 @@ class TimelineWidget(QWidget): x1 = int(start / self._duration * w) x2 = int(end / self._duration * w) alpha = int(40 + score * 80) # 40–120 opacity - p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha)) + if start in self._scan_neg_times: + p.fillRect(x1, rh, x2 - x1, h - rh, QColor(220, 60, 60, alpha)) + else: + p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha)) # ── export markers ──────────────────────────────────────────── - p.setFont(self._marker_font) - for (t, num, _path) in self._markers: - mx = int(t / self._duration * w) - p.setPen(self._marker_pen) - p.drawLine(mx, rh, mx, h) - # small filled rectangle label - p.fillRect(mx, rh + 2, 14, 12, QColor(200, 50, 50)) - p.setPen(QColor(255, 255, 255)) - p.drawText(mx + 1, rh + 2, 13, 12, - Qt.AlignmentFlag.AlignCenter, str(num)) + if not self._scan_mode: + p.setFont(self._marker_font) + for (t, num, _path) in self._markers: + mx = int(t / self._duration * w) + p.setPen(self._marker_pen) + p.drawLine(mx, rh, mx, h) + # small filled rectangle label + p.fillRect(mx, rh + 2, 14, 12, QColor(200, 50, 50)) + p.setPen(QColor(255, 255, 255)) + p.drawText(mx + 1, rh + 2, 13, 12, + Qt.AlignmentFlag.AlignCenter, str(num)) + + # ── scan mode cursor line ───────────────────────────────────── + if self._scan_mode: + p.setPen(QPen(QColor(255, 255, 255, 200), 2)) + p.drawLine(x_start, rh, x_start, h) # ── crop keyframe diamonds ──────────────────────────────────── if self._crop_keyframes and self._duration > 0: @@ -895,21 +980,28 @@ class TimelineWidget(QWidget): if abs(x - frac * w) <= 10: hit_path = output_path break - if hit_kf_time is None and hit_path is None: - return from PyQt6.QtWidgets import QMenu menu = QMenu(self) act_kf = None act_marker = None + act_clear = None if hit_kf_time is not None: act_kf = menu.addAction(f"Delete keyframe @ {format_time(hit_kf_time)}") if hit_path is not None: act_marker = menu.addAction(f"Delete marker: {os.path.basename(hit_path)}") + if self._markers: + if hit_kf_time is not None or hit_path is not None: + menu.addSeparator() + act_clear = menu.addAction(f"Clear all markers ({len(self._markers)})") + if menu.isEmpty(): + return chosen = menu.exec(event.globalPos()) if chosen and chosen == act_kf: self.keyframe_delete_requested.emit(hit_kf_time) elif chosen and chosen == act_marker: self.marker_delete_requested.emit(hit_path) + elif chosen and chosen == act_clear: + self.markers_clear_requested.emit() def _seek(self, x: float): t = self._pos_to_time(int(x)) @@ -1739,6 +1831,7 @@ class MainWindow(QMainWindow): self._timeline.cursor_changed.connect(self._on_cursor_changed) self._timeline.seek_changed.connect(self._on_seek_changed) self._timeline.marker_delete_requested.connect(self._on_delete_marker) + self._timeline.markers_clear_requested.connect(self._on_clear_markers) self._timeline.keyframe_delete_requested.connect(self._on_delete_keyframe) self._mpv.time_pos_changed.connect(self._timeline.set_play_position) self._timeline.marker_clicked.connect(self._on_marker_clicked) @@ -1862,6 +1955,7 @@ class MainWindow(QMainWindow): ) self._spn_spread.valueChanged.connect(lambda: self._preview_timer.start()) self._spn_spread.valueChanged.connect(self._update_play_loop) + self._spn_spread.valueChanged.connect(lambda: self._update_scan_export_count()) self._chk_rand_portrait = QCheckBox("1 random portrait") self._chk_rand_portrait.setToolTip( @@ -1900,6 +1994,11 @@ class MainWindow(QMainWindow): ) # ── audio scan controls ────────────────────────────────────── + self._btn_scan_mode = QPushButton("Review") + self._btn_scan_mode.setCheckable(True) + self._btn_scan_mode.setToolTip("Scan review mode: hide spread/markers, free cursor movement") + self._btn_scan_mode.toggled.connect(self._toggle_scan_mode) + self._btn_scan = QPushButton("Scan") self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips") self._btn_scan.clicked.connect(self._start_scan) @@ -1919,8 +2018,10 @@ class MainWindow(QMainWindow): self._scan_all_queue: list[str] = [] self._cmb_scan_model = QComboBox() - self._cmb_scan_model.setToolTip("Trained embedding model to use for scanning") + self._cmb_scan_model.setToolTip("Trained embedding model to use for scanning\nRight-click to rollback to a previous version") self._cmb_scan_model.setMinimumWidth(120) + self._cmb_scan_model.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self._cmb_scan_model.customContextMenuRequested.connect(self._show_model_versions_menu) self._spn_auto_fuse = QDoubleSpinBox() self._spn_auto_fuse.setDecimals(1) @@ -1933,6 +2034,7 @@ class MainWindow(QMainWindow): self._spn_auto_fuse.valueChanged.connect( lambda v: self._settings.setValue("auto_fuse", str(v)) ) + self._spn_auto_fuse.valueChanged.connect(lambda: self._update_scan_export_count()) self._sld_threshold = QDoubleSpinBox() self._sld_threshold.setDecimals(2) @@ -2079,6 +2181,7 @@ class MainWindow(QMainWindow): settings_row.addWidget(self._chk_track) settings_row.addWidget(self._cmb_scan_model) settings_row.addWidget(self._btn_scan) + settings_row.addWidget(self._btn_scan_mode) settings_row.addWidget(self._btn_auto_export) settings_row.addWidget(self._spn_auto_fuse) settings_row.addWidget(self._sld_threshold) @@ -2137,6 +2240,9 @@ class MainWindow(QMainWindow): self._scan_panel = ScanResultsPanel(self._db) self._scan_panel.seek_requested.connect(self._on_scan_seek) self._scan_panel.export_requested.connect(self._on_scan_export) + self._scan_panel.negatives_requested.connect(self._on_scan_negatives) + self._scan_panel.negatives_removed.connect(self._on_scan_negatives_removed) + self._scan_panel.tab_changed.connect(self._update_scan_export_count) # Root: horizontal splitter splitter = QSplitter(Qt.Orientation.Horizontal) @@ -2414,6 +2520,11 @@ class MainWindow(QMainWindow): if self._file_path: filename = os.path.basename(self._file_path) self._scan_panel.load_for_file(filename, self._profile) + self._timeline.set_scan_regions( + self._scan_panel.current_regions(), + neg_times=self._scan_panel._neg_times, + ) + self._update_scan_export_count() dur = self._mpv.get_duration() self._timeline.set_duration(dur) @@ -2487,6 +2598,19 @@ class MainWindow(QMainWindow): f"Deleted marker ({n} clip{'s' if n != 1 else ''})", 4000 ) + def _on_clear_markers(self) -> None: + """Delete all markers for the current file.""" + if not self._file_path: + return + filename = os.path.basename(self._file_path) + markers = self._db.get_markers(filename, self._profile) + for _, _, output_path in markers: + self._db.delete_by_output_path(output_path) + self._refresh_markers() + self._refresh_playlist_checks() + self._update_next_label() + self._show_status(f"Cleared {len(markers)} marker(s)", 4000) + def _on_delete_keyframe(self, time: float) -> None: self._crop_keyframes = [ kf for kf in self._crop_keyframes @@ -2916,6 +3040,33 @@ class MainWindow(QMainWindow): if idx >= 0: self._cmb_scan_model.setCurrentIndex(idx) + def _show_model_versions_menu(self, pos) -> None: + """Show context menu with model version history for rollback.""" + from core.audio_scan import list_model_versions, restore_model_version + sel = self._cmb_scan_model.currentText() + if not sel or sel == "(no model)": + return + embed_name = None if sel == "(legacy)" else sel + versions = list_model_versions(self._profile, embed_name) + if len(versions) <= 1: + self._show_status("No previous versions available") + return + from PyQt6.QtWidgets import QMenu + menu = QMenu(self) + for label, path in versions: + if label == "current": + act = menu.addAction(f"current (active)") + act.setEnabled(False) + else: + # Format timestamp for display: 20260418_170800 → 2026-04-18 17:08 + display = f"{label[:4]}-{label[4:6]}-{label[6:8]} {label[9:11]}:{label[11:13]}" + act = menu.addAction(f"Restore {display}") + act.setData(path) + chosen = menu.exec(self._cmb_scan_model.mapToGlobal(pos)) + if chosen and chosen.data(): + restore_model_version(chosen.data(), self._profile, embed_name) + self._show_status(f"Restored model version — rescan to use it") + def _cleanup_scan_worker(self) -> None: """Disconnect signals, cancel, and schedule deletion of old scan worker.""" if self._scan_worker is not None: @@ -2932,6 +3083,11 @@ class MainWindow(QMainWindow): self._scan_worker.deleteLater() self._scan_worker = None + def _toggle_scan_mode(self, on: bool) -> None: + """Toggle scan review mode — clean timeline, free cursor.""" + self._timeline._scan_mode = on + self._timeline.update() + def _start_scan(self) -> None: if not self._file_path: self._show_status("No video loaded") @@ -2972,6 +3128,7 @@ class MainWindow(QMainWindow): if model_label and self._file_path: filename = os.path.basename(self._file_path) self._scan_panel.add_scan_results(model_label, regions) + self._update_scan_export_count() self._show_status(f"Scan complete: {len(regions)} matching regions") def _on_scan_error(self, msg: str) -> None: @@ -2988,6 +3145,20 @@ class MainWindow(QMainWindow): dur = self._mpv.get_duration() self._lbl_time.setText(f"{format_time(t)} / {format_time(dur)}") + def _update_scan_export_count(self) -> None: + """Recalculate and display estimated clip count on the export button.""" + neg = self._scan_panel._neg_times + regions = [r for r in self._scan_panel.current_regions() if r[0] not in neg] + if not regions: + self._scan_panel.set_export_count(0) + return + groups = self._build_export_spans( + regions, fuse_gap=self._spn_auto_fuse.value(), + spread=self._spn_spread.value(), + ) + n = sum(len(g) for g in groups) + self._scan_panel.set_export_count(n) + def _on_scan_export(self, regions: list) -> None: """Export clips from scan results panel.""" if not self._file_path or not regions: @@ -2995,8 +3166,36 @@ class MainWindow(QMainWindow): if self._export_worker and self._export_worker.isRunning(): self._show_status("Export already running…") return + self._auto_export_no_markers = True self._auto_export_regions(regions) + def _on_scan_negatives(self, times: list) -> None: + """Save selected scan result timestamps as hard negatives for training.""" + if not self._file_path: + return + filename = os.path.basename(self._file_path) + self._db.add_hard_negatives(filename, self._profile, times, + source_path=self._file_path) + self._timeline.set_scan_regions( + self._scan_panel.current_regions(), + neg_times=self._scan_panel._neg_times, + ) + self._update_scan_export_count() + self._show_status(f"Added {len(times)} hard negative(s) for training") + + def _on_scan_negatives_removed(self, times: list) -> None: + """Remove hard negatives that were toggled off.""" + if not self._file_path: + return + filename = os.path.basename(self._file_path) + self._db.remove_hard_negatives(filename, self._profile, times) + self._timeline.set_scan_regions( + self._scan_panel.current_regions(), + neg_times=self._scan_panel._neg_times, + ) + self._update_scan_export_count() + self._show_status(f"Removed {len(times)} hard negative(s)") + # ── Scan All ─────────────────────────────────────────────── def _start_scan_all(self) -> None: @@ -3118,6 +3317,7 @@ class MainWindow(QMainWindow): neg_margin = dlg.neg_margin embed_model = dlg.embed_model video_dir = dlg.video_dir + inc_scan = dlg.include_scan_exports if not pos_folder: self._show_status("No positive class selected") return @@ -3129,6 +3329,7 @@ class MainWindow(QMainWindow): video_infos = self._db.get_training_data( self._profile, pos_folder, negative_folder=neg_folder, fallback_video_dir=video_dir, + include_scan_exports=inc_scan, ) if not video_infos: self._show_status("No training data found for this subprofile") @@ -3197,45 +3398,50 @@ class MainWindow(QMainWindow): self._scan_worker.start() @staticmethod - def _select_export_positions(regions: list[tuple[float, float, float]], - min_gap: float = 2.0, - cluster_fuse: float = 30.0, - ) -> list[float]: - """Cluster scan regions, then fill each cluster with clips spaced min_gap apart. + def _build_export_spans(regions: list[tuple[float, float, float]], + fuse_gap: float = 30.0, + spread: float = 3.0, + min_dur: float = 8.0, + ) -> list[list[float]]: + """Build export position groups from fused scan regions. - 1. Merge overlapping regions into clusters, fusing clusters = min_gap for p in cluster_picks): - cluster_picks.append(start) - picked.extend(cluster_picks) + # Place clips within each span + groups: list[list[float]] = [] + step = max(spread, 1.0) + for s, e in spans: + dur = e - s + if dur < min_dur: + continue + clips: list[float] = [] + t = s + while t + min_dur <= e: + clips.append(t) + t += step + if clips: + groups.append(clips) - return sorted(picked) + return groups def _on_auto_scan_done(self, regions: list) -> None: self._btn_scan.setEnabled(True) @@ -3249,6 +3455,7 @@ class MainWindow(QMainWindow): if model_label and self._file_path: self._scan_panel.add_scan_results(model_label, regions) + self._auto_export_no_markers = True self._auto_export_regions(regions) def _auto_export_regions(self, regions: list) -> None: @@ -3258,23 +3465,24 @@ class MainWindow(QMainWindow): self._btn_auto_export.setEnabled(True) return - positions = self._select_export_positions( - regions, min_gap=2.0, cluster_fuse=self._spn_auto_fuse.value(), + spread = self._spn_spread.value() + groups = self._build_export_spans( + regions, fuse_gap=self._spn_auto_fuse.value(), + spread=spread, ) - if not positions: - self._show_status("Auto: no positions after NMS") + if not groups: + self._show_status("Auto: no regions >= 8s") self._btn_auto_export.setEnabled(True) return - # Build export jobs — one 8s clip per position folder = self._txt_folder.text() name = self._txt_name.text() or "clip" - self._auto_export_name = name fmt = self._cmb_format.currentText() image_sequence = fmt == "WebP sequence" + ext = "" if image_sequence else ".mp4" os.makedirs(folder, exist_ok=True) - # Find starting counter + # Find next counter following the normal order counter = 1 while True: if image_sequence: @@ -3285,18 +3493,19 @@ class MainWindow(QMainWindow): break counter += 1 + # One group folder for the whole scan batch + group_name = f"{name}_{counter:03d}" + group_dir = os.path.join(folder, group_name) + os.makedirs(group_dir, exist_ok=True) + jobs = [] - self._auto_export_positions = [] # stash for DB writes - for start_t in positions: - group_dir = os.path.join(folder, f"{name}_{counter:03d}") - os.makedirs(group_dir, exist_ok=True) - if image_sequence: - out = build_sequence_dir(folder, name, counter, sub=0) - else: - out = build_export_path(folder, name, counter, sub=0) - jobs.append((start_t, out, None, 0.5)) - self._auto_export_positions.append((start_t, counter)) - counter += 1 + self._auto_export_positions = [] + for area_idx, group in enumerate(groups, 1): + for sub, start_t in enumerate(group): + fname = f"{group_name}_a{area_idx}_{sub}{ext}" + out = os.path.join(group_dir, fname) + jobs.append((start_t, out, None, 0.5)) + self._auto_export_positions.append((start_t, out)) self._show_status(f"Auto: exporting {len(jobs)} clips...") @@ -3306,7 +3515,7 @@ class MainWindow(QMainWindow): self._export_crop_center = 0.5 self._export_format = fmt self._export_clip_count = 1 - self._export_spread = 0 + self._export_spread = spread self._export_folder = folder self._export_folder_suffix = "" self._export_profile = self._profile @@ -3333,20 +3542,17 @@ class MainWindow(QMainWindow): def _on_auto_clip_done(self, path: str): """Record each auto-exported clip to DB.""" - # Find the start_time for this clip from stashed positions - counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042" - name = getattr(self, '_auto_export_name', self._txt_name.text() or "clip") - start_t = None - for t, c in self._auto_export_positions: - if counter_str == f"{name}_{c:03d}": + start_t = 0.0 + for t, out in self._auto_export_positions: + if os.path.normpath(out) == os.path.normpath(path): start_t = t break - + is_scan = getattr(self, '_auto_export_no_markers', False) label = self._txt_label.currentText().strip() category = self._cmb_category.currentText() self._db.add( os.path.basename(self._file_path), - start_t or 0.0, + start_t, path, label=label, category=category, @@ -3355,11 +3561,13 @@ class MainWindow(QMainWindow): crop_center=0.5, fmt=self._export_format, clip_count=1, - spread=0, + spread=self._export_spread, profile=self._export_profile, source_path=self._file_path, + scan_export=is_scan, ) - upsert_clip_annotation(self._export_folder, path, label) + if not is_scan: + upsert_clip_annotation(self._export_folder, path, label) self._show_status(f"Auto: {os.path.basename(path)}") _log(f" auto clip done: {os.path.basename(path)}") @@ -3369,6 +3577,7 @@ class MainWindow(QMainWindow): self._btn_cancel.setEnabled(False) self._btn_export.setEnabled(True) self._set_subprofile_btns_enabled(True) + self._auto_export_no_markers = False self._refresh_markers() markers = self._db.get_markers(os.path.basename(self._file_path), self._profile) self._playlist.mark_done(self._file_path, len(markers))