feat: hard negative source_model tracking, training toggle
Add source_model column to hard_negatives table with migration. New get_hard_negatives() returns full rows, delete_hard_negatives_by_ids() for bulk deletion. get_training_data() gains use_hard_negatives param. TrainDialog has "Use hard negatives" checkbox. Scan panel passes current model name when marking negatives. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+44
-5
@@ -123,9 +123,19 @@ class ProcessedDB:
|
||||
" filename TEXT NOT NULL,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" start_time REAL NOT NULL,"
|
||||
" source_path TEXT NOT NULL DEFAULT ''"
|
||||
" source_path TEXT NOT NULL DEFAULT '',"
|
||||
" source_model TEXT NOT NULL DEFAULT ''"
|
||||
")"
|
||||
)
|
||||
# Migrate: add source_model column to existing hard_negatives tables
|
||||
hn_cols = {
|
||||
row[1]
|
||||
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||
}
|
||||
if "source_model" not in hn_cols:
|
||||
self._con.execute(
|
||||
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||
" ON hard_negatives(filename, profile)"
|
||||
@@ -353,6 +363,7 @@ class ProcessedDB:
|
||||
negative_folder: str = "",
|
||||
fallback_video_dir: str = "",
|
||||
include_scan_exports: bool = False,
|
||||
use_hard_negatives: bool = True,
|
||||
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||
"""Build training video_infos from DB data.
|
||||
|
||||
@@ -362,6 +373,7 @@ class ProcessedDB:
|
||||
negative_folder: export folder name for explicit negatives (optional)
|
||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||
include_scan_exports: if True, include auto-exported scan clips
|
||||
use_hard_negatives: if False, skip hard negatives from scan feedback
|
||||
|
||||
Returns:
|
||||
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||
@@ -400,6 +412,7 @@ class ProcessedDB:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
# Include hard negatives from scan feedback
|
||||
if use_hard_negatives:
|
||||
hard_rows = self._con.execute(
|
||||
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||
" WHERE profile = ?",
|
||||
@@ -638,16 +651,18 @@ class ProcessedDB:
|
||||
return {r[0] for r in rows}
|
||||
|
||||
def add_hard_negatives(self, filename: str, profile: str,
|
||||
times: list[float], source_path: str = "") -> None:
|
||||
times: list[float], source_path: str = "",
|
||||
source_model: str = "") -> None:
|
||||
"""Save timestamps as hard-negative training examples."""
|
||||
if not self._enabled or not times:
|
||||
return
|
||||
with self._lock:
|
||||
for t in times:
|
||||
self._con.execute(
|
||||
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)"
|
||||
" VALUES (?, ?, ?, ?)",
|
||||
(filename, profile, t, source_path),
|
||||
"INSERT INTO hard_negatives"
|
||||
" (filename, profile, start_time, source_path, source_model)"
|
||||
" VALUES (?, ?, ?, ?, ?)",
|
||||
(filename, profile, t, source_path, source_model),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
@@ -662,6 +677,30 @@ class ProcessedDB:
|
||||
).fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||
"""Return all hard negatives for a profile with full details."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT id, filename, start_time, source_path, source_model"
|
||||
" FROM hard_negatives WHERE profile = ?"
|
||||
" ORDER BY filename, start_time",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||
|
||||
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||
"""Delete hard negatives by row IDs."""
|
||||
if not self._enabled or not ids:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||
ids,
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def remove_hard_negatives(self, filename: str, profile: str,
|
||||
times: list[float]) -> None:
|
||||
"""Remove specific hard-negative timestamps."""
|
||||
|
||||
@@ -372,6 +372,14 @@ class TrainDialog(QDialog):
|
||||
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
|
||||
form.addRow("", self._chk_scan_exports)
|
||||
|
||||
self._chk_hard_negatives = QCheckBox("Use hard negatives in training")
|
||||
self._chk_hard_negatives.setChecked(True)
|
||||
self._chk_hard_negatives.setToolTip(
|
||||
"When unchecked, manually marked hard negatives are excluded from training.\n"
|
||||
"Useful when training a new model type where old negatives may not apply.")
|
||||
self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start())
|
||||
form.addRow("", self._chk_hard_negatives)
|
||||
|
||||
# Video source directory (fallback for old DB rows without source_path)
|
||||
self._txt_video_dir = QLineEdit(video_dir)
|
||||
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
||||
@@ -464,15 +472,18 @@ class TrainDialog(QDialog):
|
||||
return
|
||||
neg_folder = self._cmb_negative.currentData() or ""
|
||||
inc_scan = self._chk_scan_exports.isChecked()
|
||||
use_neg = self._chk_hard_negatives.isChecked()
|
||||
# First check without fallback to see if source_paths are sufficient
|
||||
video_infos_no_fb = self._db.get_training_data(
|
||||
self._profile, folder, negative_folder=neg_folder,
|
||||
include_scan_exports=inc_scan,
|
||||
use_hard_negatives=use_neg,
|
||||
)
|
||||
video_infos = self._db.get_training_data(
|
||||
self._profile, folder, negative_folder=neg_folder,
|
||||
fallback_video_dir=self._txt_video_dir.text(),
|
||||
include_scan_exports=inc_scan,
|
||||
use_hard_negatives=use_neg,
|
||||
)
|
||||
# Show video dir field only when the fallback helps find extra videos
|
||||
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
||||
@@ -526,6 +537,10 @@ class TrainDialog(QDialog):
|
||||
def include_scan_exports(self) -> bool:
|
||||
return self._chk_scan_exports.isChecked()
|
||||
|
||||
@property
|
||||
def use_hard_negatives(self) -> bool:
|
||||
return self._chk_hard_negatives.isChecked()
|
||||
|
||||
|
||||
class TrainWorker(QThread):
|
||||
"""Trains an audio classifier off the main thread."""
|
||||
@@ -4007,8 +4022,10 @@ class MainWindow(QMainWindow):
|
||||
if not self._file_path:
|
||||
return
|
||||
filename = os.path.basename(self._file_path)
|
||||
source_model = self._scan_panel.current_model_name()
|
||||
self._db.add_hard_negatives(filename, self._profile, times,
|
||||
source_path=self._file_path)
|
||||
source_path=self._file_path,
|
||||
source_model=source_model)
|
||||
self._timeline.set_scan_regions(
|
||||
self._scan_panel.current_regions_with_orig(),
|
||||
neg_times=self._scan_panel._neg_times,
|
||||
@@ -4228,6 +4245,7 @@ class MainWindow(QMainWindow):
|
||||
embed_model = dlg.embed_model
|
||||
video_dir = dlg.video_dir
|
||||
inc_scan = dlg.include_scan_exports
|
||||
use_neg = dlg.use_hard_negatives
|
||||
if not pos_folder:
|
||||
self._show_status("No positive class selected")
|
||||
return
|
||||
@@ -4240,6 +4258,7 @@ class MainWindow(QMainWindow):
|
||||
self._profile, pos_folder, negative_folder=neg_folder,
|
||||
fallback_video_dir=video_dir,
|
||||
include_scan_exports=inc_scan,
|
||||
use_hard_negatives=use_neg,
|
||||
)
|
||||
if not video_infos:
|
||||
self._show_status("No training data found for this subprofile")
|
||||
|
||||
@@ -50,3 +50,60 @@ def test_scan_result_history():
|
||||
assert len(results.get("MODEL_A", [])) == 1
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_hard_negatives_source_model():
|
||||
"""Hard negatives should store source_model."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 2
|
||||
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_training_data_skips_hard_negatives():
|
||||
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
# Create a source file that "exists" — use the temp db file itself
|
||||
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||
source_path=path)
|
||||
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||
# With hard negatives
|
||||
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||
# Without hard negatives
|
||||
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||
assert len(data_with) >= 1
|
||||
# The "with" case should have the hard negative time in neg list
|
||||
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_delete_hard_negatives_by_ids():
|
||||
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||
source_path="/a.mp4")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 3
|
||||
# Delete first two
|
||||
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||
remaining = db.get_hard_negatives("test")
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["start_time"] == 30.0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
Reference in New Issue
Block a user