Compare commits
13 Commits
52aa982aa2
...
5d45b8d8eb
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d45b8d8eb | |||
| e6db83f00b | |||
| edc5784ba6 | |||
| 8ed9fbf557 | |||
| 4fb2ae144f | |||
| 2614a765d5 | |||
| c020c0dfec | |||
| e7b791fbfa | |||
| f5361a963e | |||
| 8fb8581816 | |||
| 5b25e85e98 | |||
| e3f133ef84 | |||
| 4736f150b0 |
+150
-5
@@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
|||||||
_w2v_model = None
|
_w2v_model = None
|
||||||
_w2v_device = None
|
_w2v_device = None
|
||||||
_w2v_model_name = None
|
_w2v_model_name = None
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
|
||||||
# Supported embedding models — name → embed_dim
|
# Supported embedding models — name → embed_dim
|
||||||
_EMBED_MODELS = {
|
_EMBED_MODELS = {
|
||||||
@@ -55,6 +56,15 @@ _EMBED_MODELS = {
|
|||||||
"HUBERT_LARGE": 1024,
|
"HUBERT_LARGE": 1024,
|
||||||
"HUBERT_XLARGE": 1280,
|
"HUBERT_XLARGE": 1280,
|
||||||
"BEATS": 768,
|
"BEATS": 768,
|
||||||
|
# Multi-layer variants (4 quartile layers concatenated)
|
||||||
|
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
"EAT": 768,
|
||||||
}
|
}
|
||||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
@@ -70,11 +80,14 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
global _w2v_model, _w2v_device, _w2v_model_name
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
model_name = _DEFAULT_EMBED_MODEL
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
if _w2v_model is None or _w2v_model_name != model_name:
|
# Multi-layer variants use the same base model weights
|
||||||
|
ml = _ml_config(model_name)
|
||||||
|
load_name = ml[0] if ml else model_name
|
||||||
|
if _w2v_model is None or _w2v_model_name != load_name:
|
||||||
import torch
|
import torch
|
||||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
if model_name == "BEATS":
|
if load_name == "BEATS":
|
||||||
from .beats_model import BEATs, BEATsConfig
|
from .beats_model import BEATs, BEATsConfig
|
||||||
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||||
weights_only=False)
|
weights_only=False)
|
||||||
@@ -82,17 +95,61 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
_w2v_model = BEATs(cfg)
|
_w2v_model = BEATs(cfg)
|
||||||
_w2v_model.load_state_dict(checkpoint['model'])
|
_w2v_model.load_state_dict(checkpoint['model'])
|
||||||
_w2v_model.to(_w2v_device)
|
_w2v_model.to(_w2v_device)
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
|
elif load_name == "EAT":
|
||||||
|
from transformers import AutoModel
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
else:
|
else:
|
||||||
import torchaudio
|
import torchaudio
|
||||||
bundle = getattr(torchaudio.pipelines, model_name)
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
|
||||||
_w2v_model.eval()
|
_w2v_model.eval()
|
||||||
_w2v_model_name = model_name
|
_w2v_model_name = load_name
|
||||||
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
|
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||||
return _w2v_model, _w2v_device
|
return _w2v_model, _w2v_device
|
||||||
|
|
||||||
|
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
|
||||||
|
|
||||||
def _embed_dim(model_name: str | None = None) -> int:
|
def _embed_dim(model_name: str | None = None) -> int:
|
||||||
"""Return embedding dimension for a model name."""
|
"""Return embedding dimension for a model name."""
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
@@ -100,6 +157,31 @@ def _embed_dim(model_name: str | None = None) -> int:
|
|||||||
return _EMBED_MODELS.get(model_name, 768)
|
return _EMBED_MODELS.get(model_name, 768)
|
||||||
|
|
||||||
|
|
||||||
|
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||||
|
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||||
|
|
||||||
|
Returns None for single-layer models.
|
||||||
|
Layer indices are 0-based into the list returned by extract_features().
|
||||||
|
"""
|
||||||
|
if not model_name.endswith("_ML"):
|
||||||
|
return None
|
||||||
|
base = model_name[:-3] # strip "_ML"
|
||||||
|
if base not in _EMBED_MODELS:
|
||||||
|
return None
|
||||||
|
# Layer counts per model family
|
||||||
|
layer_counts = {
|
||||||
|
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||||
|
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||||
|
"AST": 12,
|
||||||
|
}
|
||||||
|
n = layer_counts.get(base)
|
||||||
|
if n is None:
|
||||||
|
return None
|
||||||
|
# Select 4 layers at quartile boundaries (0-indexed)
|
||||||
|
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||||
|
return base, indices
|
||||||
|
|
||||||
|
|
||||||
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||||
model_name: str | None = None) -> str:
|
model_name: str | None = None) -> str:
|
||||||
"""Return cache file path for a video's embeddings (includes model name)."""
|
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||||
@@ -171,6 +253,9 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
import torch
|
import torch
|
||||||
model, device = _get_w2v_model(model_name)
|
model, device = _get_w2v_model(model_name)
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
# Auto-size batches based on available GPU memory
|
# Auto-size batches based on available GPU memory
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
@@ -195,10 +280,33 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
start = i * hop_samples
|
start = i * hop_samples
|
||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
if is_beats:
|
if is_beats:
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
else:
|
else:
|
||||||
features, _ = model(waveforms)
|
features, _ = model(waveforms)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
@@ -274,6 +382,9 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
embeddings_list: list[np.ndarray] = []
|
embeddings_list: list[np.ndarray] = []
|
||||||
|
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_times), batch_size):
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
batch_end = min(batch_start + batch_size, len(valid_times))
|
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||||
@@ -283,10 +394,33 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
timestamps_list.append(float(t))
|
timestamps_list.append(float(t))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
if is_beats:
|
if is_beats:
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
else:
|
else:
|
||||||
features, _ = model(waveforms)
|
features, _ = model(waveforms)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
@@ -428,6 +562,17 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
|||||||
clf.fit(X[train_idx], y_arr[train_idx])
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
_log("audio_scan: classifier trained")
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
# Calibrate probabilities for better threshold behavior
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
min_class = min(int(n_pos), int(n_neg_sample))
|
||||||
|
if min_class >= 6:
|
||||||
|
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||||
|
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
clf = cal_clf
|
||||||
|
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||||
|
|
||||||
model = {"classifier": clf, "n_features": X.shape[1],
|
model = {"classifier": clf, "n_features": X.shape[1],
|
||||||
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||||
|
|
||||||
|
|||||||
+153
-22
@@ -94,7 +94,8 @@ class ProcessedDB:
|
|||||||
" score REAL NOT NULL,"
|
" score REAL NOT NULL,"
|
||||||
" disabled INTEGER NOT NULL DEFAULT 0,"
|
" disabled INTEGER NOT NULL DEFAULT 0,"
|
||||||
" orig_start_time REAL,"
|
" orig_start_time REAL,"
|
||||||
" orig_end_time REAL"
|
" orig_end_time REAL,"
|
||||||
|
" scan_timestamp TEXT NOT NULL DEFAULT ''"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
# Migrate: add new columns to existing scan_results tables
|
# Migrate: add new columns to existing scan_results tables
|
||||||
@@ -106,6 +107,7 @@ class ProcessedDB:
|
|||||||
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
||||||
("orig_start_time", "REAL"),
|
("orig_start_time", "REAL"),
|
||||||
("orig_end_time", "REAL"),
|
("orig_end_time", "REAL"),
|
||||||
|
("scan_timestamp", "TEXT NOT NULL DEFAULT ''"),
|
||||||
]:
|
]:
|
||||||
if col not in sr_cols:
|
if col not in sr_cols:
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
@@ -121,9 +123,19 @@ class ProcessedDB:
|
|||||||
" filename TEXT NOT NULL,"
|
" filename TEXT NOT NULL,"
|
||||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
" start_time REAL NOT NULL,"
|
" start_time REAL NOT NULL,"
|
||||||
" source_path TEXT NOT NULL DEFAULT ''"
|
" source_path TEXT NOT NULL DEFAULT '',"
|
||||||
|
" source_model TEXT NOT NULL DEFAULT ''"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
# Migrate: add source_model column to existing hard_negatives tables
|
||||||
|
hn_cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||||
|
}
|
||||||
|
if "source_model" not in hn_cols:
|
||||||
|
self._con.execute(
|
||||||
|
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||||
|
)
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||||
" ON hard_negatives(filename, profile)"
|
" ON hard_negatives(filename, profile)"
|
||||||
@@ -291,7 +303,35 @@ class ProcessedDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return [r[0] for r in rows]
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
def get_export_folders(self, profile: str = "default") -> list[str]:
|
def get_max_counter(self, folder: str, name: str) -> int:
|
||||||
|
"""Return the highest counter N found in output_paths matching folder/name_NNN*.
|
||||||
|
|
||||||
|
Parses the group directory component (e.g. 'clip_035') from stored
|
||||||
|
output_path values. Returns 0 if no matches exist.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return 0
|
||||||
|
prefix = os.path.join(folder, name + "_")
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed"
|
||||||
|
" WHERE output_path LIKE ?",
|
||||||
|
(prefix + "%",),
|
||||||
|
).fetchall()
|
||||||
|
max_n = 0
|
||||||
|
for (op,) in rows:
|
||||||
|
# output_path: .../folder/name_NNN/name_NNN_sub.ext
|
||||||
|
parent = os.path.basename(os.path.dirname(op))
|
||||||
|
# parent should be "name_NNN"
|
||||||
|
parts = parent.rsplit("_", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
try:
|
||||||
|
max_n = max(max_n, int(parts[1]))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return max_n
|
||||||
|
|
||||||
|
def get_export_folders(self, profile: str = "default",
|
||||||
|
include_scan_exports: bool = False) -> list[str]:
|
||||||
"""Return distinct export folder names found in output_paths for a profile.
|
"""Return distinct export folder names found in output_paths for a profile.
|
||||||
|
|
||||||
Export paths follow the structure:
|
Export paths follow the structure:
|
||||||
@@ -301,10 +341,17 @@ class ProcessedDB:
|
|||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return []
|
return []
|
||||||
|
if include_scan_exports:
|
||||||
rows = self._con.execute(
|
rows = self._con.execute(
|
||||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
(profile,),
|
(profile,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed"
|
||||||
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
folder_names: set[str] = set()
|
folder_names: set[str] = set()
|
||||||
for (op,) in rows:
|
for (op,) in rows:
|
||||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
@@ -316,6 +363,7 @@ class ProcessedDB:
|
|||||||
negative_folder: str = "",
|
negative_folder: str = "",
|
||||||
fallback_video_dir: str = "",
|
fallback_video_dir: str = "",
|
||||||
include_scan_exports: bool = False,
|
include_scan_exports: bool = False,
|
||||||
|
use_hard_negatives: bool = True,
|
||||||
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
"""Build training video_infos from DB data.
|
"""Build training video_infos from DB data.
|
||||||
|
|
||||||
@@ -325,6 +373,7 @@ class ProcessedDB:
|
|||||||
negative_folder: export folder name for explicit negatives (optional)
|
negative_folder: export folder name for explicit negatives (optional)
|
||||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||||
include_scan_exports: if True, include auto-exported scan clips
|
include_scan_exports: if True, include auto-exported scan clips
|
||||||
|
use_hard_negatives: if False, skip hard negatives from scan feedback
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of (source_video_path, positive_times, soft_times, negative_times)
|
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||||
@@ -363,6 +412,7 @@ class ProcessedDB:
|
|||||||
soft_by_video.setdefault(fn, set()).add(st)
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
# Include hard negatives from scan feedback
|
# Include hard negatives from scan feedback
|
||||||
|
if use_hard_negatives:
|
||||||
hard_rows = self._con.execute(
|
hard_rows = self._con.execute(
|
||||||
"SELECT filename, start_time, source_path FROM hard_negatives"
|
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||||
" WHERE profile = ?",
|
" WHERE profile = ?",
|
||||||
@@ -429,7 +479,7 @@ class ProcessedDB:
|
|||||||
" WHERE profile = ? AND scan_export = 0",
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
(profile,),
|
(profile,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
folders = self.get_export_folders(profile)
|
folders = self.get_export_folders(profile, include_scan_exports=include_scan_exports)
|
||||||
stats: dict[str, dict] = {}
|
stats: dict[str, dict] = {}
|
||||||
for folder_name in folders:
|
for folder_name in folders:
|
||||||
videos: set[str] = set()
|
videos: set[str] = set()
|
||||||
@@ -440,49 +490,104 @@ class ProcessedDB:
|
|||||||
videos.add(fn)
|
videos.add(fn)
|
||||||
clips += 1
|
clips += 1
|
||||||
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||||
return stats
|
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||||
|
|
||||||
# ── Scan results ─────────────────────────────────────────────
|
# ── Scan results ─────────────────────────────────────────────
|
||||||
|
|
||||||
def save_scan_results(self, filename: str, profile: str, model: str,
|
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||||
regions: list[tuple[float, float, float]]) -> None:
|
regions: list[tuple[float, float, float]],
|
||||||
"""Replace scan results for (filename, profile, model) with new regions.
|
max_versions: int = 5) -> None:
|
||||||
|
"""Save scan results as a new version for (filename, profile, model).
|
||||||
|
|
||||||
regions: list of (start_time, end_time, score).
|
regions: list of (start_time, end_time, score).
|
||||||
|
Keeps up to max_versions; oldest are pruned automatically.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._con.execute(
|
|
||||||
"DELETE FROM scan_results"
|
|
||||||
" WHERE filename = ? AND profile = ? AND model = ?",
|
|
||||||
(filename, profile, model),
|
|
||||||
)
|
|
||||||
self._con.executemany(
|
self._con.executemany(
|
||||||
"INSERT INTO scan_results"
|
"INSERT INTO scan_results"
|
||||||
" (filename, profile, model, start_time, end_time, score,"
|
" (filename, profile, model, start_time, end_time, score,"
|
||||||
" orig_start_time, orig_end_time)"
|
" orig_start_time, orig_end_time, scan_timestamp)"
|
||||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
[(filename, profile, model, s, e, sc, s, e) for s, e, sc in regions],
|
[(filename, profile, model, s, e, sc, s, e, ts)
|
||||||
|
for s, e, sc in regions],
|
||||||
|
)
|
||||||
|
# Prune old versions beyond max_versions
|
||||||
|
versions = self._con.execute(
|
||||||
|
"SELECT DISTINCT scan_timestamp FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
if len(versions) > max_versions:
|
||||||
|
old_ts = [v[0] for v in versions[max_versions:]]
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
f" AND scan_timestamp IN ({','.join('?' * len(old_ts))})",
|
||||||
|
(filename, profile, model, *old_ts),
|
||||||
)
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def get_scan_results(self, filename: str, profile: str
|
def get_scan_versions(self, filename: str, profile: str, model: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Return list of scan versions for (filename, profile, model).
|
||||||
|
|
||||||
|
Returns [{timestamp, count, max_score}, ...] ordered newest first.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT scan_timestamp, COUNT(*), MAX(score)"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" AND scan_timestamp != ''"
|
||||||
|
" GROUP BY scan_timestamp"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
return [{"timestamp": ts, "count": cnt, "max_score": sc}
|
||||||
|
for ts, cnt, sc in rows]
|
||||||
|
|
||||||
|
def get_scan_results(self, filename: str, profile: str,
|
||||||
|
scan_timestamp: str | None = None
|
||||||
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
||||||
"""Return scan results grouped by model.
|
"""Return scan results grouped by model.
|
||||||
|
|
||||||
|
If scan_timestamp is given, returns only that version's rows.
|
||||||
|
Otherwise returns the latest version per model.
|
||||||
|
|
||||||
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
||||||
sorted by start_time.
|
sorted by start_time.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return {}
|
return {}
|
||||||
|
if scan_timestamp:
|
||||||
rows = self._con.execute(
|
rows = self._con.execute(
|
||||||
"SELECT id, model, start_time, end_time, score, disabled,"
|
"SELECT id, model, start_time, end_time, score, disabled,"
|
||||||
" orig_start_time, orig_end_time"
|
" orig_start_time, orig_end_time"
|
||||||
" FROM scan_results"
|
" FROM scan_results"
|
||||||
" WHERE filename = ? AND profile = ?"
|
" WHERE filename = ? AND profile = ? AND scan_timestamp = ?"
|
||||||
" ORDER BY model, start_time",
|
" ORDER BY model, start_time",
|
||||||
(filename, profile),
|
(filename, profile, scan_timestamp),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
# For each model, get rows from the latest timestamp only
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT r.id, r.model, r.start_time, r.end_time, r.score,"
|
||||||
|
" r.disabled, r.orig_start_time, r.orig_end_time"
|
||||||
|
" FROM scan_results r"
|
||||||
|
" INNER JOIN ("
|
||||||
|
" SELECT model, MAX(scan_timestamp) AS latest"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" GROUP BY model"
|
||||||
|
" ) m ON r.model = m.model AND r.scan_timestamp = m.latest"
|
||||||
|
" WHERE r.filename = ? AND r.profile = ?"
|
||||||
|
" ORDER BY r.model, r.start_time",
|
||||||
|
(filename, profile, filename, profile),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
||||||
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
||||||
@@ -546,16 +651,18 @@ class ProcessedDB:
|
|||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
def add_hard_negatives(self, filename: str, profile: str,
|
def add_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float], source_path: str = "") -> None:
|
times: list[float], source_path: str = "",
|
||||||
|
source_model: str = "") -> None:
|
||||||
"""Save timestamps as hard-negative training examples."""
|
"""Save timestamps as hard-negative training examples."""
|
||||||
if not self._enabled or not times:
|
if not self._enabled or not times:
|
||||||
return
|
return
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for t in times:
|
for t in times:
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)"
|
"INSERT INTO hard_negatives"
|
||||||
" VALUES (?, ?, ?, ?)",
|
" (filename, profile, start_time, source_path, source_model)"
|
||||||
(filename, profile, t, source_path),
|
" VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(filename, profile, t, source_path, source_model),
|
||||||
)
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
@@ -570,6 +677,30 @@ class ProcessedDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||||
|
"""Return all hard negatives for a profile with full details."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, filename, start_time, source_path, source_model"
|
||||||
|
" FROM hard_negatives WHERE profile = ?"
|
||||||
|
" ORDER BY filename, start_time",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||||
|
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||||
|
|
||||||
|
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||||
|
"""Delete hard negatives by row IDs."""
|
||||||
|
if not self._enabled or not ids:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||||
|
ids,
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
def remove_hard_negatives(self, filename: str, profile: str,
|
def remove_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float]) -> None:
|
times: list[float]) -> None:
|
||||||
"""Remove specific hard-negative timestamps."""
|
"""Remove specific hard-negative timestamps."""
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# Scan History & Hard Negative Management Design
|
||||||
|
|
||||||
|
Date: 2026-04-19
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
1. Keep scan result history per `(file, model)` so users can track classifier improvement across training iterations
|
||||||
|
2. Make hard negatives manageable — viewable, removable, and optionally disabled per training run
|
||||||
|
3. Fix latent bug: `get_export_folders()` doesn't filter by `scan_export`
|
||||||
|
|
||||||
|
## 1. Scan Result History
|
||||||
|
|
||||||
|
### Current behavior
|
||||||
|
|
||||||
|
`save_scan_results()` **replaces** all results for `(filename, profile, model)` on every scan. No history is preserved.
|
||||||
|
|
||||||
|
### Change
|
||||||
|
|
||||||
|
Keep the last N scan results per `(filename, profile, model)` with timestamps. The most recent is the "active" result displayed in the panel; older versions are accessible for comparison.
|
||||||
|
|
||||||
|
### Schema change
|
||||||
|
|
||||||
|
Add column to `scan_results`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
ALTER TABLE scan_results ADD COLUMN scan_timestamp TEXT NOT NULL DEFAULT '';
|
||||||
|
```
|
||||||
|
|
||||||
|
All rows from the same scan share the same timestamp string (e.g. `"20260419_143022"`).
|
||||||
|
|
||||||
|
### save_scan_results changes
|
||||||
|
|
||||||
|
Instead of `DELETE ... WHERE filename=? AND profile=? AND model=?`, the new flow:
|
||||||
|
|
||||||
|
1. Insert new rows with current timestamp
|
||||||
|
2. Count distinct timestamps for this `(filename, profile, model)`
|
||||||
|
3. If count > N (default 5), delete rows belonging to the oldest timestamps
|
||||||
|
|
||||||
|
### UI changes
|
||||||
|
|
||||||
|
Add a small version dropdown/selector in `ScanResultsPanel` per model tab — shows timestamps of available scan versions. Selecting a version loads that version's results into the tab. The most recent is selected by default.
|
||||||
|
|
||||||
|
The tab label shows the active version's region count, e.g. `HUBERT_XLARGE (12) [v3]`.
|
||||||
|
|
||||||
|
### Cache interaction
|
||||||
|
|
||||||
|
Embedding cache is per `(file, model)` and doesn't change across scans. Only the classifier output changes. History stores the classified regions (start, end, score), not embeddings.
|
||||||
|
|
||||||
|
## 2. Hard Negative Management
|
||||||
|
|
||||||
|
### Current behavior
|
||||||
|
|
||||||
|
- Hard negatives stored in `hard_negatives` table: `(filename, profile, start_time, source_path)`
|
||||||
|
- No model column — applied globally within a profile
|
||||||
|
- Removable one-by-one via N toggle in scan panel, but no bulk management
|
||||||
|
- Always used in training — no way to disable
|
||||||
|
|
||||||
|
### Changes
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
|
||||||
|
Add `source_model TEXT NOT NULL DEFAULT ''` column to `hard_negatives`. Populated when marking negatives from scan results (we know which model tab is active).
|
||||||
|
|
||||||
|
#### Training toggle
|
||||||
|
|
||||||
|
New checkbox in `TrainDialog`: **"Use hard negatives"** (default checked). When unchecked, `get_training_data()` skips the `hard_negatives` query entirely. Non-destructive — negatives remain in DB.
|
||||||
|
|
||||||
|
#### Management dialog
|
||||||
|
|
||||||
|
New `HardNegativesDialog` accessible from Train dialog via "Manage..." button next to the checkbox. Shows:
|
||||||
|
|
||||||
|
- Table: filename, start time, source model, date added (if we add created_at)
|
||||||
|
- Filter by source model (dropdown)
|
||||||
|
- Multi-select + Delete button
|
||||||
|
- "Clear All" button with confirmation
|
||||||
|
- Count summary at top
|
||||||
|
|
||||||
|
### Training integration
|
||||||
|
|
||||||
|
`get_training_data()` gets a new `use_hard_negatives: bool = True` parameter. When False, the hard negatives query (lines 365-374 of db.py) is skipped entirely.
|
||||||
|
|
||||||
|
## 3. Ghost Folder Fix
|
||||||
|
|
||||||
|
### Bug
|
||||||
|
|
||||||
|
`get_export_folders()` queries all `output_path` rows without filtering `scan_export`. Folders that only contain scan-exported clips appear in training dropdowns with 0 clips.
|
||||||
|
|
||||||
|
### Fix
|
||||||
|
|
||||||
|
Add `include_scan_exports` parameter to `get_export_folders()`. When False (default), only query rows with `scan_export = 0`. Also filter out folders with 0 clips from `get_training_stats()` result dict.
|
||||||
@@ -0,0 +1,714 @@
|
|||||||
|
# Scan History & Hard Negative Management Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Add scan result versioning, hard negative management dialog with training toggle, and fix ghost folder bug.
|
||||||
|
|
||||||
|
**Architecture:** DB schema changes in `core/db.py` (new columns, new queries). UI changes in `main.py` (version selector in ScanResultsPanel, management dialog, training toggle). No changes to `core/audio_scan.py`.
|
||||||
|
|
||||||
|
**Tech Stack:** SQLite (existing), PyQt6 (existing)
|
||||||
|
|
||||||
|
**Key design notes:**
|
||||||
|
- Scan history stores N versions per `(filename, profile, model)` using a `scan_timestamp` column. All rows from one scan share the same timestamp.
|
||||||
|
- Hard negatives gain a `source_model` column (informational) and training gains a `use_hard_negatives` toggle.
|
||||||
|
- `get_export_folders()` must respect `scan_export` filter to prevent ghost folders.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Fix ghost folder bug in get_export_folders
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/db.py:294-313` (get_export_folders)
|
||||||
|
- Modify: `core/db.py:410-443` (get_training_stats — filter out 0-clip folders)
|
||||||
|
- Test: `tests/test_db.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_export_folders_excludes_scan_exports():
|
||||||
|
"""Scan-export-only folders should not appear when include_scan_exports=False."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Manual export
|
||||||
|
db.add("a.mp4", 10.0, "/out/mp4_Intense/g1/clip.mp4", profile="test")
|
||||||
|
# Scan export to different folder
|
||||||
|
db.add("a.mp4", 20.0, "/out/mp4_ScanOnly/g1/clip.mp4", profile="test",
|
||||||
|
scan_export=True)
|
||||||
|
folders = db.get_export_folders("test")
|
||||||
|
assert "mp4_Intense" in folders
|
||||||
|
assert "mp4_ScanOnly" not in folders, "scan-only folder should be excluded"
|
||||||
|
# With include_scan_exports=True, both should appear
|
||||||
|
folders_all = db.get_export_folders("test", include_scan_exports=True)
|
||||||
|
assert "mp4_ScanOnly" in folders_all
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Fix get_export_folders**
|
||||||
|
|
||||||
|
Add `include_scan_exports` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_export_folders(self, profile: str = "default",
|
||||||
|
include_scan_exports: bool = False) -> list[str]:
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
if include_scan_exports:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed"
|
||||||
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
folder_names: set[str] = set()
|
||||||
|
for (op,) in rows:
|
||||||
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
|
if grandparent:
|
||||||
|
folder_names.add(grandparent)
|
||||||
|
return sorted(folder_names)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Update get_training_stats to pass through**
|
||||||
|
|
||||||
|
```python
|
||||||
|
folders = self.get_export_folders(profile, include_scan_exports=include_scan_exports)
|
||||||
|
```
|
||||||
|
|
||||||
|
And filter out empty folders at the end:
|
||||||
|
|
||||||
|
```python
|
||||||
|
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run tests, commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v
|
||||||
|
git add core/db.py tests/test_db.py
|
||||||
|
git commit -m "fix: get_export_folders respects scan_export filter"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Scan result history — schema and DB methods
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/db.py:86-98` (scan_results schema — add scan_timestamp column)
|
||||||
|
- Modify: `core/db.py:100-113` (migration — add scan_timestamp to existing tables)
|
||||||
|
- Modify: `core/db.py:447-468` (save_scan_results — version management)
|
||||||
|
- Add: `core/db.py` (get_scan_versions, load_scan_version, delete_scan_version)
|
||||||
|
- Test: `tests/test_db.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_scan_result_history():
|
||||||
|
"""save_scan_results should keep multiple versions."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Save three versions
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(0, 8, 0.9)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(0, 8, 0.8), (10, 18, 0.7)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(5, 13, 0.95)])
|
||||||
|
versions = db.get_scan_versions("v.mp4", "test", "MODEL_A")
|
||||||
|
assert len(versions) == 3
|
||||||
|
# Most recent first
|
||||||
|
assert versions[0]["count"] == 1 # latest: 1 region
|
||||||
|
assert versions[1]["count"] == 2 # middle: 2 regions
|
||||||
|
assert versions[2]["count"] == 1 # oldest: 1 region
|
||||||
|
# get_scan_results returns latest version by default
|
||||||
|
results = db.get_scan_results("v.mp4", "test")
|
||||||
|
assert len(results.get("MODEL_A", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add scan_timestamp column**
|
||||||
|
|
||||||
|
In the CREATE TABLE (line 87-98), add:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
scan_timestamp TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
In the migration block (lines 100-113), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
("scan_timestamp", "TEXT NOT NULL DEFAULT ''"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Modify save_scan_results**
|
||||||
|
|
||||||
|
Replace the current DELETE+INSERT with versioned insert + cleanup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||||
|
regions: list[tuple[float, float, float]],
|
||||||
|
max_versions: int = 5) -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
from datetime import datetime
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
with self._lock:
|
||||||
|
self._con.executemany(
|
||||||
|
"INSERT INTO scan_results"
|
||||||
|
" (filename, profile, model, start_time, end_time, score,"
|
||||||
|
" orig_start_time, orig_end_time, scan_timestamp)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
[(filename, profile, model, s, e, sc, s, e, ts)
|
||||||
|
for s, e, sc in regions],
|
||||||
|
)
|
||||||
|
# Prune old versions beyond max_versions
|
||||||
|
versions = self._con.execute(
|
||||||
|
"SELECT DISTINCT scan_timestamp FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
if len(versions) > max_versions:
|
||||||
|
old_ts = [v[0] for v in versions[max_versions:]]
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
f" AND scan_timestamp IN ({','.join('?' * len(old_ts))})",
|
||||||
|
(filename, profile, model, *old_ts),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add get_scan_versions**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_scan_versions(self, filename: str, profile: str, model: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Return list of scan versions for (filename, profile, model).
|
||||||
|
|
||||||
|
Returns [{timestamp, count, max_score}, ...] ordered newest first.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT scan_timestamp, COUNT(*), MAX(score)"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" AND scan_timestamp != ''"
|
||||||
|
" GROUP BY scan_timestamp"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
return [{"timestamp": ts, "count": cnt, "max_score": sc}
|
||||||
|
for ts, cnt, sc in rows]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Modify get_scan_results to support version selection**
|
||||||
|
|
||||||
|
Add optional `scan_timestamp` parameter. When None (default), returns latest version:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_scan_results(self, filename: str, profile: str,
|
||||||
|
scan_timestamp: str | None = None
|
||||||
|
) -> dict[str, list[tuple]]:
|
||||||
|
if not self._enabled:
|
||||||
|
return {}
|
||||||
|
if scan_timestamp:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, model, start_time, end_time, score, disabled,"
|
||||||
|
" orig_start_time, orig_end_time"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND scan_timestamp = ?"
|
||||||
|
" ORDER BY model, start_time",
|
||||||
|
(filename, profile, scan_timestamp),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
# For each model, get rows from the latest timestamp only
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT r.id, r.model, r.start_time, r.end_time, r.score,"
|
||||||
|
" r.disabled, r.orig_start_time, r.orig_end_time"
|
||||||
|
" FROM scan_results r"
|
||||||
|
" INNER JOIN ("
|
||||||
|
" SELECT model, MAX(scan_timestamp) AS latest"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" GROUP BY model"
|
||||||
|
" ) m ON r.model = m.model AND r.scan_timestamp = m.latest"
|
||||||
|
" WHERE r.filename = ? AND r.profile = ?"
|
||||||
|
" ORDER BY r.model, r.start_time",
|
||||||
|
(filename, profile, filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
result: dict[str, list] = {}
|
||||||
|
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
||||||
|
result.setdefault(model, []).append(
|
||||||
|
(row_id, s, e, sc, bool(dis),
|
||||||
|
os_ if os_ is not None else s,
|
||||||
|
oe if oe is not None else e))
|
||||||
|
return result
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important:** Legacy rows (before this change) have `scan_timestamp = ''`. The `MAX(scan_timestamp)` query handles this correctly — empty string sorts before any real timestamp, so legacy rows are returned when they're the only version. The `get_scan_versions` query filters `scan_timestamp != ''` so legacy rows don't appear as named versions.
|
||||||
|
|
||||||
|
**Step 6: Run tests, commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v
|
||||||
|
git add core/db.py tests/test_db.py
|
||||||
|
git commit -m "feat: scan result history — keep N versions per (file, model)"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Scan history UI — version selector in ScanResultsPanel
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (ScanResultsPanel — add version combo per tab)
|
||||||
|
- Modify: `main.py` (ScanResultsPanel.load_for_file — populate versions)
|
||||||
|
|
||||||
|
**Step 1: Add version combo to tab UI**
|
||||||
|
|
||||||
|
In `ScanResultsPanel._add_tab()`, add a small QComboBox above the table. When no history exists, hide it. When versions exist, populate with timestamps and connect to a slot that reloads the tab with that version.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In _add_tab, create a container widget with version combo + table
|
||||||
|
container = QWidget()
|
||||||
|
layout = QVBoxLayout(container)
|
||||||
|
layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
cmb_version = QComboBox()
|
||||||
|
cmb_version.setMaximumWidth(200)
|
||||||
|
cmb_version.setToolTip("Scan version history")
|
||||||
|
cmb_version.hide() # Hidden when only 1 version
|
||||||
|
layout.addWidget(cmb_version)
|
||||||
|
layout.addWidget(table)
|
||||||
|
|
||||||
|
self._tabs.addTab(container, label)
|
||||||
|
```
|
||||||
|
|
||||||
|
Store the combo and table as properties on the container widget for later access.
|
||||||
|
|
||||||
|
**Step 2: Populate versions in load_for_file**
|
||||||
|
|
||||||
|
After creating each model tab, query `get_scan_versions()`. If > 1 version, show the combo with entries like `"2026-04-19 14:30 (12 regions, best: 0.95)"`. Connect `currentIndexChanged` to reload that version's results.
|
||||||
|
|
||||||
|
**Step 3: Version switching slot**
|
||||||
|
|
||||||
|
When user selects a different version from the combo:
|
||||||
|
1. Call `db.get_scan_results(filename, profile, scan_timestamp=selected_ts)`
|
||||||
|
2. Repopulate the table with that version's rows
|
||||||
|
3. Update timeline regions
|
||||||
|
|
||||||
|
**Step 4: Test manually, commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: scan version selector in results panel"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Hard negatives — schema and training toggle
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/db.py:118-130` (hard_negatives schema — add source_model column)
|
||||||
|
- Modify: `core/db.py:548-560` (add_hard_negatives — accept source_model)
|
||||||
|
- Modify: `core/db.py:365-374` (get_training_data — use_hard_negatives parameter)
|
||||||
|
- Modify: `main.py` (TrainDialog — add "Use hard negatives" checkbox)
|
||||||
|
- Modify: `main.py` (_open_train_dialog — pass use_hard_negatives to get_training_data)
|
||||||
|
- Test: `tests/test_db.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_hard_negatives_source_model():
|
||||||
|
"""Hard negatives should store source_model."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||||
|
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_training_data_skips_hard_negatives():
|
||||||
|
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||||
|
source_path="/videos/a.mp4")
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [500.0], source_path="/videos/a.mp4")
|
||||||
|
# With hard negatives
|
||||||
|
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||||
|
# Without hard negatives
|
||||||
|
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||||
|
# Both should find the video, but negative counts differ
|
||||||
|
assert len(data_with) >= 1
|
||||||
|
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||||
|
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||||
|
assert neg_with > neg_without or neg_with == neg_without # depends on margin
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add source_model column to hard_negatives**
|
||||||
|
|
||||||
|
In CREATE TABLE (line 119-125), add:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
source_model TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
In migration section, add after the hard_negatives table creation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
hn_cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||||
|
}
|
||||||
|
if "source_model" not in hn_cols:
|
||||||
|
self._con.execute(
|
||||||
|
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Update add_hard_negatives to accept source_model**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def add_hard_negatives(self, filename: str, profile: str,
|
||||||
|
times: list[float], source_path: str = "",
|
||||||
|
source_model: str = "") -> None:
|
||||||
|
if not self._enabled or not times:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
for t in times:
|
||||||
|
self._con.execute(
|
||||||
|
"INSERT INTO hard_negatives"
|
||||||
|
" (filename, profile, start_time, source_path, source_model)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(filename, profile, t, source_path, source_model),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add get_hard_negatives (full rows for management dialog)**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||||
|
"""Return all hard negatives for a profile with full details."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, filename, start_time, source_path, source_model"
|
||||||
|
" FROM hard_negatives WHERE profile = ?"
|
||||||
|
" ORDER BY filename, start_time",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||||
|
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Add delete_hard_negatives_by_ids**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||||
|
"""Delete hard negatives by row IDs."""
|
||||||
|
if not self._enabled or not ids:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||||
|
ids,
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Add use_hard_negatives parameter to get_training_data**
|
||||||
|
|
||||||
|
In `get_training_data()` (line 315), add parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_training_data(self, profile: str, positive_folder: str,
|
||||||
|
negative_folder: str = "",
|
||||||
|
fallback_video_dir: str = "",
|
||||||
|
include_scan_exports: bool = False,
|
||||||
|
use_hard_negatives: bool = True,
|
||||||
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
|
```
|
||||||
|
|
||||||
|
Then wrap the hard negatives query (lines 365-374) in a conditional:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if use_hard_negatives:
|
||||||
|
hard_rows = self._con.execute(
|
||||||
|
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||||
|
" WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
for fn, st, sp in hard_rows:
|
||||||
|
neg_by_video.setdefault(fn, set()).add(st)
|
||||||
|
if sp:
|
||||||
|
source_by_filename.setdefault(fn, sp)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 7: Pass source_model when marking negatives from scan panel**
|
||||||
|
|
||||||
|
In `main.py`, `_on_scan_negatives()` needs to pass the current scan model. The scan panel knows which tab is active:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _on_scan_negatives(self, times: list) -> None:
|
||||||
|
if not self._file_path:
|
||||||
|
return
|
||||||
|
filename = os.path.basename(self._file_path)
|
||||||
|
# Get current model tab name for source_model
|
||||||
|
source_model = self._scan_panel.current_model_name()
|
||||||
|
self._db.add_hard_negatives(filename, self._profile, times,
|
||||||
|
source_path=self._file_path,
|
||||||
|
source_model=source_model)
|
||||||
|
```
|
||||||
|
|
||||||
|
Add `current_model_name()` to ScanResultsPanel:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def current_model_name(self) -> str:
|
||||||
|
"""Return the model name of the currently active tab."""
|
||||||
|
idx = self._tabs.currentIndex()
|
||||||
|
if idx >= 0:
|
||||||
|
return self._tabs.tabText(idx).split(" (")[0] # strip count suffix
|
||||||
|
return ""
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 8: Add training toggle to TrainDialog**
|
||||||
|
|
||||||
|
After the existing `_chk_scan_exports` checkbox:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self._chk_hard_negatives = QCheckBox("Use hard negatives in training")
|
||||||
|
self._chk_hard_negatives.setChecked(True)
|
||||||
|
self._chk_hard_negatives.setToolTip(
|
||||||
|
"When unchecked, manually marked hard negatives are excluded from training.\n"
|
||||||
|
"Useful when training a new model type where old negatives may not apply.")
|
||||||
|
self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start())
|
||||||
|
form.addRow("", self._chk_hard_negatives)
|
||||||
|
```
|
||||||
|
|
||||||
|
Add property:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@property
|
||||||
|
def use_hard_negatives(self) -> bool:
|
||||||
|
return self._chk_hard_negatives.isChecked()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 9: Wire toggle through _open_train_dialog**
|
||||||
|
|
||||||
|
In `_open_train_dialog()`, pass the flag:
|
||||||
|
|
||||||
|
```python
|
||||||
|
video_infos = self._db.get_training_data(
|
||||||
|
self._profile, pos_folder, negative_folder=neg_folder,
|
||||||
|
fallback_video_dir=video_dir,
|
||||||
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=dlg.use_hard_negatives,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Also update `_update_stats()` in TrainDialog to pass it through for accurate counts:
|
||||||
|
|
||||||
|
```python
|
||||||
|
use_neg = self._chk_hard_negatives.isChecked() if hasattr(self, '_chk_hard_negatives') else True
|
||||||
|
video_infos = self._db.get_training_data(
|
||||||
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
|
fallback_video_dir=self._txt_video_dir.text(),
|
||||||
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 10: Run tests, commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v
|
||||||
|
git add core/db.py main.py tests/test_db.py
|
||||||
|
git commit -m "feat: hard negative source_model tracking, training toggle"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Hard negatives management dialog
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (add HardNegativesDialog class)
|
||||||
|
- Modify: `main.py` (TrainDialog — add "Manage..." button)
|
||||||
|
|
||||||
|
**Step 1: Create HardNegativesDialog**
|
||||||
|
|
||||||
|
Place before TrainDialog class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class HardNegativesDialog(QDialog):
|
||||||
|
"""View and manage hard negative training examples."""
|
||||||
|
|
||||||
|
def __init__(self, db: ProcessedDB, profile: str, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.setWindowTitle("Hard Negatives")
|
||||||
|
self.setMinimumSize(600, 400)
|
||||||
|
self._db = db
|
||||||
|
self._profile = profile
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
# Filter row
|
||||||
|
filter_row = QHBoxLayout()
|
||||||
|
filter_row.addWidget(QLabel("Filter model:"))
|
||||||
|
self._cmb_filter = QComboBox()
|
||||||
|
self._cmb_filter.addItem("(all)")
|
||||||
|
self._cmb_filter.currentIndexChanged.connect(self._apply_filter)
|
||||||
|
filter_row.addWidget(self._cmb_filter, 1)
|
||||||
|
layout.addLayout(filter_row)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
self._lbl_summary = QLabel()
|
||||||
|
layout.addWidget(self._lbl_summary)
|
||||||
|
|
||||||
|
# Table
|
||||||
|
self._table = QTableWidget(0, 4)
|
||||||
|
self._table.setHorizontalHeaderLabels(
|
||||||
|
["File", "Time", "Source Model", "ID"])
|
||||||
|
self._table.horizontalHeader().setSectionResizeMode(
|
||||||
|
0, QHeaderView.ResizeMode.Stretch)
|
||||||
|
self._table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
|
||||||
|
self._table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
||||||
|
self._table.setColumnHidden(3, True) # hide ID column
|
||||||
|
layout.addWidget(self._table)
|
||||||
|
|
||||||
|
# Buttons
|
||||||
|
btn_row = QHBoxLayout()
|
||||||
|
btn_delete = QPushButton("Delete Selected")
|
||||||
|
btn_delete.clicked.connect(self._delete_selected)
|
||||||
|
btn_row.addWidget(btn_delete)
|
||||||
|
btn_clear = QPushButton("Clear All")
|
||||||
|
btn_clear.clicked.connect(self._clear_all)
|
||||||
|
btn_row.addWidget(btn_clear)
|
||||||
|
btn_row.addStretch()
|
||||||
|
btn_close = QPushButton("Close")
|
||||||
|
btn_close.clicked.connect(self.close)
|
||||||
|
btn_row.addWidget(btn_close)
|
||||||
|
layout.addLayout(btn_row)
|
||||||
|
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
rows = self._db.get_hard_negatives(self._profile)
|
||||||
|
models = sorted(set(r["source_model"] for r in rows if r["source_model"]))
|
||||||
|
self._cmb_filter.blockSignals(True)
|
||||||
|
self._cmb_filter.clear()
|
||||||
|
self._cmb_filter.addItem("(all)")
|
||||||
|
for m in models:
|
||||||
|
self._cmb_filter.addItem(m)
|
||||||
|
self._cmb_filter.blockSignals(False)
|
||||||
|
|
||||||
|
self._table.setRowCount(len(rows))
|
||||||
|
for i, r in enumerate(rows):
|
||||||
|
self._table.setItem(i, 0, QTableWidgetItem(r["filename"]))
|
||||||
|
self._table.setItem(i, 1, QTableWidgetItem(f'{r["start_time"]:.1f}s'))
|
||||||
|
self._table.setItem(i, 2, QTableWidgetItem(r["source_model"]))
|
||||||
|
item = QTableWidgetItem(str(r["id"]))
|
||||||
|
self._table.setItem(i, 3, item)
|
||||||
|
self._lbl_summary.setText(f"<b>{len(rows)}</b> hard negatives")
|
||||||
|
|
||||||
|
def _apply_filter(self):
|
||||||
|
model = self._cmb_filter.currentText()
|
||||||
|
for row in range(self._table.rowCount()):
|
||||||
|
if model == "(all)":
|
||||||
|
self._table.setRowHidden(row, False)
|
||||||
|
else:
|
||||||
|
src = self._table.item(row, 2).text()
|
||||||
|
self._table.setRowHidden(row, src != model)
|
||||||
|
|
||||||
|
def _delete_selected(self):
|
||||||
|
ids = []
|
||||||
|
for row in sorted(set(i.row() for i in self._table.selectedItems()), reverse=True):
|
||||||
|
if not self._table.isRowHidden(row):
|
||||||
|
ids.append(int(self._table.item(row, 3).text()))
|
||||||
|
if ids:
|
||||||
|
self._db.delete_hard_negatives_by_ids(ids)
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _clear_all(self):
|
||||||
|
reply = QMessageBox.question(
|
||||||
|
self, "Clear All",
|
||||||
|
f"Delete all hard negatives for profile '{self._profile}'?",
|
||||||
|
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
||||||
|
)
|
||||||
|
if reply == QMessageBox.StandardButton.Yes:
|
||||||
|
all_rows = self._db.get_hard_negatives(self._profile)
|
||||||
|
self._db.delete_hard_negatives_by_ids([r["id"] for r in all_rows])
|
||||||
|
self._load()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add "Manage..." button to TrainDialog**
|
||||||
|
|
||||||
|
After the hard negatives checkbox, add a button:
|
||||||
|
|
||||||
|
```python
|
||||||
|
neg_row = QHBoxLayout()
|
||||||
|
neg_row.addWidget(self._chk_hard_negatives)
|
||||||
|
btn_manage_neg = QPushButton("Manage…")
|
||||||
|
btn_manage_neg.setFixedWidth(80)
|
||||||
|
btn_manage_neg.clicked.connect(self._manage_negatives)
|
||||||
|
neg_row.addWidget(btn_manage_neg)
|
||||||
|
form.addRow("", neg_row) # replaces the standalone checkbox addRow
|
||||||
|
```
|
||||||
|
|
||||||
|
Add handler:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _manage_negatives(self):
|
||||||
|
dlg = HardNegativesDialog(self._db, self._profile, parent=self)
|
||||||
|
dlg.exec()
|
||||||
|
self._debounce.start() # refresh stats after potential deletions
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Test manually, commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: hard negatives management dialog with filter and bulk delete"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Final integration test and push
|
||||||
|
|
||||||
|
**Step 1: Manual test checklist**
|
||||||
|
|
||||||
|
- [ ] Open Train dialog — verify no ghost folders appear
|
||||||
|
- [ ] Train with "Use hard negatives" unchecked — verify training works
|
||||||
|
- [ ] Train with "Use hard negatives" checked — verify negatives are used
|
||||||
|
- [ ] Open Manage dialog — verify negatives listed with source model
|
||||||
|
- [ ] Delete selected negatives — verify they're removed
|
||||||
|
- [ ] Scan a video — verify results saved with timestamp
|
||||||
|
- [ ] Rescan same video — verify version history appears
|
||||||
|
- [ ] Switch version in scan panel — verify correct results display
|
||||||
|
- [ ] Mark negative from scan results — verify source_model stored
|
||||||
|
|
||||||
|
**Step 2: Push**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push
|
||||||
|
```
|
||||||
@@ -8,6 +8,7 @@ import random
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from PyQt6.QtWidgets import (
|
from PyQt6.QtWidgets import (
|
||||||
@@ -318,6 +319,114 @@ class DatasetStatsDialog(QDialog):
|
|||||||
layout.addWidget(btns)
|
layout.addWidget(btns)
|
||||||
|
|
||||||
|
|
||||||
|
class HardNegativesDialog(QDialog):
|
||||||
|
"""View and manage hard negative training examples."""
|
||||||
|
|
||||||
|
def __init__(self, db: ProcessedDB, profile: str, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.setWindowTitle("Hard Negatives")
|
||||||
|
self.setMinimumSize(600, 400)
|
||||||
|
self._db = db
|
||||||
|
self._profile = profile
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
# Filter row
|
||||||
|
filter_row = QHBoxLayout()
|
||||||
|
filter_row.addWidget(QLabel("Filter model:"))
|
||||||
|
self._cmb_filter = QComboBox()
|
||||||
|
self._cmb_filter.addItem("(all)")
|
||||||
|
self._cmb_filter.currentIndexChanged.connect(self._apply_filter)
|
||||||
|
filter_row.addWidget(self._cmb_filter, 1)
|
||||||
|
layout.addLayout(filter_row)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
self._lbl_summary = QLabel()
|
||||||
|
layout.addWidget(self._lbl_summary)
|
||||||
|
|
||||||
|
# Table
|
||||||
|
self._table = QTableWidget(0, 4)
|
||||||
|
self._table.setHorizontalHeaderLabels(
|
||||||
|
["File", "Time", "Source Model", "ID"])
|
||||||
|
self._table.horizontalHeader().setSectionResizeMode(
|
||||||
|
0, QHeaderView.ResizeMode.Stretch)
|
||||||
|
self._table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
|
||||||
|
self._table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
||||||
|
self._table.setColumnHidden(3, True) # hide ID column
|
||||||
|
layout.addWidget(self._table)
|
||||||
|
|
||||||
|
# Buttons
|
||||||
|
btn_row = QHBoxLayout()
|
||||||
|
btn_delete = QPushButton("Delete Selected")
|
||||||
|
btn_delete.clicked.connect(self._delete_selected)
|
||||||
|
btn_row.addWidget(btn_delete)
|
||||||
|
btn_clear = QPushButton("Clear All")
|
||||||
|
btn_clear.clicked.connect(self._clear_all)
|
||||||
|
btn_row.addWidget(btn_clear)
|
||||||
|
btn_row.addStretch()
|
||||||
|
btn_close = QPushButton("Close")
|
||||||
|
btn_close.clicked.connect(self.close)
|
||||||
|
btn_row.addWidget(btn_close)
|
||||||
|
layout.addLayout(btn_row)
|
||||||
|
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
rows = self._db.get_hard_negatives(self._profile)
|
||||||
|
models = sorted(set(r["source_model"] for r in rows if r["source_model"]))
|
||||||
|
self._cmb_filter.blockSignals(True)
|
||||||
|
self._cmb_filter.clear()
|
||||||
|
self._cmb_filter.addItem("(all)")
|
||||||
|
for m in models:
|
||||||
|
self._cmb_filter.addItem(m)
|
||||||
|
self._cmb_filter.blockSignals(False)
|
||||||
|
|
||||||
|
self._table.setRowCount(len(rows))
|
||||||
|
for i, r in enumerate(rows):
|
||||||
|
self._table.setItem(i, 0, QTableWidgetItem(r["filename"]))
|
||||||
|
self._table.setItem(i, 1, QTableWidgetItem(f'{r["start_time"]:.1f}s'))
|
||||||
|
self._table.setItem(i, 2, QTableWidgetItem(r["source_model"]))
|
||||||
|
self._table.setItem(i, 3, QTableWidgetItem(str(r["id"])))
|
||||||
|
self._lbl_summary.setText(f"<b>{len(rows)}</b> hard negatives")
|
||||||
|
|
||||||
|
def _apply_filter(self):
|
||||||
|
model = self._cmb_filter.currentText()
|
||||||
|
for row in range(self._table.rowCount()):
|
||||||
|
if model == "(all)":
|
||||||
|
self._table.setRowHidden(row, False)
|
||||||
|
else:
|
||||||
|
src = self._table.item(row, 2).text()
|
||||||
|
self._table.setRowHidden(row, src != model)
|
||||||
|
|
||||||
|
def _delete_selected(self):
|
||||||
|
ids = []
|
||||||
|
for row in sorted(set(i.row() for i in self._table.selectedItems()), reverse=True):
|
||||||
|
if not self._table.isRowHidden(row):
|
||||||
|
ids.append(int(self._table.item(row, 3).text()))
|
||||||
|
if ids:
|
||||||
|
self._db.delete_hard_negatives_by_ids(ids)
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _clear_all(self):
|
||||||
|
all_rows = self._db.get_hard_negatives(self._profile)
|
||||||
|
model_filter = self._cmb_filter.currentText()
|
||||||
|
if model_filter != "(all)":
|
||||||
|
target = [r for r in all_rows if r["source_model"] == model_filter]
|
||||||
|
msg = f"Delete {len(target)} hard negatives for model '{model_filter}'?"
|
||||||
|
else:
|
||||||
|
target = all_rows
|
||||||
|
msg = f"Delete all {len(target)} hard negatives for profile '{self._profile}'?"
|
||||||
|
if not target:
|
||||||
|
return
|
||||||
|
reply = QMessageBox.question(
|
||||||
|
self, "Clear All", msg,
|
||||||
|
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
||||||
|
)
|
||||||
|
if reply == QMessageBox.StandardButton.Yes:
|
||||||
|
self._db.delete_hard_negatives_by_ids([r["id"] for r in target])
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
|
||||||
class TrainDialog(QDialog):
|
class TrainDialog(QDialog):
|
||||||
"""Dialog for configuring and launching classifier training."""
|
"""Dialog for configuring and launching classifier training."""
|
||||||
|
|
||||||
@@ -372,6 +481,20 @@ class TrainDialog(QDialog):
|
|||||||
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
|
self._chk_scan_exports.stateChanged.connect(lambda: self._debounce.start())
|
||||||
form.addRow("", self._chk_scan_exports)
|
form.addRow("", self._chk_scan_exports)
|
||||||
|
|
||||||
|
self._chk_hard_negatives = QCheckBox("Use hard negatives in training")
|
||||||
|
self._chk_hard_negatives.setChecked(True)
|
||||||
|
self._chk_hard_negatives.setToolTip(
|
||||||
|
"When unchecked, manually marked hard negatives are excluded from training.\n"
|
||||||
|
"Useful when training a new model type where old negatives may not apply.")
|
||||||
|
self._chk_hard_negatives.stateChanged.connect(lambda: self._debounce.start())
|
||||||
|
neg_row = QHBoxLayout()
|
||||||
|
neg_row.addWidget(self._chk_hard_negatives)
|
||||||
|
btn_manage_neg = QPushButton("Manage\u2026")
|
||||||
|
btn_manage_neg.setFixedWidth(80)
|
||||||
|
btn_manage_neg.clicked.connect(self._manage_negatives)
|
||||||
|
neg_row.addWidget(btn_manage_neg)
|
||||||
|
form.addRow("", neg_row)
|
||||||
|
|
||||||
# Video source directory (fallback for old DB rows without source_path)
|
# Video source directory (fallback for old DB rows without source_path)
|
||||||
self._txt_video_dir = QLineEdit(video_dir)
|
self._txt_video_dir = QLineEdit(video_dir)
|
||||||
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
|
||||||
@@ -427,6 +550,11 @@ class TrainDialog(QDialog):
|
|||||||
if d:
|
if d:
|
||||||
self._txt_video_dir.setText(d)
|
self._txt_video_dir.setText(d)
|
||||||
|
|
||||||
|
def _manage_negatives(self):
|
||||||
|
dlg = HardNegativesDialog(self._db, self._profile, parent=self)
|
||||||
|
dlg.exec()
|
||||||
|
self._debounce.start() # refresh stats after potential deletions
|
||||||
|
|
||||||
def _populate_folder_combos(self):
|
def _populate_folder_combos(self):
|
||||||
"""Rebuild positive/negative combo box items from DB stats."""
|
"""Rebuild positive/negative combo box items from DB stats."""
|
||||||
inc_scan = getattr(self, '_chk_scan_exports', None)
|
inc_scan = getattr(self, '_chk_scan_exports', None)
|
||||||
@@ -464,15 +592,18 @@ class TrainDialog(QDialog):
|
|||||||
return
|
return
|
||||||
neg_folder = self._cmb_negative.currentData() or ""
|
neg_folder = self._cmb_negative.currentData() or ""
|
||||||
inc_scan = self._chk_scan_exports.isChecked()
|
inc_scan = self._chk_scan_exports.isChecked()
|
||||||
|
use_neg = self._chk_hard_negatives.isChecked()
|
||||||
# First check without fallback to see if source_paths are sufficient
|
# First check without fallback to see if source_paths are sufficient
|
||||||
video_infos_no_fb = self._db.get_training_data(
|
video_infos_no_fb = self._db.get_training_data(
|
||||||
self._profile, folder, negative_folder=neg_folder,
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
video_infos = self._db.get_training_data(
|
video_infos = self._db.get_training_data(
|
||||||
self._profile, folder, negative_folder=neg_folder,
|
self._profile, folder, negative_folder=neg_folder,
|
||||||
fallback_video_dir=self._txt_video_dir.text(),
|
fallback_video_dir=self._txt_video_dir.text(),
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
# Show video dir field only when the fallback helps find extra videos
|
# Show video dir field only when the fallback helps find extra videos
|
||||||
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
needs_fallback = len(video_infos) > len(video_infos_no_fb) or len(video_infos_no_fb) == 0
|
||||||
@@ -526,6 +657,10 @@ class TrainDialog(QDialog):
|
|||||||
def include_scan_exports(self) -> bool:
|
def include_scan_exports(self) -> bool:
|
||||||
return self._chk_scan_exports.isChecked()
|
return self._chk_scan_exports.isChecked()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_hard_negatives(self) -> bool:
|
||||||
|
return self._chk_hard_negatives.isChecked()
|
||||||
|
|
||||||
|
|
||||||
class TrainWorker(QThread):
|
class TrainWorker(QThread):
|
||||||
"""Trains an audio classifier off the main thread."""
|
"""Trains an audio classifier off the main thread."""
|
||||||
@@ -629,6 +764,28 @@ class ScanResultsPanel(QWidget):
|
|||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _current_table(self) -> QTableWidget | None:
|
||||||
|
"""Return the QTableWidget from the active tab (unwrapping container)."""
|
||||||
|
w = self._tabs.currentWidget()
|
||||||
|
if isinstance(w, QTableWidget):
|
||||||
|
return w
|
||||||
|
if w is not None:
|
||||||
|
table = w.findChild(QTableWidget)
|
||||||
|
if table is not None:
|
||||||
|
return table
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _tab_table(self, index: int) -> QTableWidget | None:
|
||||||
|
"""Return the QTableWidget from a tab by index."""
|
||||||
|
w = self._tabs.widget(index)
|
||||||
|
if isinstance(w, QTableWidget):
|
||||||
|
return w
|
||||||
|
if w is not None:
|
||||||
|
table = w.findChild(QTableWidget)
|
||||||
|
if table is not None:
|
||||||
|
return table
|
||||||
|
return None
|
||||||
|
|
||||||
def load_for_file(self, filename: str, profile: str) -> None:
|
def load_for_file(self, filename: str, profile: str) -> None:
|
||||||
"""Load saved scan results from DB for a file."""
|
"""Load saved scan results from DB for a file."""
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
@@ -638,6 +795,7 @@ class ScanResultsPanel(QWidget):
|
|||||||
results = self._db.get_scan_results(filename, profile)
|
results = self._db.get_scan_results(filename, profile)
|
||||||
for model, rows in results.items():
|
for model, rows in results.items():
|
||||||
self._add_tab(model, rows)
|
self._add_tab(model, rows)
|
||||||
|
self._populate_version_combos()
|
||||||
|
|
||||||
def add_scan_results(self, model: str,
|
def add_scan_results(self, model: str,
|
||||||
regions: list[tuple[float, float, float]]) -> None:
|
regions: list[tuple[float, float, float]]) -> None:
|
||||||
@@ -650,6 +808,7 @@ class ScanResultsPanel(QWidget):
|
|||||||
self._tabs.removeTab(i)
|
self._tabs.removeTab(i)
|
||||||
break
|
break
|
||||||
self._add_tab(model, rows)
|
self._add_tab(model, rows)
|
||||||
|
self._populate_version_combos()
|
||||||
for i in range(self._tabs.count()):
|
for i in range(self._tabs.count()):
|
||||||
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
|
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
|
||||||
self._tabs.setCurrentIndex(i)
|
self._tabs.setCurrentIndex(i)
|
||||||
@@ -657,10 +816,23 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def _add_tab(self, model: str,
|
def _add_tab(self, model: str,
|
||||||
rows: list[tuple[int, float, float, float, bool, float, float]]) -> None:
|
rows: list[tuple[int, float, float, float, bool, float, float]]) -> None:
|
||||||
"""Create a table tab.
|
"""Create a table tab wrapped in a container with a version combo.
|
||||||
|
|
||||||
rows: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]
|
rows: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]
|
||||||
"""
|
"""
|
||||||
|
container = QWidget()
|
||||||
|
container_layout = QVBoxLayout(container)
|
||||||
|
container_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
container_layout.setSpacing(2)
|
||||||
|
|
||||||
|
cmb_version = QComboBox()
|
||||||
|
cmb_version.setMaximumWidth(260)
|
||||||
|
cmb_version.setToolTip("Scan version history")
|
||||||
|
cmb_version.hide() # Hidden when only 1 version
|
||||||
|
cmb_version.currentIndexChanged.connect(
|
||||||
|
lambda idx, m=model: self._on_version_changed(m, idx))
|
||||||
|
container_layout.addWidget(cmb_version)
|
||||||
|
|
||||||
table = QTableWidget(len(rows), 3)
|
table = QTableWidget(len(rows), 3)
|
||||||
table.setHorizontalHeaderLabels(["Time", "End", "Score"])
|
table.setHorizontalHeaderLabels(["Time", "End", "Score"])
|
||||||
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
||||||
@@ -706,7 +878,94 @@ class ScanResultsPanel(QWidget):
|
|||||||
lambda t=table: self._on_selection_changed(t))
|
lambda t=table: self._on_selection_changed(t))
|
||||||
table.cellChanged.connect(
|
table.cellChanged.connect(
|
||||||
lambda r, c, t=table: self._on_cell_changed(t, r, c))
|
lambda r, c, t=table: self._on_cell_changed(t, r, c))
|
||||||
self._tabs.addTab(table, f"{model} ({len(rows)})")
|
container_layout.addWidget(table)
|
||||||
|
self._tabs.addTab(container, f"{model} ({len(rows)})")
|
||||||
|
|
||||||
|
def _populate_version_combos(self) -> None:
|
||||||
|
"""Populate version combo boxes for all tabs from DB."""
|
||||||
|
for i in range(self._tabs.count()):
|
||||||
|
w = self._tabs.widget(i)
|
||||||
|
if w is None:
|
||||||
|
continue
|
||||||
|
cmb = w.findChild(QComboBox)
|
||||||
|
if cmb is None:
|
||||||
|
continue
|
||||||
|
model = self._tabs.tabText(i).rsplit(" (", 1)[0]
|
||||||
|
versions = self._db.get_scan_versions(
|
||||||
|
self._filename, self._profile, model)
|
||||||
|
cmb.blockSignals(True)
|
||||||
|
cmb.clear()
|
||||||
|
for v in versions:
|
||||||
|
ts = v["timestamp"]
|
||||||
|
# Parse timestamp to readable date string
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(ts[:15], "%Y%m%d_%H%M%S")
|
||||||
|
date_str = dt.strftime("%Y-%m-%d %H:%M")
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
date_str = ts
|
||||||
|
label = (f"{date_str}"
|
||||||
|
f" ({v['count']} regions, best: {v['max_score']:.2f})")
|
||||||
|
cmb.addItem(label, userData=ts)
|
||||||
|
cmb.blockSignals(False)
|
||||||
|
cmb.setVisible(cmb.count() > 1)
|
||||||
|
|
||||||
|
def _on_version_changed(self, model: str, idx: int) -> None:
|
||||||
|
"""Reload a tab's results when the user selects a different version."""
|
||||||
|
if idx < 0:
|
||||||
|
return
|
||||||
|
self._undo_stack.clear() # version context changed, old undo entries invalid
|
||||||
|
# Find the tab for this model
|
||||||
|
for i in range(self._tabs.count()):
|
||||||
|
if self._tabs.tabText(i).rsplit(" (", 1)[0] == model:
|
||||||
|
w = self._tabs.widget(i)
|
||||||
|
cmb = w.findChild(QComboBox) if w else None
|
||||||
|
if cmb is None:
|
||||||
|
return
|
||||||
|
ts = cmb.itemData(idx)
|
||||||
|
if ts is None:
|
||||||
|
return
|
||||||
|
results = self._db.get_scan_results(
|
||||||
|
self._filename, self._profile, scan_timestamp=ts)
|
||||||
|
rows = results.get(model, [])
|
||||||
|
# Replace the table contents
|
||||||
|
table = self._tab_table(i)
|
||||||
|
if table is None:
|
||||||
|
return
|
||||||
|
self._editing = True
|
||||||
|
table.setRowCount(len(rows))
|
||||||
|
red = QColor(220, 60, 60)
|
||||||
|
gray = QColor(100, 100, 100)
|
||||||
|
for r, (row_id, start, end, score, disabled, os_, oe) in enumerate(rows):
|
||||||
|
t_item = QTableWidgetItem(format_time(start))
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole, row_id)
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole + 1, start)
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole + 2, disabled)
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole + 3, os_)
|
||||||
|
t_item.setData(Qt.ItemDataRole.UserRole + 4, oe)
|
||||||
|
table.setItem(r, 0, t_item)
|
||||||
|
e_item = QTableWidgetItem(format_time(end))
|
||||||
|
e_item.setData(Qt.ItemDataRole.UserRole, end)
|
||||||
|
table.setItem(r, 1, e_item)
|
||||||
|
sc_item = QTableWidgetItem(f"{score:.2f}")
|
||||||
|
sc_item.setFlags(sc_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||||
|
table.setItem(r, 2, sc_item)
|
||||||
|
if disabled:
|
||||||
|
for col in range(3):
|
||||||
|
table.item(r, col).setForeground(gray)
|
||||||
|
elif start in self._neg_times:
|
||||||
|
for col in range(3):
|
||||||
|
table.item(r, col).setForeground(red)
|
||||||
|
self._editing = False
|
||||||
|
self._tabs.setTabText(i, f"{model} ({len(rows)})")
|
||||||
|
self.regions_edited.emit()
|
||||||
|
return
|
||||||
|
|
||||||
|
def current_model_name(self) -> str:
|
||||||
|
"""Return the model name of the currently active tab."""
|
||||||
|
idx = self._tabs.currentIndex()
|
||||||
|
if idx >= 0:
|
||||||
|
return self._tabs.tabText(idx).split(" (")[0]
|
||||||
|
return ""
|
||||||
|
|
||||||
def _on_selection_changed(self, table: QTableWidget) -> None:
|
def _on_selection_changed(self, table: QTableWidget) -> None:
|
||||||
items = table.selectedItems()
|
items = table.selectedItems()
|
||||||
@@ -735,7 +994,7 @@ class ScanResultsPanel(QWidget):
|
|||||||
self._editing = False
|
self._editing = False
|
||||||
return
|
return
|
||||||
# Record undo: (action, tab_index, row, col, old_value)
|
# Record undo: (action, tab_index, row, col, old_value)
|
||||||
tab_idx = self._tabs.indexOf(table)
|
tab_idx = self._tabs.indexOf(table.parent() or table)
|
||||||
self._undo_stack.append(("resize", tab_idx, row, col, float(old_val)))
|
self._undo_stack.append(("resize", tab_idx, row, col, float(old_val)))
|
||||||
# Update stored data
|
# Update stored data
|
||||||
self._editing = True
|
self._editing = True
|
||||||
@@ -755,8 +1014,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def toggle_disable_selected(self) -> None:
|
def toggle_disable_selected(self) -> None:
|
||||||
"""Toggle disabled state on selected rows."""
|
"""Toggle disabled state on selected rows."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
selected_rows = sorted({idx.row() for idx in table.selectedIndexes()})
|
selected_rows = sorted({idx.row() for idx in table.selectedIndexes()})
|
||||||
if not selected_rows:
|
if not selected_rows:
|
||||||
@@ -791,8 +1050,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def delete_selected(self) -> None:
|
def delete_selected(self) -> None:
|
||||||
"""Permanently delete selected rows from active tab and DB."""
|
"""Permanently delete selected rows from active tab and DB."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
rows_to_delete = sorted(
|
rows_to_delete = sorted(
|
||||||
{idx.row() for idx in table.selectedIndexes()}, reverse=True)
|
{idx.row() for idx in table.selectedIndexes()}, reverse=True)
|
||||||
@@ -810,8 +1069,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
def filter_by_threshold(self, threshold: float) -> None:
|
def filter_by_threshold(self, threshold: float) -> None:
|
||||||
"""Show/hide rows based on score threshold across all tabs."""
|
"""Show/hide rows based on score threshold across all tabs."""
|
||||||
for i in range(self._tabs.count()):
|
for i in range(self._tabs.count()):
|
||||||
table = self._tabs.widget(i)
|
table = self._tab_table(i)
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
continue
|
continue
|
||||||
visible = 0
|
visible = 0
|
||||||
for row in range(table.rowCount()):
|
for row in range(table.rowCount()):
|
||||||
@@ -844,8 +1103,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def current_regions_with_orig(self) -> list[tuple[float, float, float, float, float]]:
|
def current_regions_with_orig(self) -> list[tuple[float, float, float, float, float]]:
|
||||||
"""Return (start, end, score, orig_start, orig_end) for enabled, visible rows."""
|
"""Return (start, end, score, orig_start, orig_end) for enabled, visible rows."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return []
|
return []
|
||||||
regions = []
|
regions = []
|
||||||
for row in range(table.rowCount()):
|
for row in range(table.rowCount()):
|
||||||
@@ -870,8 +1129,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
def update_region_times(self, start_match: float, end_match: float,
|
def update_region_times(self, start_match: float, end_match: float,
|
||||||
new_start: float, new_end: float) -> None:
|
new_start: float, new_end: float) -> None:
|
||||||
"""Update the table row matching (start, end) with new times. Called from timeline drag."""
|
"""Update the table row matching (start, end) with new times. Called from timeline drag."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
for row in range(table.rowCount()):
|
for row in range(table.rowCount()):
|
||||||
item0 = table.item(row, 0)
|
item0 = table.item(row, 0)
|
||||||
@@ -881,7 +1140,7 @@ class ScanResultsPanel(QWidget):
|
|||||||
continue
|
continue
|
||||||
if abs(float(s) - start_match) < 0.01 and abs(float(e) - end_match) < 0.01:
|
if abs(float(s) - start_match) < 0.01 and abs(float(e) - end_match) < 0.01:
|
||||||
# Record undo
|
# Record undo
|
||||||
tab_idx = self._tabs.indexOf(table)
|
tab_idx = self._tabs.currentIndex()
|
||||||
self._undo_stack.append(("drag", tab_idx, row, float(s), float(e)))
|
self._undo_stack.append(("drag", tab_idx, row, float(s), float(e)))
|
||||||
# Update stored values
|
# Update stored values
|
||||||
self._editing = True
|
self._editing = True
|
||||||
@@ -898,8 +1157,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def _on_add_negatives(self) -> None:
|
def _on_add_negatives(self) -> None:
|
||||||
"""Toggle selected rows as hard negatives (red = negative, toggle off to remove)."""
|
"""Toggle selected rows as hard negatives (red = negative, toggle off to remove)."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
selected_rows = sorted({idx.row() for idx in table.selectedIndexes()})
|
selected_rows = sorted({idx.row() for idx in table.selectedIndexes()})
|
||||||
if not selected_rows:
|
if not selected_rows:
|
||||||
@@ -938,8 +1197,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
self.negatives_removed.emit(remove_times)
|
self.negatives_removed.emit(remove_times)
|
||||||
|
|
||||||
def _on_export(self) -> None:
|
def _on_export(self) -> None:
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
# _get_tab_regions already skips disabled; also skip negatives
|
# _get_tab_regions already skips disabled; also skip negatives
|
||||||
regions = [r for r in self._get_tab_regions(table) if r[0] not in self._neg_times]
|
regions = [r for r in self._get_tab_regions(table) if r[0] not in self._neg_times]
|
||||||
@@ -948,22 +1207,22 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
def current_regions(self) -> list[tuple[float, float, float]]:
|
def current_regions(self) -> list[tuple[float, float, float]]:
|
||||||
"""Return (start, end, score) for enabled rows in the active tab."""
|
"""Return (start, end, score) for enabled rows in the active tab."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return []
|
return []
|
||||||
return self._get_tab_regions(table)
|
return self._get_tab_regions(table)
|
||||||
|
|
||||||
def all_regions(self) -> list[tuple[float, float, float]]:
|
def all_regions(self) -> list[tuple[float, float, float]]:
|
||||||
"""Return (start, end, score) for ALL rows including disabled."""
|
"""Return (start, end, score) for ALL rows including disabled."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return []
|
return []
|
||||||
return self._get_tab_regions(table, include_disabled=True)
|
return self._get_tab_regions(table, include_disabled=True)
|
||||||
|
|
||||||
def highlight_time(self, t: float) -> None:
|
def highlight_time(self, t: float) -> None:
|
||||||
"""Select the row containing time t, scrolling to it."""
|
"""Select the row containing time t, scrolling to it."""
|
||||||
table = self._tabs.currentWidget()
|
table = self._current_table()
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
for row in range(table.rowCount()):
|
for row in range(table.rowCount()):
|
||||||
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
|
start = table.item(row, 0).data(Qt.ItemDataRole.UserRole + 1)
|
||||||
@@ -994,8 +1253,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
kind = action[0]
|
kind = action[0]
|
||||||
if kind == "disable":
|
if kind == "disable":
|
||||||
_, tab_idx, prev = action
|
_, tab_idx, prev = action
|
||||||
table = self._tabs.widget(tab_idx)
|
table = self._tab_table(tab_idx)
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
gray = QColor(100, 100, 100)
|
gray = QColor(100, 100, 100)
|
||||||
red = QColor(220, 60, 60)
|
red = QColor(220, 60, 60)
|
||||||
@@ -1021,8 +1280,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
elif kind == "resize":
|
elif kind == "resize":
|
||||||
_, tab_idx, row, col, old_val = action
|
_, tab_idx, row, col, old_val = action
|
||||||
table = self._tabs.widget(tab_idx)
|
table = self._tab_table(tab_idx)
|
||||||
if not isinstance(table, QTableWidget) or row >= table.rowCount():
|
if table is None or row >= table.rowCount():
|
||||||
return
|
return
|
||||||
self._editing = True
|
self._editing = True
|
||||||
if col == 0:
|
if col == 0:
|
||||||
@@ -1041,8 +1300,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
elif kind == "drag":
|
elif kind == "drag":
|
||||||
_, tab_idx, row, old_start, old_end = action
|
_, tab_idx, row, old_start, old_end = action
|
||||||
table = self._tabs.widget(tab_idx)
|
table = self._tab_table(tab_idx)
|
||||||
if not isinstance(table, QTableWidget) or row >= table.rowCount():
|
if table is None or row >= table.rowCount():
|
||||||
return
|
return
|
||||||
self._editing = True
|
self._editing = True
|
||||||
table.item(row, 0).setData(Qt.ItemDataRole.UserRole + 1, old_start)
|
table.item(row, 0).setData(Qt.ItemDataRole.UserRole + 1, old_start)
|
||||||
@@ -1057,8 +1316,8 @@ class ScanResultsPanel(QWidget):
|
|||||||
|
|
||||||
elif kind == "neg":
|
elif kind == "neg":
|
||||||
_, tab_idx, was_neg = action
|
_, tab_idx, was_neg = action
|
||||||
table = self._tabs.widget(tab_idx)
|
table = self._tab_table(tab_idx)
|
||||||
if not isinstance(table, QTableWidget):
|
if table is None:
|
||||||
return
|
return
|
||||||
add_back: list[float] = []
|
add_back: list[float] = []
|
||||||
remove_back: list[float] = []
|
remove_back: list[float] = []
|
||||||
@@ -3889,8 +4148,10 @@ class MainWindow(QMainWindow):
|
|||||||
if not self._file_path:
|
if not self._file_path:
|
||||||
return
|
return
|
||||||
filename = os.path.basename(self._file_path)
|
filename = os.path.basename(self._file_path)
|
||||||
|
source_model = self._scan_panel.current_model_name()
|
||||||
self._db.add_hard_negatives(filename, self._profile, times,
|
self._db.add_hard_negatives(filename, self._profile, times,
|
||||||
source_path=self._file_path)
|
source_path=self._file_path,
|
||||||
|
source_model=source_model)
|
||||||
self._timeline.set_scan_regions(
|
self._timeline.set_scan_regions(
|
||||||
self._scan_panel.current_regions_with_orig(),
|
self._scan_panel.current_regions_with_orig(),
|
||||||
neg_times=self._scan_panel._neg_times,
|
neg_times=self._scan_panel._neg_times,
|
||||||
@@ -4110,6 +4371,7 @@ class MainWindow(QMainWindow):
|
|||||||
embed_model = dlg.embed_model
|
embed_model = dlg.embed_model
|
||||||
video_dir = dlg.video_dir
|
video_dir = dlg.video_dir
|
||||||
inc_scan = dlg.include_scan_exports
|
inc_scan = dlg.include_scan_exports
|
||||||
|
use_neg = dlg.use_hard_negatives
|
||||||
if not pos_folder:
|
if not pos_folder:
|
||||||
self._show_status("No positive class selected")
|
self._show_status("No positive class selected")
|
||||||
return
|
return
|
||||||
@@ -4122,6 +4384,7 @@ class MainWindow(QMainWindow):
|
|||||||
self._profile, pos_folder, negative_folder=neg_folder,
|
self._profile, pos_folder, negative_folder=neg_folder,
|
||||||
fallback_video_dir=video_dir,
|
fallback_video_dir=video_dir,
|
||||||
include_scan_exports=inc_scan,
|
include_scan_exports=inc_scan,
|
||||||
|
use_hard_negatives=use_neg,
|
||||||
)
|
)
|
||||||
if not video_infos:
|
if not video_infos:
|
||||||
self._show_status("No training data found for this subprofile")
|
self._show_status("No training data found for this subprofile")
|
||||||
@@ -4409,8 +4672,11 @@ class MainWindow(QMainWindow):
|
|||||||
folder = self._txt_folder.text()
|
folder = self._txt_folder.text()
|
||||||
name = self._txt_name.text() or "clip"
|
name = self._txt_name.text() or "clip"
|
||||||
is_seq = self._cmb_format.currentText() == "WebP sequence"
|
is_seq = self._cmb_format.currentText() == "WebP sequence"
|
||||||
# Find the first counter whose group folder does not exist on disk.
|
# Start from the highest counter the DB knows about, so we never
|
||||||
self._export_counter = 1
|
# reuse a counter if the folder is temporarily empty / unmounted.
|
||||||
|
db_max = self._db.get_max_counter(folder, name) if self._db else 0
|
||||||
|
self._export_counter = max(1, db_max + 1)
|
||||||
|
# Then also skip any directories that exist on disk.
|
||||||
while True:
|
while True:
|
||||||
group_dir = os.path.join(folder, f"{name}_{self._export_counter:03d}")
|
group_dir = os.path.join(folder, f"{name}_{self._export_counter:03d}")
|
||||||
if not os.path.exists(group_dir):
|
if not os.path.exists(group_dir):
|
||||||
@@ -4482,7 +4748,8 @@ class MainWindow(QMainWindow):
|
|||||||
n_clips = self._spn_clips.value()
|
n_clips = self._spn_clips.value()
|
||||||
# For subprofile exports, calculate counter independently.
|
# For subprofile exports, calculate counter independently.
|
||||||
if folder_suffix:
|
if folder_suffix:
|
||||||
counter = 1
|
db_max_sub = self._db.get_max_counter(folder, name) if self._db else 0
|
||||||
|
counter = max(1, db_max_sub + 1)
|
||||||
while True:
|
while True:
|
||||||
if image_sequence:
|
if image_sequence:
|
||||||
p = build_sequence_dir(folder, name, counter, sub=0)
|
p = build_sequence_dir(folder, name, counter, sub=0)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ soundfile>=0.12
|
|||||||
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
torch>=2.0
|
torch>=2.0
|
||||||
torchaudio>=2.0
|
torchaudio>=2.0
|
||||||
|
transformers>=4.30
|
||||||
|
timm>=0.9
|
||||||
|
|
||||||
# Object detection
|
# Object detection
|
||||||
ultralytics>=8.0
|
ultralytics>=8.0
|
||||||
|
|||||||
@@ -25,6 +25,39 @@ def test_default_model_path_contains_profile():
|
|||||||
assert path.endswith(".joblib")
|
assert path.endswith(".joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_multi_layer():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
# Multi-layer models should report concatenated dimension
|
||||||
|
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||||
|
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||||
|
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||||
|
# Single-layer unchanged
|
||||||
|
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||||
|
|
||||||
|
|
||||||
|
def test_ml_config():
|
||||||
|
from core.audio_scan import _ml_config
|
||||||
|
assert _ml_config("HUBERT_XLARGE") is None
|
||||||
|
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||||
|
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||||
|
assert base == "HUBERT_XLARGE"
|
||||||
|
assert layers == [11, 23, 35, 47]
|
||||||
|
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||||
|
assert base == "HUBERT_BASE"
|
||||||
|
assert layers == [2, 5, 8, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
|
||||||
|
|
||||||
def test_db_get_all_export_paths():
|
def test_db_get_all_export_paths():
|
||||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
path = f.name
|
path = f.name
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_folders_excludes_scan_exports():
|
||||||
|
"""Scan-export-only folders should not appear when include_scan_exports=False."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Manual export
|
||||||
|
db.add("a.mp4", 10.0, "/out/mp4_Intense/g1/clip.mp4", profile="test")
|
||||||
|
# Scan export to different folder
|
||||||
|
db.add("a.mp4", 20.0, "/out/mp4_ScanOnly/g1/clip.mp4", profile="test",
|
||||||
|
scan_export=True)
|
||||||
|
folders = db.get_export_folders("test")
|
||||||
|
assert "mp4_Intense" in folders
|
||||||
|
assert "mp4_ScanOnly" not in folders, "scan-only folder should be excluded"
|
||||||
|
# With include_scan_exports=True, both should appear
|
||||||
|
folders_all = db.get_export_folders("test", include_scan_exports=True)
|
||||||
|
assert "mp4_ScanOnly" in folders_all
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_result_history():
|
||||||
|
"""save_scan_results should keep multiple versions."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Save three versions (microsecond-precision timestamps avoid collisions)
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(0, 8, 0.9)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(0, 8, 0.8), (10, 18, 0.7)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(5, 13, 0.95)])
|
||||||
|
versions = db.get_scan_versions("v.mp4", "test", "MODEL_A")
|
||||||
|
assert len(versions) == 3
|
||||||
|
# Most recent first
|
||||||
|
assert versions[0]["count"] == 1 # latest: 1 region
|
||||||
|
assert versions[1]["count"] == 2 # middle: 2 regions
|
||||||
|
assert versions[2]["count"] == 1 # oldest: 1 region
|
||||||
|
# get_scan_results returns latest version by default
|
||||||
|
results = db.get_scan_results("v.mp4", "test")
|
||||||
|
assert len(results.get("MODEL_A", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hard_negatives_source_model():
|
||||||
|
"""Hard negatives should store source_model."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||||
|
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_data_skips_hard_negatives():
|
||||||
|
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Create a source file that "exists" — use the temp db file itself
|
||||||
|
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||||
|
source_path=path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||||
|
# With hard negatives
|
||||||
|
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||||
|
# Without hard negatives
|
||||||
|
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||||
|
assert len(data_with) >= 1
|
||||||
|
# The "with" case should have the hard negative time in neg list
|
||||||
|
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||||
|
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||||
|
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hard_negatives_by_ids():
|
||||||
|
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||||
|
source_path="/a.mp4")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 3
|
||||||
|
# Delete first two
|
||||||
|
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||||
|
remaining = db.get_hard_negatives("test")
|
||||||
|
assert len(remaining) == 1
|
||||||
|
assert remaining[0]["start_time"] == 30.0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
Reference in New Issue
Block a user