feat: hard negative source_model tracking, training toggle

Add source_model column to hard_negatives table with migration. New
get_hard_negatives() returns full rows, delete_hard_negatives_by_ids()
for bulk deletion. get_training_data() gains use_hard_negatives param.
TrainDialog has "Use hard negatives" checkbox. Scan panel passes current
model name when marking negatives.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 15:27:11 +02:00
parent 8ed9fbf557
commit edc5784ba6
3 changed files with 134 additions and 19 deletions
+57 -18
View File
@@ -119,13 +119,23 @@ class ProcessedDB:
) )
self._con.execute( self._con.execute(
"CREATE TABLE IF NOT EXISTS hard_negatives (" "CREATE TABLE IF NOT EXISTS hard_negatives ("
" id INTEGER PRIMARY KEY AUTOINCREMENT," " id INTEGER PRIMARY KEY AUTOINCREMENT,"
" filename TEXT NOT NULL," " filename TEXT NOT NULL,"
" profile TEXT NOT NULL DEFAULT 'default'," " profile TEXT NOT NULL DEFAULT 'default',"
" start_time REAL NOT NULL," " start_time REAL NOT NULL,"
" source_path TEXT NOT NULL DEFAULT ''" " source_path TEXT NOT NULL DEFAULT '',"
" source_model TEXT NOT NULL DEFAULT ''"
")" ")"
) )
# Migrate: add source_model column to existing hard_negatives tables
hn_cols = {
row[1]
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
}
if "source_model" not in hn_cols:
self._con.execute(
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
)
self._con.execute( self._con.execute(
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile" "CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
" ON hard_negatives(filename, profile)" " ON hard_negatives(filename, profile)"
@@ -353,6 +363,7 @@ class ProcessedDB:
negative_folder: str = "", negative_folder: str = "",
fallback_video_dir: str = "", fallback_video_dir: str = "",
include_scan_exports: bool = False, include_scan_exports: bool = False,
use_hard_negatives: bool = True,
) -> list[tuple[str, list[float], list[float], list[float]]]: ) -> list[tuple[str, list[float], list[float], list[float]]]:
"""Build training video_infos from DB data. """Build training video_infos from DB data.
@@ -362,6 +373,7 @@ class ProcessedDB:
negative_folder: export folder name for explicit negatives (optional) negative_folder: export folder name for explicit negatives (optional)
fallback_video_dir: if source_path is empty, try filename in this dir fallback_video_dir: if source_path is empty, try filename in this dir
include_scan_exports: if True, include auto-exported scan clips include_scan_exports: if True, include auto-exported scan clips
use_hard_negatives: if False, skip hard negatives from scan feedback
Returns: Returns:
list of (source_video_path, positive_times, soft_times, negative_times) list of (source_video_path, positive_times, soft_times, negative_times)
@@ -400,15 +412,16 @@ class ProcessedDB:
soft_by_video.setdefault(fn, set()).add(st) soft_by_video.setdefault(fn, set()).add(st)
# Include hard negatives from scan feedback # Include hard negatives from scan feedback
hard_rows = self._con.execute( if use_hard_negatives:
"SELECT filename, start_time, source_path FROM hard_negatives" hard_rows = self._con.execute(
" WHERE profile = ?", "SELECT filename, start_time, source_path FROM hard_negatives"
(profile,), " WHERE profile = ?",
).fetchall() (profile,),
for fn, st, sp in hard_rows: ).fetchall()
neg_by_video.setdefault(fn, set()).add(st) for fn, st, sp in hard_rows:
if sp: neg_by_video.setdefault(fn, set()).add(st)
source_by_filename.setdefault(fn, sp) if sp:
source_by_filename.setdefault(fn, sp)
# Remove positive times from soft/neg to avoid conflicting labels # Remove positive times from soft/neg to avoid conflicting labels
for fn in pos_by_video: for fn in pos_by_video:
@@ -638,16 +651,18 @@ class ProcessedDB:
return {r[0] for r in rows} return {r[0] for r in rows}
def add_hard_negatives(self, filename: str, profile: str, def add_hard_negatives(self, filename: str, profile: str,
times: list[float], source_path: str = "") -> None: times: list[float], source_path: str = "",
source_model: str = "") -> None:
"""Save timestamps as hard-negative training examples.""" """Save timestamps as hard-negative training examples."""
if not self._enabled or not times: if not self._enabled or not times:
return return
with self._lock: with self._lock:
for t in times: for t in times:
self._con.execute( self._con.execute(
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)" "INSERT INTO hard_negatives"
" VALUES (?, ?, ?, ?)", " (filename, profile, start_time, source_path, source_model)"
(filename, profile, t, source_path), " VALUES (?, ?, ?, ?, ?)",
(filename, profile, t, source_path, source_model),
) )
self._con.commit() self._con.commit()
@@ -662,6 +677,30 @@ class ProcessedDB:
).fetchall() ).fetchall()
return {r[0] for r in rows} return {r[0] for r in rows}
def get_hard_negatives(self, profile: str) -> list[dict]:
"""Return all hard negatives for a profile with full details."""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT id, filename, start_time, source_path, source_model"
" FROM hard_negatives WHERE profile = ?"
" ORDER BY filename, start_time",
(profile,),
).fetchall()
return [{"id": r[0], "filename": r[1], "start_time": r[2],
"source_path": r[3], "source_model": r[4]} for r in rows]
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
"""Delete hard negatives by row IDs."""
if not self._enabled or not ids:
return
with self._lock:
self._con.execute(
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
ids,
)
self._con.commit()
def remove_hard_negatives(self, filename: str, profile: str, def remove_hard_negatives(self, filename: str, profile: str,
times: list[float]) -> None: times: list[float]) -> None:
"""Remove specific hard-negative timestamps.""" """Remove specific hard-negative timestamps."""
+20 -1
View File
@@ -372,6 +372,14 @@ class TrainDialog(QDialog):
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start()) self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
form.addRow("", self._chk_scan_exports) form.addRow("", self._chk_scan_exports)
self._chk_hard_negatives = QCheckBox("Use hard negatives in training")
self._chk_hard_negatives.setChecked(True)
self._chk_hard_negatives.setToolTip(
"When unchecked, manually marked hard negatives are excluded from training.\n"
"Useful when training a new model type where old negatives may not apply.")
self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start())
form.addRow("", self._chk_hard_negatives)
# Video source directory (fallback for old DB rows without source_path) # Video source directory (fallback for old DB rows without source_path)
self._txt_video_dir = QLineEdit(video_dir) self._txt_video_dir = QLineEdit(video_dir)
self._txt_video_dir.setPlaceholderText("Directory containing source videos") self._txt_video_dir.setPlaceholderText("Directory containing source videos")
@@ -464,15 +472,18 @@ class TrainDialog(QDialog):
return return
neg_folder = self._cmb_negative.currentData() or "" neg_folder = self._cmb_negative.currentData() or ""
inc_scan = self._chk_scan_exports.isChecked() inc_scan = self._chk_scan_exports.isChecked()
use_neg = self._chk_hard_negatives.isChecked()
# First check without fallback to see if source_paths are sufficient # First check without fallback to see if source_paths are sufficient
video_infos_no_fb = self._db.get_training_data( video_infos_no_fb = self._db.get_training_data(
self._profile, folder, negative_folder=neg_folder, self._profile, folder, negative_folder=neg_folder,
include_scan_exports=inc_scan, include_scan_exports=inc_scan,
use_hard_negatives=use_neg,
) )
video_infos = self._db.get_training_data( video_infos = self._db.get_training_data(
self._profile, folder, negative_folder=neg_folder, self._profile, folder, negative_folder=neg_folder,
fallback_video_dir=self._txt_video_dir.text(), fallback_video_dir=self._txt_video_dir.text(),
include_scan_exports=inc_scan, include_scan_exports=inc_scan,
use_hard_negatives=use_neg,
) )
# Show video dir field only when the fallback helps find extra videos # Show video dir field only when the fallback helps find extra videos
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0 needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
@@ -526,6 +537,10 @@ class TrainDialog(QDialog):
def include_scan_exports(self) -> bool: def include_scan_exports(self) -> bool:
return self._chk_scan_exports.isChecked() return self._chk_scan_exports.isChecked()
@property
def use_hard_negatives(self) -> bool:
return self._chk_hard_negatives.isChecked()
class TrainWorker(QThread): class TrainWorker(QThread):
"""Trains an audio classifier off the main thread.""" """Trains an audio classifier off the main thread."""
@@ -4007,8 +4022,10 @@ class MainWindow(QMainWindow):
if not self._file_path: if not self._file_path:
return return
filename = os.path.basename(self._file_path) filename = os.path.basename(self._file_path)
source_model = self._scan_panel.current_model_name()
self._db.add_hard_negatives(filename, self._profile, times, self._db.add_hard_negatives(filename, self._profile, times,
source_path=self._file_path) source_path=self._file_path,
source_model=source_model)
self._timeline.set_scan_regions( self._timeline.set_scan_regions(
self._scan_panel.current_regions_with_orig(), self._scan_panel.current_regions_with_orig(),
neg_times=self._scan_panel._neg_times, neg_times=self._scan_panel._neg_times,
@@ -4228,6 +4245,7 @@ class MainWindow(QMainWindow):
embed_model = dlg.embed_model embed_model = dlg.embed_model
video_dir = dlg.video_dir video_dir = dlg.video_dir
inc_scan = dlg.include_scan_exports inc_scan = dlg.include_scan_exports
use_neg = dlg.use_hard_negatives
if not pos_folder: if not pos_folder:
self._show_status("No positive class selected") self._show_status("No positive class selected")
return return
@@ -4240,6 +4258,7 @@ class MainWindow(QMainWindow):
self._profile, pos_folder, negative_folder=neg_folder, self._profile, pos_folder, negative_folder=neg_folder,
fallback_video_dir=video_dir, fallback_video_dir=video_dir,
include_scan_exports=inc_scan, include_scan_exports=inc_scan,
use_hard_negatives=use_neg,
) )
if not video_infos: if not video_infos:
self._show_status("No training data found for this subprofile") self._show_status("No training data found for this subprofile")
+57
View File
@@ -50,3 +50,60 @@ def test_scan_result_history():
assert len(results.get("MODEL_A", [])) == 1 assert len(results.get("MODEL_A", [])) == 1
finally: finally:
os.unlink(path) os.unlink(path)
def test_hard_negatives_source_model():
"""Hard negatives should store source_model."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
try:
db = ProcessedDB(path)
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
source_path="/a.mp4", source_model="HUBERT_XLARGE")
rows = db.get_hard_negatives("test")
assert len(rows) == 2
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
finally:
os.unlink(path)
def test_training_data_skips_hard_negatives():
"""get_training_data with use_hard_negatives=False should skip them."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
try:
db = ProcessedDB(path)
# Create a source file that "exists" — use the temp db file itself
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
source_path=path)
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
# With hard negatives
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
# Without hard negatives
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
assert len(data_with) >= 1
# The "with" case should have the hard negative time in neg list
neg_with = sum(len(vi[3]) for vi in data_with)
neg_without = sum(len(vi[3]) for vi in data_without)
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
finally:
os.unlink(path)
def test_delete_hard_negatives_by_ids():
"""delete_hard_negatives_by_ids should remove specific rows."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
try:
db = ProcessedDB(path)
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
source_path="/a.mp4")
rows = db.get_hard_negatives("test")
assert len(rows) == 3
# Delete first two
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
remaining = db.get_hard_negatives("test")
assert len(remaining) == 1
assert remaining[0]["start_time"] == 30.0
finally:
os.unlink(path)