feat: hard negative source_model tracking, training toggle
Add source_model column to hard_negatives table with migration. New get_hard_negatives() returns full rows, delete_hard_negatives_by_ids() for bulk deletion. get_training_data() gains use_hard_negatives param. TrainDialog has "Use hard negatives" checkbox. Scan panel passes current model name when marking negatives. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -50,3 +50,60 @@ def test_scan_result_history():
|
||||
assert len(results.get("MODEL_A", [])) == 1
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_hard_negatives_source_model():
|
||||
"""Hard negatives should store source_model."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 2
|
||||
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_training_data_skips_hard_negatives():
|
||||
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
# Create a source file that "exists" — use the temp db file itself
|
||||
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||
source_path=path)
|
||||
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||
# With hard negatives
|
||||
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||
# Without hard negatives
|
||||
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||
assert len(data_with) >= 1
|
||||
# The "with" case should have the hard negative time in neg list
|
||||
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_delete_hard_negatives_by_ids():
|
||||
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||
source_path="/a.mp4")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 3
|
||||
# Delete first two
|
||||
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||
remaining = db.get_hard_negatives("test")
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["start_time"] == 30.0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
Reference in New Issue
Block a user