diff --git a/core/db.py b/core/db.py index f05e3a5..3ba060b 100644 --- a/core/db.py +++ b/core/db.py @@ -119,13 +119,23 @@ class ProcessedDB: ) self._con.execute( "CREATE TABLE IF NOT EXISTS hard_negatives (" - " id INTEGER PRIMARY KEY AUTOINCREMENT," - " filename TEXT NOT NULL," - " profile TEXT NOT NULL DEFAULT 'default'," - " start_time REAL NOT NULL," - " source_path TEXT NOT NULL DEFAULT ''" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " filename TEXT NOT NULL," + " profile TEXT NOT NULL DEFAULT 'default'," + " start_time REAL NOT NULL," + " source_path TEXT NOT NULL DEFAULT ''," + " source_model TEXT NOT NULL DEFAULT ''" ")" ) + # Migrate: add source_model column to existing hard_negatives tables + hn_cols = { + row[1] + for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall() + } + if "source_model" not in hn_cols: + self._con.execute( + "ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''" + ) self._con.execute( "CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile" " ON hard_negatives(filename, profile)" @@ -353,6 +363,7 @@ class ProcessedDB: negative_folder: str = "", fallback_video_dir: str = "", include_scan_exports: bool = False, + use_hard_negatives: bool = True, ) -> list[tuple[str, list[float], list[float], list[float]]]: """Build training video_infos from DB data. @@ -362,6 +373,7 @@ class ProcessedDB: negative_folder: export folder name for explicit negatives (optional) fallback_video_dir: if source_path is empty, try filename in this dir include_scan_exports: if True, include auto-exported scan clips + use_hard_negatives: if False, skip hard negatives from scan feedback Returns: list of (source_video_path, positive_times, soft_times, negative_times) @@ -400,15 +412,16 @@ class ProcessedDB: soft_by_video.setdefault(fn, set()).add(st) # Include hard negatives from scan feedback - hard_rows = self._con.execute( - "SELECT filename, start_time, source_path FROM hard_negatives" - " WHERE profile = ?", - (profile,), - ).fetchall() - for fn, st, sp in hard_rows: - neg_by_video.setdefault(fn, set()).add(st) - if sp: - source_by_filename.setdefault(fn, sp) + if use_hard_negatives: + hard_rows = self._con.execute( + "SELECT filename, start_time, source_path FROM hard_negatives" + " WHERE profile = ?", + (profile,), + ).fetchall() + for fn, st, sp in hard_rows: + neg_by_video.setdefault(fn, set()).add(st) + if sp: + source_by_filename.setdefault(fn, sp) # Remove positive times from soft/neg to avoid conflicting labels for fn in pos_by_video: @@ -638,16 +651,18 @@ class ProcessedDB: return {r[0] for r in rows} def add_hard_negatives(self, filename: str, profile: str, - times: list[float], source_path: str = "") -> None: + times: list[float], source_path: str = "", + source_model: str = "") -> None: """Save timestamps as hard-negative training examples.""" if not self._enabled or not times: return with self._lock: for t in times: self._con.execute( - "INSERT INTO hard_negatives (filename, profile, start_time, source_path)" - " VALUES (?, ?, ?, ?)", - (filename, profile, t, source_path), + "INSERT INTO hard_negatives" + " (filename, profile, start_time, source_path, source_model)" + " VALUES (?, ?, ?, ?, ?)", + (filename, profile, t, source_path, source_model), ) self._con.commit() @@ -662,6 +677,30 @@ class ProcessedDB: ).fetchall() return {r[0] for r in rows} + def get_hard_negatives(self, profile: str) -> list[dict]: + """Return all hard negatives for a profile with full details.""" + if not self._enabled: + return [] + rows = self._con.execute( + "SELECT id, filename, start_time, source_path, source_model" + " FROM hard_negatives WHERE profile = ?" + " ORDER BY filename, start_time", + (profile,), + ).fetchall() + return [{"id": r[0], "filename": r[1], "start_time": r[2], + "source_path": r[3], "source_model": r[4]} for r in rows] + + def delete_hard_negatives_by_ids(self, ids: list[int]) -> None: + """Delete hard negatives by row IDs.""" + if not self._enabled or not ids: + return + with self._lock: + self._con.execute( + f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})", + ids, + ) + self._con.commit() + def remove_hard_negatives(self, filename: str, profile: str, times: list[float]) -> None: """Remove specific hard-negative timestamps.""" diff --git a/main.py b/main.py index 0bf71d6..c646f25 100755 --- a/main.py +++ b/main.py @@ -372,6 +372,14 @@ class TrainDialog(QDialog): self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start()) form.addRow("", self._chk_scan_exports) + self._chk_hard_negatives = QCheckBox("Use hard negatives in training") + self._chk_hard_negatives.setChecked(True) + self._chk_hard_negatives.setToolTip( + "When unchecked, manually marked hard negatives are excluded from training.\n" + "Useful when training a new model type where old negatives may not apply.") + self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start()) + form.addRow("", self._chk_hard_negatives) + # Video source directory (fallback for old DB rows without source_path) self._txt_video_dir = QLineEdit(video_dir) self._txt_video_dir.setPlaceholderText("Directory containing source videos") @@ -464,15 +472,18 @@ class TrainDialog(QDialog): return neg_folder = self._cmb_negative.currentData() or "" inc_scan = self._chk_scan_exports.isChecked() + use_neg = self._chk_hard_negatives.isChecked() # First check without fallback to see if source_paths are sufficient video_infos_no_fb = self._db.get_training_data( self._profile, folder, negative_folder=neg_folder, include_scan_exports=inc_scan, + use_hard_negatives=use_neg, ) video_infos = self._db.get_training_data( self._profile, folder, negative_folder=neg_folder, fallback_video_dir=self._txt_video_dir.text(), include_scan_exports=inc_scan, + use_hard_negatives=use_neg, ) # Show video dir field only when the fallback helps find extra videos needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0 @@ -526,6 +537,10 @@ class TrainDialog(QDialog): def include_scan_exports(self) -> bool: return self._chk_scan_exports.isChecked() + @property + def use_hard_negatives(self) -> bool: + return self._chk_hard_negatives.isChecked() + class TrainWorker(QThread): """Trains an audio classifier off the main thread.""" @@ -4007,8 +4022,10 @@ class MainWindow(QMainWindow): if not self._file_path: return filename = os.path.basename(self._file_path) + source_model = self._scan_panel.current_model_name() self._db.add_hard_negatives(filename, self._profile, times, - source_path=self._file_path) + source_path=self._file_path, + source_model=source_model) self._timeline.set_scan_regions( self._scan_panel.current_regions_with_orig(), neg_times=self._scan_panel._neg_times, @@ -4228,6 +4245,7 @@ class MainWindow(QMainWindow): embed_model = dlg.embed_model video_dir = dlg.video_dir inc_scan = dlg.include_scan_exports + use_neg = dlg.use_hard_negatives if not pos_folder: self._show_status("No positive class selected") return @@ -4240,6 +4258,7 @@ class MainWindow(QMainWindow): self._profile, pos_folder, negative_folder=neg_folder, fallback_video_dir=video_dir, include_scan_exports=inc_scan, + use_hard_negatives=use_neg, ) if not video_infos: self._show_status("No training data found for this subprofile") diff --git a/tests/test_db.py b/tests/test_db.py index 3ab3d41..51023d0 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -50,3 +50,60 @@ def test_scan_result_history(): assert len(results.get("MODEL_A", [])) == 1 finally: os.unlink(path) + + +def test_hard_negatives_source_model(): + """Hard negatives should store source_model.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + path = f.name + try: + db = ProcessedDB(path) + db.add_hard_negatives("a.mp4", "test", [10.0, 20.0], + source_path="/a.mp4", source_model="HUBERT_XLARGE") + rows = db.get_hard_negatives("test") + assert len(rows) == 2 + assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows) + finally: + os.unlink(path) + + +def test_training_data_skips_hard_negatives(): + """get_training_data with use_hard_negatives=False should skip them.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + path = f.name + try: + db = ProcessedDB(path) + # Create a source file that "exists" — use the temp db file itself + db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test", + source_path=path) + db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path) + # With hard negatives + data_with = db.get_training_data("test", "folder", use_hard_negatives=True) + # Without hard negatives + data_without = db.get_training_data("test", "folder", use_hard_negatives=False) + assert len(data_with) >= 1 + # The "with" case should have the hard negative time in neg list + neg_with = sum(len(vi[3]) for vi in data_with) + neg_without = sum(len(vi[3]) for vi in data_without) + assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False" + finally: + os.unlink(path) + + +def test_delete_hard_negatives_by_ids(): + """delete_hard_negatives_by_ids should remove specific rows.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + path = f.name + try: + db = ProcessedDB(path) + db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0], + source_path="/a.mp4") + rows = db.get_hard_negatives("test") + assert len(rows) == 3 + # Delete first two + db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]]) + remaining = db.get_hard_negatives("test") + assert len(remaining) == 1 + assert remaining[0]["start_time"] == 30.0 + finally: + os.unlink(path)