Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1bdeb33a6f | |||
| 387ed7bc6a | |||
| f268d61fe4 | |||
| 24db32c09f | |||
| 0f6ae88ea6 | |||
| 4d99cf6015 | |||
| b75fa85ff5 | |||
| e7d47331c6 | |||
| 7cd31ebe55 | |||
| 3a37dddfd9 | |||
| b249705506 | |||
| aaf405dd3d | |||
| cb2060beb8 | |||
| 0db412baf4 | |||
| 876026d1f6 | |||
| 6c1d42adfe | |||
| d8b3972bdc | |||
| bd345abca2 | |||
| 7d6fee9df1 | |||
| fd043f4172 | |||
| 3c3b1d74bb | |||
| a3c657c66e | |||
| 5d45b8d8eb | |||
| e6db83f00b | |||
| edc5784ba6 | |||
| 8ed9fbf557 | |||
| 4fb2ae144f | |||
| 2614a765d5 | |||
| c020c0dfec | |||
| e7b791fbfa | |||
| f5361a963e | |||
| 8fb8581816 | |||
| 5b25e85e98 | |||
| e3f133ef84 | |||
| 4736f150b0 | |||
| 52aa982aa2 | |||
| 07457d0d6f | |||
| c5d613fc5f |
+166
-18
@@ -45,6 +45,7 @@ os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
|||||||
_w2v_model = None
|
_w2v_model = None
|
||||||
_w2v_device = None
|
_w2v_device = None
|
||||||
_w2v_model_name = None
|
_w2v_model_name = None
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
|
||||||
# Supported embedding models — name → embed_dim
|
# Supported embedding models — name → embed_dim
|
||||||
_EMBED_MODELS = {
|
_EMBED_MODELS = {
|
||||||
@@ -55,6 +56,16 @@ _EMBED_MODELS = {
|
|||||||
"HUBERT_LARGE": 1024,
|
"HUBERT_LARGE": 1024,
|
||||||
"HUBERT_XLARGE": 1280,
|
"HUBERT_XLARGE": 1280,
|
||||||
"BEATS": 768,
|
"BEATS": 768,
|
||||||
|
# Multi-layer variants (4 quartile layers concatenated)
|
||||||
|
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
"EAT": 768,
|
||||||
|
"EAT_LARGE": 1024,
|
||||||
}
|
}
|
||||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
@@ -70,11 +81,14 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
global _w2v_model, _w2v_device, _w2v_model_name
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
model_name = _DEFAULT_EMBED_MODEL
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
if _w2v_model is None or _w2v_model_name != model_name:
|
# Multi-layer variants use the same base model weights
|
||||||
|
ml = _ml_config(model_name)
|
||||||
|
load_name = ml[0] if ml else model_name
|
||||||
|
if _w2v_model is None or _w2v_model_name != load_name:
|
||||||
import torch
|
import torch
|
||||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
if model_name == "BEATS":
|
if load_name == "BEATS":
|
||||||
from .beats_model import BEATs, BEATsConfig
|
from .beats_model import BEATs, BEATsConfig
|
||||||
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||||
weights_only=False)
|
weights_only=False)
|
||||||
@@ -82,17 +96,63 @@ def _get_w2v_model(model_name: str | None = None):
|
|||||||
_w2v_model = BEATs(cfg)
|
_w2v_model = BEATs(cfg)
|
||||||
_w2v_model.load_state_dict(checkpoint['model'])
|
_w2v_model.load_state_dict(checkpoint['model'])
|
||||||
_w2v_model.to(_w2v_device)
|
_w2v_model.to(_w2v_device)
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
|
elif load_name in ("EAT", "EAT_LARGE"):
|
||||||
|
from transformers import AutoModel
|
||||||
|
eat_repo = ("worstchan/EAT-large_epoch20_finetune_AS2M"
|
||||||
|
if load_name == "EAT_LARGE"
|
||||||
|
else "worstchan/EAT-base_epoch30_finetune_AS2M")
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
eat_repo, trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
else:
|
else:
|
||||||
import torchaudio
|
import torchaudio
|
||||||
bundle = getattr(torchaudio.pipelines, model_name)
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
|
||||||
_w2v_model.eval()
|
_w2v_model.eval()
|
||||||
_w2v_model_name = model_name
|
_w2v_model_name = load_name
|
||||||
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
|
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||||
return _w2v_model, _w2v_device
|
return _w2v_model, _w2v_device
|
||||||
|
|
||||||
|
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
|
||||||
|
|
||||||
def _embed_dim(model_name: str | None = None) -> int:
|
def _embed_dim(model_name: str | None = None) -> int:
|
||||||
"""Return embedding dimension for a model name."""
|
"""Return embedding dimension for a model name."""
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
@@ -100,6 +160,31 @@ def _embed_dim(model_name: str | None = None) -> int:
|
|||||||
return _EMBED_MODELS.get(model_name, 768)
|
return _EMBED_MODELS.get(model_name, 768)
|
||||||
|
|
||||||
|
|
||||||
|
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||||
|
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||||
|
|
||||||
|
Returns None for single-layer models.
|
||||||
|
Layer indices are 0-based into the list returned by extract_features().
|
||||||
|
"""
|
||||||
|
if not model_name.endswith("_ML"):
|
||||||
|
return None
|
||||||
|
base = model_name[:-3] # strip "_ML"
|
||||||
|
if base not in _EMBED_MODELS:
|
||||||
|
return None
|
||||||
|
# Layer counts per model family
|
||||||
|
layer_counts = {
|
||||||
|
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||||
|
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||||
|
"AST": 12,
|
||||||
|
}
|
||||||
|
n = layer_counts.get(base)
|
||||||
|
if n is None:
|
||||||
|
return None
|
||||||
|
# Select 4 layers at quartile boundaries (0-indexed)
|
||||||
|
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||||
|
return base, indices
|
||||||
|
|
||||||
|
|
||||||
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||||
model_name: str | None = None) -> str:
|
model_name: str | None = None) -> str:
|
||||||
"""Return cache file path for a video's embeddings (includes model name)."""
|
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||||
@@ -171,6 +256,9 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
import torch
|
import torch
|
||||||
model, device = _get_w2v_model(model_name)
|
model, device = _get_w2v_model(model_name)
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
# Auto-size batches based on available GPU memory
|
# Auto-size batches based on available GPU memory
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
@@ -195,13 +283,36 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
start = i * hop_samples
|
start = i * hop_samples
|
||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
if is_ast:
|
||||||
if is_beats:
|
inputs = _ast_feature_extractor(
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
else:
|
else:
|
||||||
features, _ = model(waveforms)
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
embeddings.append(batch_emb)
|
embeddings.append(batch_emb)
|
||||||
|
|
||||||
result_ts = timestamps
|
result_ts = timestamps
|
||||||
@@ -274,6 +385,9 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
embeddings_list: list[np.ndarray] = []
|
embeddings_list: list[np.ndarray] = []
|
||||||
|
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_times), batch_size):
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
batch_end = min(batch_start + batch_size, len(valid_times))
|
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||||
@@ -283,13 +397,36 @@ def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
|||||||
chunks.append(y[start:start + win_samples])
|
chunks.append(y[start:start + win_samples])
|
||||||
timestamps_list.append(float(t))
|
timestamps_list.append(float(t))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
if is_ast:
|
||||||
if is_beats:
|
inputs = _ast_feature_extractor(
|
||||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
else:
|
else:
|
||||||
features, _ = model(waveforms)
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
embeddings_list.append(batch_emb)
|
embeddings_list.append(batch_emb)
|
||||||
|
|
||||||
timestamps = np.array(timestamps_list)
|
timestamps = np.array(timestamps_list)
|
||||||
@@ -428,6 +565,17 @@ def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
|||||||
clf.fit(X[train_idx], y_arr[train_idx])
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
_log("audio_scan: classifier trained")
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
# Calibrate probabilities for better threshold behavior
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
min_class = min(int(n_pos), int(n_neg_sample))
|
||||||
|
if min_class >= 6:
|
||||||
|
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||||
|
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
clf = cal_clf
|
||||||
|
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||||
|
|
||||||
model = {"classifier": clf, "n_features": X.shape[1],
|
model = {"classifier": clf, "n_features": X.shape[1],
|
||||||
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||||
|
|
||||||
@@ -589,7 +737,7 @@ def prefetch_audio(video_path: str, embed_model: str | None = None,
|
|||||||
def scan_video(
|
def scan_video(
|
||||||
video_path: str,
|
video_path: str,
|
||||||
model: dict = None,
|
model: dict = None,
|
||||||
threshold: float = 0.30,
|
threshold: float = 0.50,
|
||||||
hop: float = 1.0,
|
hop: float = 1.0,
|
||||||
window: float = _WINDOW,
|
window: float = _WINDOW,
|
||||||
cancel_flag: object = None,
|
cancel_flag: object = None,
|
||||||
|
|||||||
+326
-46
@@ -94,7 +94,8 @@ class ProcessedDB:
|
|||||||
" score REAL NOT NULL,"
|
" score REAL NOT NULL,"
|
||||||
" disabled INTEGER NOT NULL DEFAULT 0,"
|
" disabled INTEGER NOT NULL DEFAULT 0,"
|
||||||
" orig_start_time REAL,"
|
" orig_start_time REAL,"
|
||||||
" orig_end_time REAL"
|
" orig_end_time REAL,"
|
||||||
|
" scan_timestamp TEXT NOT NULL DEFAULT ''"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
# Migrate: add new columns to existing scan_results tables
|
# Migrate: add new columns to existing scan_results tables
|
||||||
@@ -106,6 +107,7 @@ class ProcessedDB:
|
|||||||
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
||||||
("orig_start_time", "REAL"),
|
("orig_start_time", "REAL"),
|
||||||
("orig_end_time", "REAL"),
|
("orig_end_time", "REAL"),
|
||||||
|
("scan_timestamp", "TEXT NOT NULL DEFAULT ''"),
|
||||||
]:
|
]:
|
||||||
if col not in sr_cols:
|
if col not in sr_cols:
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
@@ -117,18 +119,114 @@ class ProcessedDB:
|
|||||||
)
|
)
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS hard_negatives ("
|
"CREATE TABLE IF NOT EXISTS hard_negatives ("
|
||||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||||
" filename TEXT NOT NULL,"
|
" filename TEXT NOT NULL,"
|
||||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
" start_time REAL NOT NULL,"
|
" start_time REAL NOT NULL,"
|
||||||
" source_path TEXT NOT NULL DEFAULT ''"
|
" source_path TEXT NOT NULL DEFAULT '',"
|
||||||
|
" source_model TEXT NOT NULL DEFAULT ''"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
# Migrate: add source_model column to existing hard_negatives tables
|
||||||
|
hn_cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||||
|
}
|
||||||
|
if "source_model" not in hn_cols:
|
||||||
|
self._con.execute(
|
||||||
|
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||||
|
)
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||||
" ON hard_negatives(filename, profile)"
|
" ON hard_negatives(filename, profile)"
|
||||||
)
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
self._migrate_vid_folders()
|
||||||
|
|
||||||
|
def _migrate_vid_folders(self) -> None:
|
||||||
|
"""Migrate old clip_NNN group dirs → vid_NNN per-video folders.
|
||||||
|
|
||||||
|
Old layout: export_folder/clip_NNN/clip_NNN_sub.mp4
|
||||||
|
New layout: export_folder/vid_NNN/clip_NNN_sub.mp4
|
||||||
|
|
||||||
|
Rewrites output_path in DB and moves files on disk.
|
||||||
|
"""
|
||||||
|
# Check if any rows still use the old clip_NNN parent dir layout
|
||||||
|
row = self._con.execute(
|
||||||
|
"SELECT id FROM processed WHERE output_path LIKE '%/clip_%/%' LIMIT 1"
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return
|
||||||
|
|
||||||
|
_log("Migrating old clip group dirs → vid folders …")
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, filename, profile, output_path FROM processed"
|
||||||
|
" ORDER BY profile, filename, output_path"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
# Assign vid_NNN per (profile, export_folder, filename)
|
||||||
|
vid_map: dict[tuple, str] = {}
|
||||||
|
vid_counters: dict[tuple, int] = {}
|
||||||
|
|
||||||
|
for rid, filename, profile, op in rows:
|
||||||
|
parent = os.path.dirname(op)
|
||||||
|
export_folder = os.path.dirname(parent)
|
||||||
|
key = (profile, export_folder, filename)
|
||||||
|
if key not in vid_map:
|
||||||
|
counter_key = (profile, export_folder)
|
||||||
|
n = vid_counters.get(counter_key, 1)
|
||||||
|
vid_map[key] = f"vid_{n:03d}"
|
||||||
|
vid_counters[counter_key] = n + 1
|
||||||
|
|
||||||
|
updates: list[tuple[str, int]] = []
|
||||||
|
moves: list[tuple[str, str]] = []
|
||||||
|
dirs_to_create: set[str] = set()
|
||||||
|
old_dirs: set[str] = set()
|
||||||
|
|
||||||
|
for rid, filename, profile, op in rows:
|
||||||
|
parent = os.path.dirname(op)
|
||||||
|
parent_name = os.path.basename(parent)
|
||||||
|
# Skip rows already using vid_NNN layout
|
||||||
|
if parent_name.startswith("vid_"):
|
||||||
|
continue
|
||||||
|
export_folder = os.path.dirname(parent)
|
||||||
|
key = (profile, export_folder, filename)
|
||||||
|
vid_name = vid_map[key]
|
||||||
|
new_path = os.path.join(export_folder, vid_name, os.path.basename(op))
|
||||||
|
updates.append((new_path, rid))
|
||||||
|
dirs_to_create.add(os.path.join(export_folder, vid_name))
|
||||||
|
old_dirs.add(parent)
|
||||||
|
if os.path.exists(op):
|
||||||
|
moves.append((op, new_path))
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create vid directories
|
||||||
|
for d in sorted(dirs_to_create):
|
||||||
|
os.makedirs(d, exist_ok=True)
|
||||||
|
|
||||||
|
# Move files
|
||||||
|
import shutil
|
||||||
|
for old, new in moves:
|
||||||
|
if os.path.exists(old) and not os.path.exists(new):
|
||||||
|
shutil.move(old, new)
|
||||||
|
|
||||||
|
# Update DB
|
||||||
|
self._con.executemany(
|
||||||
|
"UPDATE processed SET output_path = ? WHERE id = ?", updates
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
# Remove empty old group directories
|
||||||
|
for d in sorted(old_dirs, reverse=True):
|
||||||
|
try:
|
||||||
|
if os.path.isdir(d) and not os.listdir(d):
|
||||||
|
os.rmdir(d)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_log(f"Migrated {len(updates)} rows, moved {len(moves)} files to vid folders")
|
||||||
|
|
||||||
def add(self, filename: str, start_time: float, output_path: str,
|
def add(self, filename: str, start_time: float, output_path: str,
|
||||||
label: str = "", category: str = "",
|
label: str = "", category: str = "",
|
||||||
@@ -291,20 +389,118 @@ class ProcessedDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return [r[0] for r in rows]
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
def get_export_folders(self, profile: str = "default") -> list[str]:
|
def get_max_counter(self, folder: str, name: str) -> int:
|
||||||
|
"""Return the highest counter N found in output_paths matching folder/name_NNN*.
|
||||||
|
|
||||||
|
Parses the counter from filenames (e.g. 'clip_035_0.mp4' → 35).
|
||||||
|
*folder* is typically the vid folder. Returns 0 if no matches exist.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return 0
|
||||||
|
prefix = os.path.join(folder, name + "_")
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed"
|
||||||
|
" WHERE output_path LIKE ?",
|
||||||
|
(prefix + "%",),
|
||||||
|
).fetchall()
|
||||||
|
max_n = 0
|
||||||
|
name_prefix = name + "_"
|
||||||
|
for (op,) in rows:
|
||||||
|
stem = os.path.splitext(os.path.basename(op))[0]
|
||||||
|
# stem: "clip_035_0" or "clip_036_a1_0"
|
||||||
|
if not stem.startswith(name_prefix):
|
||||||
|
continue
|
||||||
|
rest = stem[len(name_prefix):] # "035_0" or "036_a1_0"
|
||||||
|
counter_str = rest.split("_")[0]
|
||||||
|
try:
|
||||||
|
max_n = max(max_n, int(counter_str))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return max_n
|
||||||
|
|
||||||
|
def delete_scan_exports(self, filename: str, profile: str) -> int:
|
||||||
|
"""Delete all scan_export entries for *filename* in *profile*.
|
||||||
|
|
||||||
|
Returns the number of rows deleted.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return 0
|
||||||
|
cur = self._con.execute(
|
||||||
|
"DELETE FROM processed"
|
||||||
|
" WHERE filename = ? AND profile = ? AND scan_export = 1",
|
||||||
|
(filename, profile),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
return cur.rowcount
|
||||||
|
|
||||||
|
def get_vid_folder(self, filename: str, profile: str,
|
||||||
|
export_folder: str) -> str:
|
||||||
|
"""Return the vid_NNN folder name for a source video.
|
||||||
|
|
||||||
|
Checks existing DB output_paths first; if the video already has a
|
||||||
|
vid_NNN folder, returns it. Otherwise assigns max(existing) + 1,
|
||||||
|
also checking disk for orphan vid folders.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return "vid_001"
|
||||||
|
# Use the most recent entry (ORDER BY rowid DESC) for determinism
|
||||||
|
# when a file has entries across multiple vid folders.
|
||||||
|
row = self._con.execute(
|
||||||
|
"SELECT output_path FROM processed"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" ORDER BY rowid DESC LIMIT 1",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchone()
|
||||||
|
if row:
|
||||||
|
parent = os.path.basename(os.path.dirname(row[0]))
|
||||||
|
if parent.startswith("vid_"):
|
||||||
|
return parent
|
||||||
|
# Collect max vid_NNN number from DB + disk (never reuse old numbers)
|
||||||
|
max_n = 0
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
for (op,) in rows:
|
||||||
|
p = os.path.basename(os.path.dirname(op))
|
||||||
|
if p.startswith("vid_"):
|
||||||
|
try:
|
||||||
|
max_n = max(max_n, int(p.split("_")[1]))
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
pass
|
||||||
|
if os.path.isdir(export_folder):
|
||||||
|
for d in os.listdir(export_folder):
|
||||||
|
if d.startswith("vid_") and os.path.isdir(
|
||||||
|
os.path.join(export_folder, d)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
max_n = max(max_n, int(d.split("_")[1]))
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
pass
|
||||||
|
return f"vid_{max_n + 1:03d}"
|
||||||
|
|
||||||
|
def get_export_folders(self, profile: str = "default",
|
||||||
|
include_scan_exports: bool = False) -> list[str]:
|
||||||
"""Return distinct export folder names found in output_paths for a profile.
|
"""Return distinct export folder names found in output_paths for a profile.
|
||||||
|
|
||||||
Export paths follow the structure:
|
Export paths follow the structure:
|
||||||
.../export_folder/group_dir/clip.mp4
|
.../export_folder/vid_NNN/clip.mp4
|
||||||
The export folder is 2 levels up from the clip file.
|
The export folder is 2 levels up from the clip file.
|
||||||
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
|
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return []
|
return []
|
||||||
rows = self._con.execute(
|
if include_scan_exports:
|
||||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
rows = self._con.execute(
|
||||||
(profile,),
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
).fetchall()
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed"
|
||||||
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
folder_names: set[str] = set()
|
folder_names: set[str] = set()
|
||||||
for (op,) in rows:
|
for (op,) in rows:
|
||||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
@@ -316,6 +512,7 @@ class ProcessedDB:
|
|||||||
negative_folder: str = "",
|
negative_folder: str = "",
|
||||||
fallback_video_dir: str = "",
|
fallback_video_dir: str = "",
|
||||||
include_scan_exports: bool = False,
|
include_scan_exports: bool = False,
|
||||||
|
use_hard_negatives: bool = True,
|
||||||
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
"""Build training video_infos from DB data.
|
"""Build training video_infos from DB data.
|
||||||
|
|
||||||
@@ -325,6 +522,7 @@ class ProcessedDB:
|
|||||||
negative_folder: export folder name for explicit negatives (optional)
|
negative_folder: export folder name for explicit negatives (optional)
|
||||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||||
include_scan_exports: if True, include auto-exported scan clips
|
include_scan_exports: if True, include auto-exported scan clips
|
||||||
|
use_hard_negatives: if False, skip hard negatives from scan feedback
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of (source_video_path, positive_times, soft_times, negative_times)
|
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||||
@@ -363,15 +561,16 @@ class ProcessedDB:
|
|||||||
soft_by_video.setdefault(fn, set()).add(st)
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
# Include hard negatives from scan feedback
|
# Include hard negatives from scan feedback
|
||||||
hard_rows = self._con.execute(
|
if use_hard_negatives:
|
||||||
"SELECT filename, start_time, source_path FROM hard_negatives"
|
hard_rows = self._con.execute(
|
||||||
" WHERE profile = ?",
|
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||||
(profile,),
|
" WHERE profile = ?",
|
||||||
).fetchall()
|
(profile,),
|
||||||
for fn, st, sp in hard_rows:
|
).fetchall()
|
||||||
neg_by_video.setdefault(fn, set()).add(st)
|
for fn, st, sp in hard_rows:
|
||||||
if sp:
|
neg_by_video.setdefault(fn, set()).add(st)
|
||||||
source_by_filename.setdefault(fn, sp)
|
if sp:
|
||||||
|
source_by_filename.setdefault(fn, sp)
|
||||||
|
|
||||||
# Remove positive times from soft/neg to avoid conflicting labels
|
# Remove positive times from soft/neg to avoid conflicting labels
|
||||||
for fn in pos_by_video:
|
for fn in pos_by_video:
|
||||||
@@ -429,7 +628,7 @@ class ProcessedDB:
|
|||||||
" WHERE profile = ? AND scan_export = 0",
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
(profile,),
|
(profile,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
folders = self.get_export_folders(profile)
|
folders = self.get_export_folders(profile, include_scan_exports=include_scan_exports)
|
||||||
stats: dict[str, dict] = {}
|
stats: dict[str, dict] = {}
|
||||||
for folder_name in folders:
|
for folder_name in folders:
|
||||||
videos: set[str] = set()
|
videos: set[str] = set()
|
||||||
@@ -440,50 +639,105 @@ class ProcessedDB:
|
|||||||
videos.add(fn)
|
videos.add(fn)
|
||||||
clips += 1
|
clips += 1
|
||||||
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||||
return stats
|
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||||
|
|
||||||
# ── Scan results ─────────────────────────────────────────────
|
# ── Scan results ─────────────────────────────────────────────
|
||||||
|
|
||||||
def save_scan_results(self, filename: str, profile: str, model: str,
|
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||||
regions: list[tuple[float, float, float]]) -> None:
|
regions: list[tuple[float, float, float]],
|
||||||
"""Replace scan results for (filename, profile, model) with new regions.
|
max_versions: int = 5) -> None:
|
||||||
|
"""Save scan results as a new version for (filename, profile, model).
|
||||||
|
|
||||||
regions: list of (start_time, end_time, score).
|
regions: list of (start_time, end_time, score).
|
||||||
|
Keeps up to max_versions; oldest are pruned automatically.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._con.execute(
|
|
||||||
"DELETE FROM scan_results"
|
|
||||||
" WHERE filename = ? AND profile = ? AND model = ?",
|
|
||||||
(filename, profile, model),
|
|
||||||
)
|
|
||||||
self._con.executemany(
|
self._con.executemany(
|
||||||
"INSERT INTO scan_results"
|
"INSERT INTO scan_results"
|
||||||
" (filename, profile, model, start_time, end_time, score,"
|
" (filename, profile, model, start_time, end_time, score,"
|
||||||
" orig_start_time, orig_end_time)"
|
" orig_start_time, orig_end_time, scan_timestamp)"
|
||||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
[(filename, profile, model, s, e, sc, s, e) for s, e, sc in regions],
|
[(filename, profile, model, s, e, sc, s, e, ts)
|
||||||
|
for s, e, sc in regions],
|
||||||
)
|
)
|
||||||
|
# Prune old versions beyond max_versions
|
||||||
|
versions = self._con.execute(
|
||||||
|
"SELECT DISTINCT scan_timestamp FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
if len(versions) > max_versions:
|
||||||
|
old_ts = [v[0] for v in versions[max_versions:]]
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
f" AND scan_timestamp IN ({','.join('?' * len(old_ts))})",
|
||||||
|
(filename, profile, model, *old_ts),
|
||||||
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def get_scan_results(self, filename: str, profile: str
|
def get_scan_versions(self, filename: str, profile: str, model: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Return list of scan versions for (filename, profile, model).
|
||||||
|
|
||||||
|
Returns [{timestamp, count, max_score}, ...] ordered newest first.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT scan_timestamp, COUNT(*), MAX(score)"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||||
|
" AND scan_timestamp != ''"
|
||||||
|
" GROUP BY scan_timestamp"
|
||||||
|
" ORDER BY scan_timestamp DESC",
|
||||||
|
(filename, profile, model),
|
||||||
|
).fetchall()
|
||||||
|
return [{"timestamp": ts, "count": cnt, "max_score": sc}
|
||||||
|
for ts, cnt, sc in rows]
|
||||||
|
|
||||||
|
def get_scan_results(self, filename: str, profile: str,
|
||||||
|
scan_timestamp: str | None = None
|
||||||
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
||||||
"""Return scan results grouped by model.
|
"""Return scan results grouped by model.
|
||||||
|
|
||||||
|
If scan_timestamp is given, returns only that version's rows.
|
||||||
|
Otherwise returns the latest version per model.
|
||||||
|
|
||||||
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
||||||
sorted by start_time.
|
sorted by start_time.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return {}
|
return {}
|
||||||
rows = self._con.execute(
|
if scan_timestamp:
|
||||||
"SELECT id, model, start_time, end_time, score, disabled,"
|
rows = self._con.execute(
|
||||||
" orig_start_time, orig_end_time"
|
"SELECT id, model, start_time, end_time, score, disabled,"
|
||||||
" FROM scan_results"
|
" orig_start_time, orig_end_time"
|
||||||
" WHERE filename = ? AND profile = ?"
|
" FROM scan_results"
|
||||||
" ORDER BY model, start_time",
|
" WHERE filename = ? AND profile = ? AND scan_timestamp = ?"
|
||||||
(filename, profile),
|
" ORDER BY model, start_time",
|
||||||
).fetchall()
|
(filename, profile, scan_timestamp),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
# For each model, get rows from the latest timestamp only
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT r.id, r.model, r.start_time, r.end_time, r.score,"
|
||||||
|
" r.disabled, r.orig_start_time, r.orig_end_time"
|
||||||
|
" FROM scan_results r"
|
||||||
|
" INNER JOIN ("
|
||||||
|
" SELECT model, MAX(scan_timestamp) AS latest"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" GROUP BY model"
|
||||||
|
" ) m ON r.model = m.model AND r.scan_timestamp = m.latest"
|
||||||
|
" WHERE r.filename = ? AND r.profile = ?"
|
||||||
|
" ORDER BY r.model, r.start_time",
|
||||||
|
(filename, profile, filename, profile),
|
||||||
|
).fetchall()
|
||||||
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
||||||
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
||||||
# Fall back to current bounds for legacy rows without orig
|
# Fall back to current bounds for legacy rows without orig
|
||||||
@@ -546,16 +800,18 @@ class ProcessedDB:
|
|||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
def add_hard_negatives(self, filename: str, profile: str,
|
def add_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float], source_path: str = "") -> None:
|
times: list[float], source_path: str = "",
|
||||||
|
source_model: str = "") -> None:
|
||||||
"""Save timestamps as hard-negative training examples."""
|
"""Save timestamps as hard-negative training examples."""
|
||||||
if not self._enabled or not times:
|
if not self._enabled or not times:
|
||||||
return
|
return
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for t in times:
|
for t in times:
|
||||||
self._con.execute(
|
self._con.execute(
|
||||||
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)"
|
"INSERT INTO hard_negatives"
|
||||||
" VALUES (?, ?, ?, ?)",
|
" (filename, profile, start_time, source_path, source_model)"
|
||||||
(filename, profile, t, source_path),
|
" VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(filename, profile, t, source_path, source_model),
|
||||||
)
|
)
|
||||||
self._con.commit()
|
self._con.commit()
|
||||||
|
|
||||||
@@ -570,6 +826,30 @@ class ProcessedDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return {r[0] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||||
|
"""Return all hard negatives for a profile with full details."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, filename, start_time, source_path, source_model"
|
||||||
|
" FROM hard_negatives WHERE profile = ?"
|
||||||
|
" ORDER BY filename, start_time",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||||
|
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||||
|
|
||||||
|
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||||
|
"""Delete hard negatives by row IDs."""
|
||||||
|
if not self._enabled or not ids:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||||
|
ids,
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
def remove_hard_negatives(self, filename: str, profile: str,
|
def remove_hard_negatives(self, filename: str, profile: str,
|
||||||
times: list[float]) -> None:
|
times: list[float]) -> None:
|
||||||
"""Remove specific hard-negative timestamps."""
|
"""Remove specific hard-negative timestamps."""
|
||||||
|
|||||||
+10
-1
@@ -128,7 +128,16 @@ def build_ffmpeg_command(
|
|||||||
os.path.join(output_path, "frame_%04d.webp"),
|
os.path.join(output_path, "frame_%04d.webp"),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
cmd += ["-c:v", encoder]
|
||||||
|
if "nvenc" in encoder:
|
||||||
|
cmd += ["-preset", "p4", "-cq", "28"]
|
||||||
|
elif "vaapi" in encoder:
|
||||||
|
cmd += ["-qp", "28"]
|
||||||
|
elif "qsv" in encoder:
|
||||||
|
cmd += ["-global_quality", "28"]
|
||||||
|
elif "amf" in encoder:
|
||||||
|
cmd += ["-qp_i", "28", "-qp_p", "28"]
|
||||||
|
cmd += ["-c:a", "pcm_s16le", output_path]
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+18
-8
@@ -24,16 +24,26 @@ def _log(*args) -> None:
|
|||||||
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
def build_export_path(folder: str, basename: str, counter: int,
|
||||||
group = f"{basename}_{counter:03d}"
|
sub: int | None = None, tag: str | None = None) -> str:
|
||||||
name = f"{group}_{sub}" if sub is not None else group
|
"""Build clip output path. *folder* should be the vid folder (e.g. .../mp4/vid_001)."""
|
||||||
return os.path.join(folder, group, name + ".mp4")
|
name = f"{basename}_{counter:03d}"
|
||||||
|
if tag is not None:
|
||||||
|
name = f"{name}_{tag}"
|
||||||
|
if sub is not None:
|
||||||
|
name = f"{name}_{sub}"
|
||||||
|
return os.path.join(folder, name + ".mp4")
|
||||||
|
|
||||||
|
|
||||||
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
def build_sequence_dir(folder: str, basename: str, counter: int,
|
||||||
group = f"{basename}_{counter:03d}"
|
sub: int | None = None, tag: str | None = None) -> str:
|
||||||
name = f"{group}_{sub}" if sub is not None else group
|
"""Build WebP sequence output dir. *folder* should be the vid folder."""
|
||||||
return os.path.join(folder, group, name)
|
name = f"{basename}_{counter:03d}"
|
||||||
|
if tag is not None:
|
||||||
|
name = f"{name}_{tag}"
|
||||||
|
if sub is not None:
|
||||||
|
name = f"{name}_{sub}"
|
||||||
|
return os.path.join(folder, name)
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
def format_time(seconds: float) -> str:
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# Audio Pipeline Improvements Design
|
||||||
|
|
||||||
|
Date: 2026-04-19
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Improve audio scan classification accuracy, especially for non-speech sounds (suction, gagging, impacts), through three changes:
|
||||||
|
|
||||||
|
1. Multi-layer feature extraction from existing HuBERT/Wav2Vec2 models
|
||||||
|
2. Two new embedding models: AST (AudioSet-supervised) and EAT (self-supervised + AudioSet finetuned)
|
||||||
|
3. Calibrated classifier for better threshold behavior
|
||||||
|
|
||||||
|
## 1. Multi-Layer Feature Extraction
|
||||||
|
|
||||||
|
### Current behavior
|
||||||
|
|
||||||
|
`model(waveforms)` extracts embeddings from the **last transformer layer only**.
|
||||||
|
|
||||||
|
### Change
|
||||||
|
|
||||||
|
Use `model.extract_features(waveforms)` (torchaudio API) to get all layer outputs. Select layers at quartile boundaries, mean-pool each over time, concatenate.
|
||||||
|
|
||||||
|
| Model | Layers | Single-layer dim | Multi-layer dim (4 quartiles) |
|
||||||
|
|-------|--------|-------------------|-------------------------------|
|
||||||
|
| HUBERT_XLARGE | 48 | 1280 | 5120 |
|
||||||
|
| HUBERT_LARGE | 24 | 1024 | 4096 |
|
||||||
|
| HUBERT_BASE | 12 | 768 | 3072 |
|
||||||
|
| WAV2VEC2_BASE | 12 | 768 | 3072 |
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
- New entries in `_EMBED_MODELS`: `"HUBERT_XLARGE_ML"` -> 5120, etc.
|
||||||
|
- `_extract_w2v_windows`: when model name ends with `_ML`, call `extract_features()` instead of `model()`, select quartile layers, concat
|
||||||
|
- Cache key: model name includes `_ML` suffix -> separate cache files
|
||||||
|
- No change to classifier or training pipeline (HistGBT handles high-dim fine)
|
||||||
|
|
||||||
|
## 2. AST (Audio Spectrogram Transformer)
|
||||||
|
|
||||||
|
### What
|
||||||
|
|
||||||
|
`MIT/ast-finetuned-audioset-10-10-0.4593` via HuggingFace `transformers`. 86M params, 768-dim, supervised on AudioSet 527 sound classes.
|
||||||
|
|
||||||
|
### Integration
|
||||||
|
|
||||||
|
- Load: `ASTModel.from_pretrained()` + `ASTFeatureExtractor`
|
||||||
|
- Preprocessing: `ASTFeatureExtractor` handles mel spectrogram from 16kHz raw audio
|
||||||
|
- Batching: prepare `input_values` per window, stack into batch, forward through model
|
||||||
|
- Multi-layer: `output_hidden_states=True` returns 13 layers; `AST_ML` variant concats quartile layers -> 3072-dim
|
||||||
|
- Model cached via `_get_w2v_model()` same lazy-load pattern
|
||||||
|
|
||||||
|
### Entries
|
||||||
|
|
||||||
|
- `"AST"` -> 768
|
||||||
|
- `"AST_ML"` -> 3072
|
||||||
|
|
||||||
|
## 3. EAT (Efficient Audio Transformer)
|
||||||
|
|
||||||
|
### What
|
||||||
|
|
||||||
|
`worstchan/EAT-base_epoch30_finetune_AS2M` via HuggingFace with `trust_remote_code=True`. 88M params, 768-dim, self-supervised + AudioSet finetuned.
|
||||||
|
|
||||||
|
### Integration
|
||||||
|
|
||||||
|
- Load: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
||||||
|
- Preprocessing: manual 128-bin Kaldi fbank mel spectrogram via torchaudio, normalize with EAT constants `(mel - (-4.268)) / (4.569 * 2)`, reshape to `[B, 1, T, 128]`
|
||||||
|
- Feature extraction: `model.extract_features(mel)` returns `[B, seq, 768]`; CLS token `[:, 0, :]` for utterance-level, or mean-pool `[:, 1:, :]` for frame-level. Use mean-pool for consistency with other models.
|
||||||
|
- Multi-layer: not natively supported, skip for now
|
||||||
|
|
||||||
|
### Entry
|
||||||
|
|
||||||
|
- `"EAT"` -> 768
|
||||||
|
|
||||||
|
## 4. Calibrated Classifier
|
||||||
|
|
||||||
|
Wrap `HistGradientBoostingClassifier` in `CalibratedClassifierCV(clf, cv=3, method='isotonic')` after fitting. Gives well-calibrated probabilities -> threshold slider maps more linearly to precision/recall.
|
||||||
|
|
||||||
|
One change in `train_classifier()`, no UI changes needed.
|
||||||
|
|
||||||
|
## 5. Requirements
|
||||||
|
|
||||||
|
Add to `requirements.txt`:
|
||||||
|
```
|
||||||
|
transformers>=4.30
|
||||||
|
timm>=0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
Both AST and EAT need `transformers`. EAT additionally needs `timm` (used internally by its custom model code). Both setup scripts (`setup_env.sh`, `setup-windows.ps1`) install from `requirements.txt` so no changes needed there.
|
||||||
|
|
||||||
|
## Cache Compatibility
|
||||||
|
|
||||||
|
- All new model variants get distinct cache keys via model name in the hash
|
||||||
|
- Existing caches for HUBERT_XLARGE, BEATs, etc. remain valid and untouched
|
||||||
|
- New models create new `.npz` files in the same `cache/w2v/` directory
|
||||||
|
|
||||||
|
## UI Changes
|
||||||
|
|
||||||
|
- `_EMBED_MODELS` dict additions appear automatically in Train dialog model dropdown and scan model dropdown
|
||||||
|
- No other UI changes needed
|
||||||
@@ -0,0 +1,588 @@
|
|||||||
|
# Audio Pipeline Improvements Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Improve audio scan accuracy with multi-layer extraction, AST/EAT models, and calibrated classifier.
|
||||||
|
|
||||||
|
**Architecture:** All changes are in `core/audio_scan.py`. The embedding extraction functions gain new model-type branches (AST, EAT, multi-layer). The classifier gets a calibration wrapper. `_EMBED_MODELS` dict and `_get_w2v_model()` are extended. No UI changes needed — new models appear automatically in dropdowns.
|
||||||
|
|
||||||
|
**Tech Stack:** torchaudio (existing), transformers (new dep), timm (new dep), sklearn.calibration (existing dep)
|
||||||
|
|
||||||
|
**Key design notes:**
|
||||||
|
- `_get_w2v_model()` resolves `_ML` suffixed names to their base model for loading (e.g. `HUBERT_XLARGE_ML` loads `HUBERT_XLARGE`). Both share the same GPU model — only the extraction path differs (last-layer vs multi-layer). The global `_w2v_model_name` stores the **base** name so switching between `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` does NOT trigger a reload.
|
||||||
|
- Cache keys use the **full** model name (including `_ML`), so single-layer and multi-layer caches coexist as separate `.npz` files.
|
||||||
|
- AST and EAT are separate model types that do NOT share the torchaudio loading path — they get their own `elif` branches in `_get_w2v_model()`.
|
||||||
|
- Both `_extract_w2v_windows` and `_extract_w2v_targeted` need identical changes to their batch inference blocks. Keep them in sync.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add transformers and timm to requirements
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `requirements.txt`
|
||||||
|
|
||||||
|
**Step 1: Add dependencies**
|
||||||
|
|
||||||
|
Add after the `torchaudio` line in `requirements.txt`:
|
||||||
|
|
||||||
|
```
|
||||||
|
transformers>=4.30
|
||||||
|
timm>=0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify install**
|
||||||
|
|
||||||
|
Run: `pip install transformers timm`
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add requirements.txt
|
||||||
|
git commit -m "deps: add transformers and timm for AST/EAT models"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Multi-layer extraction for torchaudio models
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-58` (_EMBED_MODELS dict)
|
||||||
|
- Modify: `core/audio_scan.py:96-100` (_embed_dim)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model)
|
||||||
|
- Modify: `core/audio_scan.py:189-205` (_extract_w2v_windows batch loop)
|
||||||
|
- Modify: `core/audio_scan.py:278-293` (_extract_w2v_targeted batch loop)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_multi_layer():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
# Multi-layer models should report concatenated dimension
|
||||||
|
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||||
|
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||||
|
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||||
|
# Single-layer unchanged
|
||||||
|
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run test to verify it fails**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||||
|
Expected: FAIL — `_embed_dim("HUBERT_XLARGE_ML")` returns 768 (default fallback)
|
||||||
|
|
||||||
|
**Step 3: Add multi-layer entries to _EMBED_MODELS**
|
||||||
|
|
||||||
|
In `core/audio_scan.py:50-58`, add after existing entries:
|
||||||
|
|
||||||
|
```python
|
||||||
|
_EMBED_MODELS = {
|
||||||
|
"WAV2VEC2_BASE": 768,
|
||||||
|
"WAV2VEC2_LARGE": 1024,
|
||||||
|
"WAV2VEC2_LARGE_LV60K": 1024,
|
||||||
|
"HUBERT_BASE": 768,
|
||||||
|
"HUBERT_LARGE": 1024,
|
||||||
|
"HUBERT_XLARGE": 1280,
|
||||||
|
"BEATS": 768,
|
||||||
|
# Multi-layer variants (4 quartile layers concatenated)
|
||||||
|
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run test to verify it passes**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 5: Add helper to resolve base model and layer indices**
|
||||||
|
|
||||||
|
Add after `_embed_dim()` (around line 101):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||||
|
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||||
|
|
||||||
|
Returns None for single-layer models.
|
||||||
|
Layer indices are 0-based into the list returned by extract_features().
|
||||||
|
"""
|
||||||
|
if not model_name.endswith("_ML"):
|
||||||
|
return None
|
||||||
|
base = model_name[:-3] # strip "_ML"
|
||||||
|
if base not in _EMBED_MODELS:
|
||||||
|
return None
|
||||||
|
# Layer counts per model family
|
||||||
|
layer_counts = {
|
||||||
|
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||||
|
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||||
|
"AST": 12,
|
||||||
|
}
|
||||||
|
n = layer_counts.get(base)
|
||||||
|
if n is None:
|
||||||
|
return None
|
||||||
|
# Select 4 layers at quartile boundaries (0-indexed)
|
||||||
|
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||||
|
return base, indices
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: AST is included in the layer_counts dict here already so Task 3 doesn't need to modify it again.
|
||||||
|
|
||||||
|
**Step 6: Write test for _ml_config**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_ml_config():
|
||||||
|
from core.audio_scan import _ml_config
|
||||||
|
assert _ml_config("HUBERT_XLARGE") is None
|
||||||
|
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||||
|
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||||
|
assert base == "HUBERT_XLARGE"
|
||||||
|
assert layers == [11, 23, 35, 47]
|
||||||
|
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||||
|
assert base == "HUBERT_BASE"
|
||||||
|
assert layers == [2, 5, 8, 11]
|
||||||
|
```
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_ml_config -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 7: Modify _get_w2v_model to resolve ML base names**
|
||||||
|
|
||||||
|
In `_get_w2v_model()` (line 68), the comparison key must use the resolved base name so that `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` share the same loaded model without reloading:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_w2v_model(model_name: str | None = None):
|
||||||
|
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||||
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
# Multi-layer variants use the same base model weights
|
||||||
|
ml = _ml_config(model_name)
|
||||||
|
load_name = ml[0] if ml else model_name
|
||||||
|
if _w2v_model is None or _w2v_model_name != load_name:
|
||||||
|
import torch
|
||||||
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if load_name == "BEATS":
|
||||||
|
... # existing BEATs code unchanged
|
||||||
|
else:
|
||||||
|
import torchaudio
|
||||||
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
_w2v_model.eval()
|
||||||
|
_w2v_model_name = load_name
|
||||||
|
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||||
|
return _w2v_model, _w2v_device
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 8: Modify _extract_w2v_windows batch inference**
|
||||||
|
|
||||||
|
In `_extract_w2v_windows`, compute `ml_cfg` **once** before the batch loop (after line 173 `is_beats = ...`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then replace the batch inference block (lines 197-204):
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings.append(batch_emb)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 9: Modify _extract_w2v_targeted batch inference (keep in sync)**
|
||||||
|
|
||||||
|
In `_extract_w2v_targeted`, add `ml_cfg` computation after line 276 `is_beats = ...`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then replace the batch inference block (lines 285-292) with the same branching logic as Step 8:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings_list.append(batch_emb)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `_extract_w2v_targeted` appends to `embeddings_list` (not `embeddings`).
|
||||||
|
|
||||||
|
**Step 10: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 11: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: multi-layer extraction for HuBERT/Wav2Vec2 models"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: AST model integration
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add AST entries)
|
||||||
|
- Modify: `core/audio_scan.py:45-47` (add _ast_feature_extractor global)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add AST loading branch)
|
||||||
|
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add AST inference branch)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
```
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_ast -v`
|
||||||
|
Expected: FAIL
|
||||||
|
|
||||||
|
**Step 2: Add AST entries to _EMBED_MODELS**
|
||||||
|
|
||||||
|
Add to the dict (after the ML entries):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
```
|
||||||
|
|
||||||
|
Run test again — should PASS now.
|
||||||
|
|
||||||
|
**Step 3: Add module-level global for AST feature extractor**
|
||||||
|
|
||||||
|
Near line 47 (after `_w2v_model_name = None`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add AST loading branch in _get_w2v_model**
|
||||||
|
|
||||||
|
In `_get_w2v_model()`, add an `elif` branch **before** the torchaudio fallback `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `_ast_feature_extractor` is recreated on every model load (not cached separately) — simple and correct since the feature extractor is lightweight and model reloads are rare.
|
||||||
|
|
||||||
|
**Step 5: Add AST inference branch in both extraction functions**
|
||||||
|
|
||||||
|
In both `_extract_w2v_windows` and `_extract_w2v_targeted`, compute `is_ast` once before the loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in the batch inference block, add after the `elif ml_cfg` branch and before `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif is_ast:
|
||||||
|
# AST uses its own feature extractor for mel spectrogram
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
```
|
||||||
|
|
||||||
|
Important: `chunks` is already a list of numpy arrays (built in the loop at lines 194-196). Pass it directly as `list(chunks)` — the `ASTFeatureExtractor` accepts a list of numpy arrays and handles batching/padding internally. Verified: `ASTFeatureExtractor([np.array, np.array, ...], sampling_rate=16000, return_tensors="pt", padding=True)` returns `input_values` of shape `[B, 1024, 128]`.
|
||||||
|
|
||||||
|
**Step 6: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 7: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add AST (Audio Spectrogram Transformer) embedding model"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: EAT model integration
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add EAT entry)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add EAT loading branch)
|
||||||
|
- Add: `core/audio_scan.py` (_eat_preprocess helper function)
|
||||||
|
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add EAT inference branch)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add EAT entry to _EMBED_MODELS**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"EAT": 768,
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: No `EAT_ML` variant — EAT's `extract_features()` does not natively support multi-layer output. Can be added later if needed by monkey-patching.
|
||||||
|
|
||||||
|
**Step 3: Add EAT loading branch in _get_w2v_model**
|
||||||
|
|
||||||
|
Add after the AST branch, before the torchaudio `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif load_name == "EAT":
|
||||||
|
from transformers import AutoModel
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add EAT preprocessing helper**
|
||||||
|
|
||||||
|
Add as a module-level function near `_get_w2v_model`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Add EAT inference branch in both extraction functions**
|
||||||
|
|
||||||
|
Compute `is_eat` once before the loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in the batch inference block, add after the `elif is_ast` branch and before `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
# Mean-pool frame-level tokens (skip CLS at index 0)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
```
|
||||||
|
|
||||||
|
Important: `model.extract_features()` returns a plain `torch.Tensor` of shape `[B, 513, 768]` (not a tuple). Index 0 is the CLS token, indices 1-512 are frame-level patch embeddings. We mean-pool the frame tokens for consistency with how other models are pooled.
|
||||||
|
|
||||||
|
**Step 6: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 7: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add EAT (Efficient Audio Transformer) embedding model"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Calibrated classifier
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:424-429` (train_classifier, wrap clf)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Modify train_classifier**
|
||||||
|
|
||||||
|
After the existing `clf.fit()` call (line 428), add calibration with a safe guard:
|
||||||
|
|
||||||
|
```python
|
||||||
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
# Calibrate probabilities for better threshold behavior
|
||||||
|
# Requires at least 6 samples per class for stable 3-fold isotonic calibration
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
min_class = min(int(n_pos), int(n_neg_sample))
|
||||||
|
if min_class >= 6:
|
||||||
|
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||||
|
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
clf = cal_clf
|
||||||
|
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||||
|
```
|
||||||
|
|
||||||
|
Why `min_class >= 6`: `CalibratedClassifierCV` uses stratified k-fold internally. With `cv=3`, each fold needs at least 2 samples per class. `min_class >= 6` guarantees this. With fewer samples, the uncalibrated HistGBT probabilities are still reasonable — calibration is an enhancement, not a requirement.
|
||||||
|
|
||||||
|
Previous plan bug: `cv=min(3, n_pos, n_neg_sample)` could produce `cv=1` when `n_pos=1`, which raises `ValueError` (minimum is 2). Even `cv=2` with 2 positives causes one fold to have only 1 positive, making isotonic regression unstable. The `>= 6` guard avoids all these edge cases.
|
||||||
|
|
||||||
|
**Step 2: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py
|
||||||
|
git commit -m "feat: calibrate classifier probabilities with isotonic regression"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Integration test with real model (manual)
|
||||||
|
|
||||||
|
This task is manual — it requires GPU and a real video file.
|
||||||
|
|
||||||
|
**Step 1: Test multi-layer extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='HUBERT_XLARGE_ML')
|
||||||
|
print(f'HUBERT_XLARGE_ML: {emb.shape}') # expect (13, 5120)
|
||||||
|
assert emb.shape[1] == _embed_dim('HUBERT_XLARGE_ML')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Test AST extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='AST')
|
||||||
|
print(f'AST: {emb.shape}') # expect (13, 768)
|
||||||
|
assert emb.shape[1] == _embed_dim('AST')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Test AST multi-layer**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='AST_ML')
|
||||||
|
print(f'AST_ML: {emb.shape}') # expect (13, 3072)
|
||||||
|
assert emb.shape[1] == _embed_dim('AST_ML')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Test EAT extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='EAT')
|
||||||
|
print(f'EAT: {emb.shape}') # expect (13, 768)
|
||||||
|
assert emb.shape[1] == _embed_dim('EAT')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Test model switching doesn't reload unnecessarily**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _get_w2v_model
|
||||||
|
import core.audio_scan as m
|
||||||
|
# Load HUBERT_XLARGE
|
||||||
|
_get_w2v_model('HUBERT_XLARGE')
|
||||||
|
name1 = m._w2v_model_name
|
||||||
|
# Switch to ML variant — should NOT reload
|
||||||
|
_get_w2v_model('HUBERT_XLARGE_ML')
|
||||||
|
name2 = m._w2v_model_name
|
||||||
|
assert name1 == name2 == 'HUBERT_XLARGE', f'Expected no reload, got {name1} -> {name2}'
|
||||||
|
print('PASS: no reload on ML switch')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Test full train+scan cycle in app**
|
||||||
|
|
||||||
|
Load app, select each new model from scan model dropdown, scan a video, train, verify results display correctly.
|
||||||
|
|
||||||
|
**Step 7: Final commit and push**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push
|
||||||
|
```
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
# ComfyUI-8cut Node Pack Design
|
||||||
|
|
||||||
|
Date: 2026-04-19
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Port 8-cut's video scanning, training, review, and export workflow to a ComfyUI node pack. The primary motivation is **remote access** — ComfyUI's web UI allows browser-based operation over the network, and HTML5 `<video>` handles streaming compression natively. No tensor-based image pipeline; videos stay as file paths throughout.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Approach
|
||||||
|
|
||||||
|
Monolithic Review Node + simple pipeline nodes. One central **VideoReview** node embeds the full interactive player/timeline/region table as a large DOM widget. Other nodes (Scan, Train, Export) are headless pipeline nodes that pass lightweight metadata.
|
||||||
|
|
||||||
|
### Core reuse
|
||||||
|
|
||||||
|
The entire `8-cut/core/` package is Qt-free and reusable as-is:
|
||||||
|
- `core/audio_scan.py` — `scan_video()`, `train_classifier()`, `load_classifier()`
|
||||||
|
- `core/db.py` — `ProcessedDB` (SQLite, all scan/training/export persistence)
|
||||||
|
- `core/ffmpeg.py` — `build_ffmpeg_command()` (clip export)
|
||||||
|
- `core/tracking.py` — YOLO-based subject tracking
|
||||||
|
- `core/paths.py` — path helpers, `format_time()`
|
||||||
|
|
||||||
|
No porting required — these are imported directly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Node Pack Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
ComfyUI-8cut/
|
||||||
|
__init__.py # NODE_CLASS_MAPPINGS, WEB_DIRECTORY
|
||||||
|
core/ # symlink or copy of 8-cut/core/
|
||||||
|
data/
|
||||||
|
8cut.db # separate SQLite DB (can copy from ~/.8cut.db)
|
||||||
|
models/ # trained classifiers (.joblib)
|
||||||
|
nodes/
|
||||||
|
load_video.py
|
||||||
|
audio_scan.py
|
||||||
|
video_review.py
|
||||||
|
train_model.py
|
||||||
|
export_clips.py
|
||||||
|
server_routes.py # custom API routes
|
||||||
|
web/
|
||||||
|
js/
|
||||||
|
video_review.js # timeline + player + scan panel widget
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Custom Types
|
||||||
|
|
||||||
|
No tensors anywhere in the pipeline. All data flows as lightweight metadata:
|
||||||
|
|
||||||
|
| Type | Python value | Purpose |
|
||||||
|
|------|-------------|---------|
|
||||||
|
| `VIDEO_PATH` | `str` (absolute path) | Video file reference |
|
||||||
|
| `SCAN_REGIONS` | `list[dict]` with start/end/score/model/disabled | Scan output / review edits |
|
||||||
|
| `SCAN_MODEL` | `str` (path to .joblib) | Trained classifier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### LoadVideo
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `video_path` (STRING, file browser), `profile` (STRING combo from DB profiles) |
|
||||||
|
| **Output** | `VIDEO_PATH`, `filename` (STRING) |
|
||||||
|
| **Logic** | Validates path exists, returns it. Populates profile combo via API route. |
|
||||||
|
|
||||||
|
### AudioScan
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_MODEL`, `threshold` (FLOAT 0-1), `hop` (FLOAT) |
|
||||||
|
| **Output** | `SCAN_REGIONS` |
|
||||||
|
| **Logic** | Calls `core.audio_scan.scan_video()` directly. Progress via `PromptServer.send_sync("progress", ...)`. |
|
||||||
|
|
||||||
|
### VideoReview (interactive, blocking)
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS` (optional) |
|
||||||
|
| **Output** | `SCAN_REGIONS` (edited) |
|
||||||
|
| **OUTPUT_NODE** | `True` |
|
||||||
|
| **Logic** | Execution pauses here. User interacts via the widget. Clicks "Continue" to pass edited regions downstream. |
|
||||||
|
|
||||||
|
The widget layout:
|
||||||
|
|
||||||
|
```
|
||||||
|
+-------------------------------------+
|
||||||
|
| [video player (HTML5 <video>)] |
|
||||||
|
| +- timeline with scan regions ----+|
|
||||||
|
| | cursor + region drag/resize ||
|
||||||
|
| +---------------------------------+|
|
||||||
|
| +- model tabs [EAT_LARGE][HuBERT]+|
|
||||||
|
| | Time | End | Score ||
|
||||||
|
| | 1:23 | 1:31 | 0.92 ||
|
||||||
|
| | 3:45 | 3:53 | 0.87 ||
|
||||||
|
| | [Add Negative] [Export] [Continue]|
|
||||||
|
| +---------------------------------+|
|
||||||
|
+-------------------------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
Widget size: ~640x500px minimum, resizable via LiteGraph.
|
||||||
|
|
||||||
|
**Blocking mechanism**: The node's `run()` method blocks on a server-side event/queue. The frontend signals completion via `POST /8cut/review_done/{node_id}`, which unblocks `run()` and returns the edited `SCAN_REGIONS`.
|
||||||
|
|
||||||
|
### TrainModel
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `profile` (STRING combo), `positive_folder` (STRING combo), `negative_folder` (STRING combo, optional), `embed_model` (STRING combo from `_EMBED_MODELS`), `use_hard_negatives` (BOOL) |
|
||||||
|
| **Output** | `SCAN_MODEL` |
|
||||||
|
| **Logic** | Queries `db.get_training_data()` to assemble `video_infos`, calls `core.audio_scan.train_classifier()`. Saves to `models/{profile}_{embed_model}.joblib` with version rotation. Progress via ComfyUI progress bar. |
|
||||||
|
|
||||||
|
### ExportClips
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS`, `output_folder` (STRING), `short_side` (INT), `format` (combo MP4/WEBM), `spread` (FLOAT), `clip_count` (INT), `fuse_gap` (FLOAT) |
|
||||||
|
| **Output** | exported file paths (list) |
|
||||||
|
| **Logic** | Region fusion via `_build_export_spans()`, then `core.ffmpeg.build_ffmpeg_command()` per clip. Records each clip in DB via `db.add()`. |
|
||||||
|
|
||||||
|
### Typical workflow
|
||||||
|
|
||||||
|
```
|
||||||
|
[LoadVideo] --> [AudioScan] --> [VideoReview] --> [ExportClips]
|
||||||
|
^
|
||||||
|
[TrainModel]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training loop (hard negatives round-trip)
|
||||||
|
|
||||||
|
1. Scan with existing model -> regions in VideoReview
|
||||||
|
2. Review -> mark false positives as negatives (DB)
|
||||||
|
3. Train -> new model uses hard negatives
|
||||||
|
4. Rescan -> better results
|
||||||
|
5. Repeat
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Routes
|
||||||
|
|
||||||
|
### Video serving
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/video` | GET | Serve raw video file via `web.FileResponse`. Query param: `path`. Browser decodes mp4/h264 natively — key for remote streaming. |
|
||||||
|
| `/8cut/video_transcode` | GET | Fallback: transcode to webm on-the-fly via ffmpeg `StreamResponse` for browser-incompatible formats (some MKV, odd codecs). |
|
||||||
|
|
||||||
|
### Region editing (from VideoReview widget)
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/toggle_region` | POST | `toggle_scan_result_disabled()` |
|
||||||
|
| `/8cut/resize_region` | POST | `update_scan_result()` |
|
||||||
|
| `/8cut/delete_region` | POST | `delete_scan_result()` |
|
||||||
|
| `/8cut/add_negatives` | POST | `add_hard_negatives()` |
|
||||||
|
| `/8cut/scan_versions` | GET | `get_scan_versions()` |
|
||||||
|
| `/8cut/review_done/{node_id}` | POST | Unblock the VideoReview node's `run()`, pass final regions |
|
||||||
|
|
||||||
|
### Data queries (for combo widget population)
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/profiles` | GET | `db.get_profiles()` |
|
||||||
|
| `/8cut/export_folders` | GET | `db.get_export_folders()` |
|
||||||
|
| `/8cut/models` | GET | List available `.joblib` models |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Frontend JS Widget (`web/js/video_review.js`)
|
||||||
|
|
||||||
|
Registered via `app.registerExtension()`. Hooks into the VideoReview node's `onNodeCreated` and `onExecuted` callbacks.
|
||||||
|
|
||||||
|
### Components
|
||||||
|
|
||||||
|
1. **Video player** — HTML5 `<video>` element, src pointed at `/8cut/video?path=...`
|
||||||
|
2. **Timeline** — `<canvas>` overlay below the video. Renders:
|
||||||
|
- Scan region rectangles (color-coded by score, red for negatives, gray for disabled)
|
||||||
|
- Cursor line (click to seek)
|
||||||
|
- Drag handles on region edges (resize)
|
||||||
|
- Waveform (optional, fetched via separate route)
|
||||||
|
3. **Region table** — HTML table with model tabs. Click row to seek. Columns: Time, End, Score.
|
||||||
|
4. **Action buttons** — Add Negative, Export, Continue
|
||||||
|
5. **Version combo** — dropdown to switch scan history versions
|
||||||
|
|
||||||
|
### Interaction flow
|
||||||
|
|
||||||
|
- Widget activates when `onExecuted` fires with scan regions
|
||||||
|
- User clicks/drags timeline, edits regions, marks negatives
|
||||||
|
- Each edit hits an API route (immediate DB persistence)
|
||||||
|
- "Continue" sends `POST /8cut/review_done/{node_id}` with final region state
|
||||||
|
- Node's `run()` unblocks, passes `SCAN_REGIONS` downstream
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## DB
|
||||||
|
|
||||||
|
Separate SQLite DB at `ComfyUI-8cut/data/8cut.db`. Uses the existing `ProcessedDB` class unchanged — same schema, same migration code. Users can copy their existing `~/.8cut.db` to carry over scan history, training data, and hard negatives.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
Same as 8-cut's `requirements.txt` minus PyQt6/python-mpv:
|
||||||
|
- `torch`, `torchaudio`, `torchvision` (from CUDA index)
|
||||||
|
- `transformers>=4.30,<5.0`, `timm>=0.9`
|
||||||
|
- `librosa`, `scikit-learn`, `joblib`, `soundfile`, `numpy`
|
||||||
|
- `ultralytics` (YOLO tracking)
|
||||||
|
|
||||||
|
ComfyUI already provides torch. The node pack's install script just needs the audio/ML extras.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Priority
|
||||||
|
|
||||||
|
1. **Node pack skeleton** — structure, `__init__.py`, custom types, API routes for video serving
|
||||||
|
2. **LoadVideo + AudioScan** — headless nodes, no widget needed yet
|
||||||
|
3. **VideoReview widget (minimal)** — video player + static region display + Continue button
|
||||||
|
4. **VideoReview interactivity** — timeline click/drag, region editing, negative marking
|
||||||
|
5. **TrainModel + ExportClips** — complete the pipeline
|
||||||
|
6. **Polish** — version history, waveform overlay, transcode fallback
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,205 @@
|
|||||||
|
# Scan History & Hard Negative Management — Final Design
|
||||||
|
|
||||||
|
Date: 2026-04-19 (implemented on `feat/training-ui`)
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
1. Keep scan result history per `(file, model)` so users can track classifier improvement across training iterations
|
||||||
|
2. Make hard negatives manageable — viewable, removable, and optionally disabled per training run
|
||||||
|
3. Fix latent bug: `get_export_folders()` doesn't filter by `scan_export`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Ghost Folder Fix
|
||||||
|
|
||||||
|
### Bug
|
||||||
|
|
||||||
|
`get_export_folders()` queried all `output_path` rows without filtering `scan_export`. Folders that only contained scan-exported clips appeared in training dropdowns with 0 clips.
|
||||||
|
|
||||||
|
### Implementation (`core/db.py`)
|
||||||
|
|
||||||
|
**`get_export_folders(profile, include_scan_exports=False)`** — new parameter. When `False` (default), the SQL query adds `AND scan_export = 0` to exclude scan-only folders. The `get_training_stats()` method passes this through and also filters its return dict to remove folders with 0 clips:
|
||||||
|
|
||||||
|
```python
|
||||||
|
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test
|
||||||
|
|
||||||
|
`tests/test_db.py::test_export_folders_excludes_scan_exports` — verifies scan-only folders are excluded by default and included when `include_scan_exports=True`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Scan Result History
|
||||||
|
|
||||||
|
### Schema
|
||||||
|
|
||||||
|
Added column to `scan_results`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
scan_timestamp TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
All rows from the same scan share one timestamp string with **microsecond precision** (`%Y%m%d_%H%M%S_%f`, e.g. `"20260419_143022_123456"`). Microsecond precision prevents version collisions on fast successive scans.
|
||||||
|
|
||||||
|
Migration adds the column via `ALTER TABLE` for existing databases. Legacy rows keep `scan_timestamp = ''`.
|
||||||
|
|
||||||
|
### DB methods (`core/db.py`)
|
||||||
|
|
||||||
|
**`save_scan_results(filename, profile, model, regions, max_versions=5)`**
|
||||||
|
1. Inserts new rows with current microsecond-precision timestamp
|
||||||
|
2. Counts distinct timestamps for this `(filename, profile, model)`
|
||||||
|
3. Prunes oldest timestamps beyond `max_versions`
|
||||||
|
|
||||||
|
No more DELETE-then-INSERT — all versions coexist in the table.
|
||||||
|
|
||||||
|
**`get_scan_versions(filename, profile, model)`**
|
||||||
|
Returns `[{timestamp, count, max_score}, ...]` ordered newest first. Filters `scan_timestamp != ''` so legacy rows don't appear as named versions.
|
||||||
|
|
||||||
|
**`get_scan_results(filename, profile, scan_timestamp=None)`**
|
||||||
|
- With `scan_timestamp`: returns rows matching that exact version
|
||||||
|
- Without (default): uses `INNER JOIN` subquery with `MAX(scan_timestamp)` per model to return only the latest version. Legacy rows (empty timestamp) sort before any real timestamp, so they're returned when no versioned scans exist.
|
||||||
|
|
||||||
|
### UI (`main.py` — `ScanResultsPanel`)
|
||||||
|
|
||||||
|
Each model tab wraps its `QTableWidget` in a container `QWidget` with a `QComboBox` for version selection:
|
||||||
|
|
||||||
|
```
|
||||||
|
container (QWidget)
|
||||||
|
├── cmb_version (QComboBox) — hidden when ≤ 1 version
|
||||||
|
└── table (QTableWidget)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Helper methods** unwrap this container:
|
||||||
|
- `_current_table()` — returns `QTableWidget` from active tab (handles both raw table and container)
|
||||||
|
- `_tab_table(index)` — same by tab index
|
||||||
|
|
||||||
|
**Version combo** is populated by `_populate_version_combos()` after every `load_for_file()` and `add_scan_results()` call. Labels use `datetime.strptime` parsing with try/except fallback for robustness:
|
||||||
|
|
||||||
|
```
|
||||||
|
2026-04-19 14:30 (12 regions, best: 0.95)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Version switching** via `_on_version_changed(model, idx)`:
|
||||||
|
1. Reads `scan_timestamp` from combo's `userData`
|
||||||
|
2. Calls `get_scan_results(filename, profile, scan_timestamp=ts)`
|
||||||
|
3. Repopulates the table in-place
|
||||||
|
4. **Clears the undo stack** — stale undo entries from a different version would corrupt data
|
||||||
|
5. Emits `regions_edited` to refresh the timeline
|
||||||
|
|
||||||
|
**Tab switch** connects `tab_changed` signal to `_on_scan_regions_edited` (not just `_update_scan_export_count`), so the timeline updates scan regions when switching model tabs.
|
||||||
|
|
||||||
|
### Cache interaction
|
||||||
|
|
||||||
|
Embedding cache is per `(file, model)` and doesn't change across scans. History stores classified regions (start, end, score), not embeddings.
|
||||||
|
|
||||||
|
### Test
|
||||||
|
|
||||||
|
`tests/test_db.py::test_scan_result_history` — saves 3 versions, verifies counts, ordering, and latest-by-default behavior.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Hard Negative Management
|
||||||
|
|
||||||
|
### Schema
|
||||||
|
|
||||||
|
Added column to `hard_negatives`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
source_model TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
Migration adds the column via `ALTER TABLE` for existing databases.
|
||||||
|
|
||||||
|
### DB methods (`core/db.py`)
|
||||||
|
|
||||||
|
**`add_hard_negatives(filename, profile, times, source_path="", source_model="")`** — now stores which embedding model produced the scan that led to the negative marking.
|
||||||
|
|
||||||
|
**`get_hard_negatives(profile)`** — returns all rows as `[{id, filename, start_time, source_path, source_model}, ...]` for the management dialog.
|
||||||
|
|
||||||
|
**`delete_hard_negatives_by_ids(ids)`** — bulk delete by row IDs.
|
||||||
|
|
||||||
|
**`get_training_data(..., use_hard_negatives=True)`** — new parameter. When `False`, the hard negatives query is skipped entirely. Non-destructive — negatives remain in DB.
|
||||||
|
|
||||||
|
### Source model tracking (`main.py`)
|
||||||
|
|
||||||
|
`_on_scan_negatives()` now passes `source_model=self._scan_panel.current_model_name()` when marking negatives from scan results. `current_model_name()` extracts the model name from the active tab text (stripping the count suffix).
|
||||||
|
|
||||||
|
### Training toggle (`main.py` — `TrainDialog`)
|
||||||
|
|
||||||
|
Checkbox **"Use hard negatives in training"** (default checked) with "Manage..." button in an HBox layout. The toggle:
|
||||||
|
- Updates live training stats preview via debounced `_update_stats()`
|
||||||
|
- Passes `use_hard_negatives` through `_open_train_dialog()` to `get_training_data()`
|
||||||
|
|
||||||
|
### Management dialog (`main.py` — `HardNegativesDialog`)
|
||||||
|
|
||||||
|
Accessible from TrainDialog's "Manage..." button. Features:
|
||||||
|
|
||||||
|
| Component | Details |
|
||||||
|
|-----------|---------|
|
||||||
|
| **Filter combo** | `(all)` + each distinct `source_model` found in data |
|
||||||
|
| **Summary label** | `<b>N</b> hard negatives` |
|
||||||
|
| **Table** | File, Time (`{:.1f}s`), Source Model, hidden ID column |
|
||||||
|
| **Delete Selected** | Multi-select aware, skips hidden (filtered) rows |
|
||||||
|
| **Clear All** | **Filter-aware**: if a model filter is active, only deletes negatives for that model with an appropriate confirmation message. If `(all)`, deletes everything. |
|
||||||
|
| **Close** | Closes dialog, triggers stats refresh in parent TrainDialog |
|
||||||
|
|
||||||
|
`blockSignals(True)` guards prevent spurious filter callbacks during `_load()` repopulation.
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- `test_hard_negatives_source_model` — verifies source_model stored and retrieved
|
||||||
|
- `test_training_data_skips_hard_negatives` — verifies `use_hard_negatives=False` excludes them
|
||||||
|
- `test_delete_hard_negatives_by_ids` — verifies bulk deletion by ID
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Runtime Fixes (discovered during testing)
|
||||||
|
|
||||||
|
### EAT/torchvision ABI mismatch
|
||||||
|
|
||||||
|
**Problem:** `torchvision` installed from PyPI (CPU build) was incompatible with `torch` from CUDA wheel index, causing `operator torchvision::nms does not exist`.
|
||||||
|
|
||||||
|
**Fix:** Added `torchvision` to the explicit torch install line in both setup scripts:
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
```
|
||||||
|
|
||||||
|
Also added `--extra-index-url "$TORCH_INDEX"` to the `pip install -r requirements.txt` line to prevent transitive dependencies (timm, ultralytics) from pulling CPU-only torch packages.
|
||||||
|
|
||||||
|
Applied to: `setup_env.sh` (both conda and venv paths), `setup-windows.ps1`.
|
||||||
|
|
||||||
|
### EAT / transformers 5.x incompatibility
|
||||||
|
|
||||||
|
**Problem:** transformers 5.x broke EAT's remote model code (`'EATModel' object has no attribute 'all_tied_weights_keys'`).
|
||||||
|
|
||||||
|
**Fix:** Pinned `transformers>=4.30,<5.0` in `requirements.txt`.
|
||||||
|
|
||||||
|
### NumPy non-writable array warning
|
||||||
|
|
||||||
|
**Problem:** Cached HuBERT/EAT embeddings loaded from disk are read-only numpy arrays. `torch.from_numpy()` on a non-writable array triggers a deprecation warning.
|
||||||
|
|
||||||
|
**Fix:** In `core/audio_scan.py`, changed EAT preprocessing to copy the array:
|
||||||
|
```python
|
||||||
|
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timeline not updating on tab switch
|
||||||
|
|
||||||
|
**Problem:** Switching model tabs in the scan results panel didn't refresh the timeline's highlighted regions because `tab_changed` was only connected to `_update_scan_export_count`.
|
||||||
|
|
||||||
|
**Fix:** Connected `tab_changed` to `_on_scan_regions_edited` instead, which handles both timeline refresh and export count update.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## File Summary
|
||||||
|
|
||||||
|
| File | Changes |
|
||||||
|
|------|---------|
|
||||||
|
| `core/db.py` | Schema migrations, `get_export_folders` filter, versioned `save_scan_results`, `get_scan_versions`, version-aware `get_scan_results`, `add_hard_negatives` with `source_model`, `get_hard_negatives`, `delete_hard_negatives_by_ids`, `get_training_data` with `use_hard_negatives` |
|
||||||
|
| `main.py` | `HardNegativesDialog` class, `TrainDialog` hard neg toggle + manage button, `ScanResultsPanel` container/combo architecture, version combo population and switching, `current_model_name()`, tab-switch timeline fix |
|
||||||
|
| `core/audio_scan.py` | `np.array(chunk)` copy for read-only numpy arrays in EAT preprocessing |
|
||||||
|
| `requirements.txt` | `transformers>=4.30,<5.0` pin |
|
||||||
|
| `setup_env.sh` | `torchvision` in torch install, `--extra-index-url` on requirements install |
|
||||||
|
| `setup-windows.ps1` | `torchvision` in torch install, `--extra-index-url` on requirements install, removed skip-if-exists guard |
|
||||||
|
| `tests/test_db.py` | 5 tests covering all DB-layer changes |
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
# Scan History & Hard Negative Management — Implementation Log
|
||||||
|
|
||||||
|
> All tasks complete. See the design doc for the final specification.
|
||||||
|
|
||||||
|
**Branch:** `feat/training-ui`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Fix ghost folder bug in get_export_folders -- DONE
|
||||||
|
|
||||||
|
**Commit:** `2614a76 fix: get_export_folders respects scan_export filter`
|
||||||
|
|
||||||
|
- `core/db.py` — `get_export_folders(profile, include_scan_exports=False)`: filters `scan_export = 0` by default
|
||||||
|
- `core/db.py` — `get_training_stats()`: passes `include_scan_exports` through, filters out 0-clip folders
|
||||||
|
- `tests/test_db.py` — `test_export_folders_excludes_scan_exports`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Scan result history — schema and DB methods -- DONE
|
||||||
|
|
||||||
|
**Commit:** `4fb2ae1 feat: scan result history — keep N versions per (file, model)`
|
||||||
|
|
||||||
|
- `core/db.py` — added `scan_timestamp TEXT NOT NULL DEFAULT ''` column with migration
|
||||||
|
- `core/db.py` — `save_scan_results()`: versioned insert with microsecond-precision timestamp (`%Y%m%d_%H%M%S_%f`), auto-prunes beyond `max_versions=5`
|
||||||
|
- `core/db.py` — `get_scan_versions()`: returns `[{timestamp, count, max_score}, ...]` newest first
|
||||||
|
- `core/db.py` — `get_scan_results(scan_timestamp=None)`: `INNER JOIN` subquery with `MAX(scan_timestamp)` for latest-by-default
|
||||||
|
- `tests/test_db.py` — `test_scan_result_history`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Scan history UI — version selector in ScanResultsPanel -- DONE
|
||||||
|
|
||||||
|
**Commit:** `8ed9fbf feat: scan version selector in results panel`
|
||||||
|
|
||||||
|
- `main.py` — `_add_tab()`: wraps table in container `QWidget` with version `QComboBox` (hidden when ≤ 1 version)
|
||||||
|
- `main.py` — `_current_table()` / `_tab_table(idx)`: unwrap container to get `QTableWidget`
|
||||||
|
- `main.py` — `_populate_version_combos()`: queries `get_scan_versions()`, formats labels with `datetime.strptime` + try/except fallback
|
||||||
|
- `main.py` — `_on_version_changed()`: reloads table from specific version, clears undo stack, emits `regions_edited`
|
||||||
|
- `main.py` — `current_model_name()`: extracts model name from tab text
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Hard negatives — schema and training toggle -- DONE
|
||||||
|
|
||||||
|
**Commit:** `edc5784 feat: hard negative source_model tracking, training toggle`
|
||||||
|
|
||||||
|
- `core/db.py` — added `source_model TEXT NOT NULL DEFAULT ''` column to `hard_negatives` with migration
|
||||||
|
- `core/db.py` — `add_hard_negatives(source_model="")`: stores originating model
|
||||||
|
- `core/db.py` — `get_hard_negatives(profile)`: returns full rows as list of dicts
|
||||||
|
- `core/db.py` — `delete_hard_negatives_by_ids(ids)`: bulk delete by row IDs
|
||||||
|
- `core/db.py` — `get_training_data(use_hard_negatives=True)`: conditionally skips hard negatives query
|
||||||
|
- `main.py` — `TrainDialog`: "Use hard negatives" checkbox + "Manage..." button in HBox layout
|
||||||
|
- `main.py` — `_on_scan_negatives()`: passes `source_model=self._scan_panel.current_model_name()`
|
||||||
|
- `tests/test_db.py` — `test_hard_negatives_source_model`, `test_training_data_skips_hard_negatives`, `test_delete_hard_negatives_by_ids`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Hard negatives management dialog -- DONE
|
||||||
|
|
||||||
|
**Commit:** `e6db83f feat: hard negatives management dialog with filter and bulk delete`
|
||||||
|
|
||||||
|
- `main.py` — `HardNegativesDialog`: table with File/Time/Source Model/hidden ID columns, model filter combo, delete selected, filter-aware clear all, close button
|
||||||
|
- Filter-aware "Clear All": respects active model filter, shows appropriate confirmation message
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Code review fixes -- DONE
|
||||||
|
|
||||||
|
**Commit:** `5d45b8d fix: timestamp collision, undo stack invalidation, label parsing, filter-aware clear`
|
||||||
|
|
||||||
|
Four issues found during code review:
|
||||||
|
1. **Timestamp collision** — second-precision timestamps could merge versions on sub-second calls. Fixed with microsecond precision `%f`
|
||||||
|
2. **Undo stack invalidation** — switching scan versions left stale undo entries. Fixed by clearing undo stack in `_on_version_changed()`
|
||||||
|
3. **Timestamp label fragile parsing** — hard-coded string slicing. Fixed with `datetime.strptime` + try/except fallback
|
||||||
|
4. **Clear All ignoring filter** — deleted all negatives regardless of model filter. Fixed to respect active filter
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Runtime fixes (discovered during manual testing)
|
||||||
|
|
||||||
|
| Commit | Fix |
|
||||||
|
|--------|-----|
|
||||||
|
| `a3c657c` | Install `torchvision` from CUDA wheel index (was pulling CPU build from PyPI) |
|
||||||
|
| `3c3b1d7` | Remove "skip if torch exists" guard in Windows setup so re-runs fix broken envs |
|
||||||
|
| `fd043f4` | Pin `transformers>=4.30,<5.0` — EAT remote model code incompatible with transformers 5.x |
|
||||||
|
| `7d6fee9` | Copy read-only numpy array before `torch.from_numpy()` in EAT preprocessing |
|
||||||
|
| `bd345ab` | Connect `tab_changed` to `_on_scan_regions_edited` so timeline refreshes on tab switch |
|
||||||
|
| `d8b3972` | Add `--extra-index-url` to `pip install -r requirements.txt` in both setup scripts |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Test results
|
||||||
|
|
||||||
|
All 68 tests pass (5 new DB tests + 63 existing).
|
||||||
@@ -13,6 +13,8 @@ soundfile>=0.12
|
|||||||
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
torch>=2.0
|
torch>=2.0
|
||||||
torchaudio>=2.0
|
torchaudio>=2.0
|
||||||
|
transformers>=4.30,<5.0 # EAT remote model code incompatible with transformers 5.x
|
||||||
|
timm>=0.9
|
||||||
|
|
||||||
# Object detection
|
# Object detection
|
||||||
ultralytics>=8.0
|
ultralytics>=8.0
|
||||||
|
|||||||
+12
-14
@@ -22,25 +22,23 @@ if (Test-Path (Join-Path $venvDir "Scripts\python.exe")) {
|
|||||||
& "$venvDir\Scripts\Activate.ps1"
|
& "$venvDir\Scripts\Activate.ps1"
|
||||||
|
|
||||||
# ── PyTorch ───────────────────────────────────────────────
|
# ── PyTorch ───────────────────────────────────────────────
|
||||||
$hasTorch = python -c "import torch" 2>&1
|
# Detect NVIDIA GPU via nvidia-smi
|
||||||
if ($LASTEXITCODE -eq 0) {
|
$hasNvidia = Get-Command nvidia-smi -ErrorAction SilentlyContinue
|
||||||
Write-Host "`nPyTorch already installed, skipping." -ForegroundColor Green
|
if ($hasNvidia) {
|
||||||
|
$torchIndex = "https://download.pytorch.org/whl/cu128"
|
||||||
|
Write-Host "`nNVIDIA GPU detected — using CUDA 12.8 PyTorch index" -ForegroundColor Green
|
||||||
} else {
|
} else {
|
||||||
# Detect NVIDIA GPU via nvidia-smi
|
$torchIndex = "https://download.pytorch.org/whl/cpu"
|
||||||
$hasNvidia = Get-Command nvidia-smi -ErrorAction SilentlyContinue
|
Write-Host "`nNo NVIDIA GPU detected — using CPU-only PyTorch index" -ForegroundColor Yellow
|
||||||
if ($hasNvidia) {
|
|
||||||
Write-Host "`nNVIDIA GPU detected — installing PyTorch with CUDA 12.8..." -ForegroundColor Green
|
|
||||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
|
||||||
} else {
|
|
||||||
Write-Host "`nNo NVIDIA GPU detected — installing CPU-only PyTorch..." -ForegroundColor Yellow
|
|
||||||
Write-Host "(Audio scanning will work but will be slower without GPU)" -ForegroundColor Yellow
|
|
||||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
# Always install/upgrade torch stack from correct index
|
||||||
|
# (pip install is a no-op if already at the right version)
|
||||||
|
Write-Host "Installing PyTorch + torchaudio + torchvision..."
|
||||||
|
pip install torch torchaudio torchvision --index-url $torchIndex
|
||||||
|
|
||||||
# ── Python deps ───────────────────────────────────────────
|
# ── Python deps ───────────────────────────────────────────
|
||||||
Write-Host "`nInstalling project dependencies..."
|
Write-Host "`nInstalling project dependencies..."
|
||||||
pip install -r (Join-Path $root "requirements.txt")
|
pip install -r (Join-Path $root "requirements.txt") --extra-index-url $torchIndex
|
||||||
|
|
||||||
# ── libmpv ────────────────────────────────────────────────
|
# ── libmpv ────────────────────────────────────────────────
|
||||||
$mpvDll = Join-Path $root "libmpv-2.dll"
|
$mpvDll = Join-Path $root "libmpv-2.dll"
|
||||||
|
|||||||
+4
-4
@@ -66,10 +66,10 @@ setup_conda() {
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
pip install torch torchaudio --index-url "$TORCH_INDEX"
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
echo " Installing project dependencies..."
|
echo " Installing project dependencies..."
|
||||||
pip install -r "$SCRIPT_DIR/requirements.txt"
|
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "Done! Activate with:"
|
echo "Done! Activate with:"
|
||||||
@@ -91,10 +91,10 @@ setup_venv() {
|
|||||||
source "$VENV_DIR/bin/activate"
|
source "$VENV_DIR/bin/activate"
|
||||||
|
|
||||||
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
pip install torch torchaudio --index-url "$TORCH_INDEX"
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
echo " Installing project dependencies..."
|
echo " Installing project dependencies..."
|
||||||
pip install -r "$SCRIPT_DIR/requirements.txt"
|
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "Done! Activate with:"
|
echo "Done! Activate with:"
|
||||||
|
|||||||
@@ -25,6 +25,39 @@ def test_default_model_path_contains_profile():
|
|||||||
assert path.endswith(".joblib")
|
assert path.endswith(".joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_multi_layer():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
# Multi-layer models should report concatenated dimension
|
||||||
|
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||||
|
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||||
|
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||||
|
# Single-layer unchanged
|
||||||
|
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||||
|
|
||||||
|
|
||||||
|
def test_ml_config():
|
||||||
|
from core.audio_scan import _ml_config
|
||||||
|
assert _ml_config("HUBERT_XLARGE") is None
|
||||||
|
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||||
|
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||||
|
assert base == "HUBERT_XLARGE"
|
||||||
|
assert layers == [11, 23, 35, 47]
|
||||||
|
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||||
|
assert base == "HUBERT_BASE"
|
||||||
|
assert layers == [2, 5, 8, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
|
||||||
|
|
||||||
def test_db_get_all_export_paths():
|
def test_db_get_all_export_paths():
|
||||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
path = f.name
|
path = f.name
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_folders_excludes_scan_exports():
|
||||||
|
"""Scan-export-only folders should not appear when include_scan_exports=False."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Manual export
|
||||||
|
db.add("a.mp4", 10.0, "/out/mp4_Intense/g1/clip.mp4", profile="test")
|
||||||
|
# Scan export to different folder
|
||||||
|
db.add("a.mp4", 20.0, "/out/mp4_ScanOnly/g1/clip.mp4", profile="test",
|
||||||
|
scan_export=True)
|
||||||
|
folders = db.get_export_folders("test")
|
||||||
|
assert "mp4_Intense" in folders
|
||||||
|
assert "mp4_ScanOnly" not in folders, "scan-only folder should be excluded"
|
||||||
|
# With include_scan_exports=True, both should appear
|
||||||
|
folders_all = db.get_export_folders("test", include_scan_exports=True)
|
||||||
|
assert "mp4_ScanOnly" in folders_all
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_result_history():
|
||||||
|
"""save_scan_results should keep multiple versions."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Save three versions (microsecond-precision timestamps avoid collisions)
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(0, 8, 0.9)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(0, 8, 0.8), (10, 18, 0.7)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(5, 13, 0.95)])
|
||||||
|
versions = db.get_scan_versions("v.mp4", "test", "MODEL_A")
|
||||||
|
assert len(versions) == 3
|
||||||
|
# Most recent first
|
||||||
|
assert versions[0]["count"] == 1 # latest: 1 region
|
||||||
|
assert versions[1]["count"] == 2 # middle: 2 regions
|
||||||
|
assert versions[2]["count"] == 1 # oldest: 1 region
|
||||||
|
# get_scan_results returns latest version by default
|
||||||
|
results = db.get_scan_results("v.mp4", "test")
|
||||||
|
assert len(results.get("MODEL_A", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hard_negatives_source_model():
|
||||||
|
"""Hard negatives should store source_model."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||||
|
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_data_skips_hard_negatives():
|
||||||
|
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Create a source file that "exists" — use the temp db file itself
|
||||||
|
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||||
|
source_path=path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||||
|
# With hard negatives
|
||||||
|
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||||
|
# Without hard negatives
|
||||||
|
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||||
|
assert len(data_with) >= 1
|
||||||
|
# The "with" case should have the hard negative time in neg list
|
||||||
|
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||||
|
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||||
|
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hard_negatives_by_ids():
|
||||||
|
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||||
|
source_path="/a.mp4")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 3
|
||||||
|
# Delete first two
|
||||||
|
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||||
|
remaining = db.get_hard_negatives("test")
|
||||||
|
assert len(remaining) == 1
|
||||||
|
assert remaining[0]["start_time"] == 30.0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
+23
-23
@@ -5,21 +5,21 @@ from main import ProcessedDB
|
|||||||
|
|
||||||
|
|
||||||
def test_build_export_path_first():
|
def test_build_export_path_first():
|
||||||
assert build_export_path("/out", "clip", 1) == "/out/clip_001/clip_001.mp4"
|
assert build_export_path("/out", "clip", 1) == "/out/clip_001.mp4"
|
||||||
|
|
||||||
def test_build_export_path_counter():
|
def test_build_export_path_counter():
|
||||||
assert build_export_path("/out", "clip", 42) == "/out/clip_042/clip_042.mp4"
|
assert build_export_path("/out", "clip", 42) == "/out/clip_042.mp4"
|
||||||
|
|
||||||
def test_build_export_path_deep_counter():
|
def test_build_export_path_deep_counter():
|
||||||
assert build_export_path("/out", "shot", 999) == "/out/shot_999/shot_999.mp4"
|
assert build_export_path("/out", "shot", 999) == "/out/shot_999.mp4"
|
||||||
|
|
||||||
def test_build_export_path_sub():
|
def test_build_export_path_sub():
|
||||||
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0.mp4"
|
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001_0.mp4"
|
||||||
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001/clip_001_2.mp4"
|
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001_2.mp4"
|
||||||
|
|
||||||
def test_build_sequence_dir_sub():
|
def test_build_sequence_dir_sub():
|
||||||
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0"
|
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001_0"
|
||||||
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001/clip_001_1"
|
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001_1"
|
||||||
|
|
||||||
def test_format_time_seconds():
|
def test_format_time_seconds():
|
||||||
assert format_time(0.0) == "0:00.0"
|
assert format_time(0.0) == "0:00.0"
|
||||||
@@ -178,10 +178,10 @@ def test_audio_extract_timing():
|
|||||||
|
|
||||||
|
|
||||||
def test_build_sequence_dir_basic():
|
def test_build_sequence_dir_basic():
|
||||||
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001/clip_001"
|
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001"
|
||||||
|
|
||||||
def test_build_sequence_dir_counter():
|
def test_build_sequence_dir_counter():
|
||||||
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042/clip_042"
|
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042"
|
||||||
|
|
||||||
def test_ffmpeg_command_image_sequence():
|
def test_ffmpeg_command_image_sequence():
|
||||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/seq_001", image_sequence=True)
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/seq_001", image_sequence=True)
|
||||||
@@ -265,13 +265,13 @@ def test_db_get_group_returns_all_sub_clips():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_2.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_2.mp4")
|
||||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(group) == 3
|
assert len(group) == 3
|
||||||
assert "/out/clip_001/clip_001_0.mp4" in group
|
assert "/out/vid_001/clip_001_0.mp4" in group
|
||||||
assert "/out/clip_001/clip_001_2.mp4" in group
|
assert "/out/vid_001/clip_001_2.mp4" in group
|
||||||
finally:
|
finally:
|
||||||
os.unlink(path)
|
os.unlink(path)
|
||||||
|
|
||||||
@@ -281,10 +281,10 @@ def test_db_get_group_isolates_by_start_time():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(group) == 2
|
assert len(group) == 2
|
||||||
finally:
|
finally:
|
||||||
os.unlink(path)
|
os.unlink(path)
|
||||||
@@ -295,10 +295,10 @@ def test_db_delete_group_removes_all():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||||
deleted = db.delete_group("/out/clip_001/clip_001_0.mp4")
|
deleted = db.delete_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(deleted) == 2
|
assert len(deleted) == 2
|
||||||
# clip_002 should still exist
|
# clip_002 should still exist
|
||||||
markers = db.get_markers("video.mp4")
|
markers = db.get_markers("video.mp4")
|
||||||
|
|||||||
Reference in New Issue
Block a user