feat: hard negative source_model tracking, training toggle
Add source_model column to hard_negatives table with migration. New get_hard_negatives() returns full rows, delete_hard_negatives_by_ids() for bulk deletion. get_training_data() gains use_hard_negatives param. TrainDialog has "Use hard negatives" checkbox. Scan panel passes current model name when marking negatives. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+44
-5
@@ -123,9 +123,19 @@ class ProcessedDB:
|
|||||||
" filename TEXT NOT NULL,"
|
" filename TEXT NOT NULL,"
|
||||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
" start_time REAL NOT NULL,"
|
" start_time REAL NOT NULL,"
|
||||||
" source_path TEXT NOT NULL DEFAULT ''"
|
" source_path TEXT NOT NULL DEFAULT '',"
|
||||||
|
" source_model TEXT NOT NULL DEFAULT ''"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
# Migrate: add source_model column to existing hard_negatives tables
|
||||||
|
hn_cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||||
|
}
|
||||||
|
if "source_model" not in hn_cols:
|
||||||
|
self._con.execute(
|
||||||
|
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||||
|
)
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||||
" ON hard_negatives(filename, profile)"
|
" ON hard_negatives(filename, profile)"
|
||||||
@@ -353,6 +363,7 @@ class ProcessedDB:
|
|||||||
negative_folder: str = "",
|
negative_folder: str = "",
|
||||||
fallback_video_dir: str = "",
|
fallback_video_dir: str = "",
|
||||||
include_scan_exports: bool = False,
|
include_scan_exports: bool = False,
|
||||||
|
use_hard_negatives: bool = True,
|
||||||
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
"""Build training video_infos from DB data.
|
"""Build training video_infos from DB data.
|
||||||
|
|
||||||
@@ -362,6 +373,7 @@ class ProcessedDB:
|
|||||||
negative_folder: export folder name for explicit negatives (optional)
|
negative_folder: export folder name for explicit negatives (optional)
|
||||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||||
include_scan_exports: if True, include auto-exported scan clips
|
include_scan_exports: if True, include auto-exported scan clips
|
||||||
|
use_hard_negatives: if False, skip hard negatives from scan feedback
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of (source_video_path, positive_times, soft_times, negative_times)
|
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||||
@@ -400,6 +412,7 @@ class ProcessedDB:
|
|||||||
soft_by_video.setdefault(fn, set()).add(st)
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
# Include hard negatives from scan feedback
|
# Include hard negatives from scan feedback
|
||||||
|
if use_hard_negatives:
|
||||||
hard_rows = self._con.execute(
|
hard_rows = self._con.execute(
|
||||||
"SELECT filename, start_time, source_path FROM hard_negatives"
|
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||||
" WHERE profile = ?",
|
" WHERE profile = ?",
|
||||||
@@ -638,16 +651,18 @@ class ProcessedDB:
|
|||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
def add_hard_negatives(self, filename: str, profile: str,
|
def add_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float], source_path: str = "") -> None:
|
times: list[float], source_path: str = "",
|
||||||
|
source_model: str = "") -> None:
|
||||||
"""Save timestamps as hard-negative training examples."""
|
"""Save timestamps as hard-negative training examples."""
|
||||||
if not self._enabled or not times:
|
if not self._enabled or not times:
|
||||||
return
|
return
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for t in times:
|
for t in times:
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)"
|
"INSERT INTO hard_negatives"
|
||||||
" VALUES (?, ?, ?, ?)",
|
" (filename, profile, start_time, source_path, source_model)"
|
||||||
(filename, profile, t, source_path),
|
" VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(filename, profile, t, source_path, source_model),
|
||||||
)
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
@@ -662,6 +677,30 @@ class ProcessedDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||||
|
"""Return all hard negatives for a profile with full details."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, filename, start_time, source_path, source_model"
|
||||||
|
" FROM hard_negatives WHERE profile = ?"
|
||||||
|
" ORDER BY filename, start_time",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||||
|
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||||
|
|
||||||
|
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||||
|
"""Delete hard negatives by row IDs."""
|
||||||
|
if not self._enabled or not ids:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||||
|
ids,
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
def remove_hard_negatives(self, filename: str, profile: str,
|
def remove_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float]) -> None:
|
times: list[float]) -> None:
|
||||||
"""Remove specific hard-negative timestamps."""
|
"""Remove specific hard-negative timestamps."""
|
||||||
|
|||||||
@@ -372,6 +372,14 @@ class TrainDialog(QDialog):
|
|||||||
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
|
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
|
||||||
form.addRow("", self._chk_scan_exports)
|
form.addRow("", self._chk_scan_exports)
|
||||||
|
|
||||||
|
self._chk_hard_negatives = QCheckBox("Use hard negatives in training")
|
||||||
|
self._chk_hard_negatives.setChecked(True)
|
||||||
|
self._chk_hard_negatives.setToolTip(
|
||||||
|
"When unchecked, manually marked hard negatives are excluded from training.\n"
|
||||||
|
"Useful when training a new model type where old negatives may not apply.")
|
||||||
|
self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start())
|
||||||
|
form.addRow("", self._chk_hard_negatives)
|
||||||
|
|
||||||
# Video source directory (fallback for old DB rows without source_path)
|
# Video source directory (fallback for old DB rows without source_path)
|
||||||
self._txt_video_dir = QLineEdit(video_dir)
|
self._txt_video_dir = QLineEdit(video_dir)
|
||||||
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
||||||
@@ -464,15 +472,18 @@ class TrainDialog(QDialog):
|
|||||||
return
|
return
|
||||||
neg_folder = self._cmb_negative.currentData() or ""
|
neg_folder = self._cmb_negative.currentData() or ""
|
||||||
inc_scan = self._chk_scan_exports.isChecked()
|
inc_scan = self._chk_scan_exports.isChecked()
|
||||||
|
use_neg = self._chk_hard_negatives.isChecked()
|
||||||
# First check without fallback to see if source_paths are sufficient
|
# First check without fallback to see if source_paths are sufficient
|
||||||
video_infos_no_fb = self._db.get_training_data(
|
video_infos_no_fb = self._db.get_training_data(
|
||||||
self._profile, folder, negative_folder=neg_folder,
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
video_infos = self._db.get_training_data(
|
video_infos = self._db.get_training_data(
|
||||||
self._profile, folder, negative_folder=neg_folder,
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
fallback_video_dir=self._txt_video_dir.text(),
|
fallback_video_dir=self._txt_video_dir.text(),
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
# Show video dir field only when the fallback helps find extra videos
|
# Show video dir field only when the fallback helps find extra videos
|
||||||
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
||||||
@@ -526,6 +537,10 @@ class TrainDialog(QDialog):
|
|||||||
def include_scan_exports(self) -> bool:
|
def include_scan_exports(self) -> bool:
|
||||||
return self._chk_scan_exports.isChecked()
|
return self._chk_scan_exports.isChecked()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_hard_negatives(self) -> bool:
|
||||||
|
return self._chk_hard_negatives.isChecked()
|
||||||
|
|
||||||
|
|
||||||
class TrainWorker(QThread):
|
class TrainWorker(QThread):
|
||||||
"""Trains an audio classifier off the main thread."""
|
"""Trains an audio classifier off the main thread."""
|
||||||
@@ -4007,8 +4022,10 @@ class MainWindow(QMainWindow):
|
|||||||
if not self._file_path:
|
if not self._file_path:
|
||||||
return
|
return
|
||||||
filename = os.path.basename(self._file_path)
|
filename = os.path.basename(self._file_path)
|
||||||
|
source_model = self._scan_panel.current_model_name()
|
||||||
self._db.add_hard_negatives(filename, self._profile, times,
|
self._db.add_hard_negatives(filename, self._profile, times,
|
||||||
source_path=self._file_path)
|
source_path=self._file_path,
|
||||||
|
source_model=source_model)
|
||||||
self._timeline.set_scan_regions(
|
self._timeline.set_scan_regions(
|
||||||
self._scan_panel.current_regions_with_orig(),
|
self._scan_panel.current_regions_with_orig(),
|
||||||
neg_times=self._scan_panel._neg_times,
|
neg_times=self._scan_panel._neg_times,
|
||||||
@@ -4228,6 +4245,7 @@ class MainWindow(QMainWindow):
|
|||||||
embed_model = dlg.embed_model
|
embed_model = dlg.embed_model
|
||||||
video_dir = dlg.video_dir
|
video_dir = dlg.video_dir
|
||||||
inc_scan = dlg.include_scan_exports
|
inc_scan = dlg.include_scan_exports
|
||||||
|
use_neg = dlg.use_hard_negatives
|
||||||
if not pos_folder:
|
if not pos_folder:
|
||||||
self._show_status("No positive class selected")
|
self._show_status("No positive class selected")
|
||||||
return
|
return
|
||||||
@@ -4240,6 +4258,7 @@ class MainWindow(QMainWindow):
|
|||||||
self._profile, pos_folder, negative_folder=neg_folder,
|
self._profile, pos_folder, negative_folder=neg_folder,
|
||||||
fallback_video_dir=video_dir,
|
fallback_video_dir=video_dir,
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
if not video_infos:
|
if not video_infos:
|
||||||
self._show_status("No training data found for this subprofile")
|
self._show_status("No training data found for this subprofile")
|
||||||
|
|||||||
@@ -50,3 +50,60 @@ def test_scan_result_history():
|
|||||||
assert len(results.get("MODEL_A", [])) == 1
|
assert len(results.get("MODEL_A", [])) == 1
|
||||||
finally:
|
finally:
|
||||||
os.unlink(path)
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hard_negatives_source_model():
|
||||||
|
"""Hard negatives should store source_model."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||||
|
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_data_skips_hard_negatives():
|
||||||
|
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Create a source file that "exists" — use the temp db file itself
|
||||||
|
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||||
|
source_path=path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||||
|
# With hard negatives
|
||||||
|
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||||
|
# Without hard negatives
|
||||||
|
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||||
|
assert len(data_with) >= 1
|
||||||
|
# The "with" case should have the hard negative time in neg list
|
||||||
|
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||||
|
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||||
|
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hard_negatives_by_ids():
|
||||||
|
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||||
|
source_path="/a.mp4")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 3
|
||||||
|
# Delete first two
|
||||||
|
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||||
|
remaining = db.get_hard_negatives("test")
|
||||||
|
assert len(remaining) == 1
|
||||||
|
assert remaining[0]["start_time"] == 30.0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|||||||
Reference in New Issue
Block a user