16 Commits

Author SHA1 Message Date
Ethanfel 12ed183f1b feat: integrate training UI, BEATs model, and clean up legacy code
- Remove legacy distance-mode scanning (build_profile, _similarity, etc.)
  and hand-crafted intensity features — pipeline is now embedding-only
- Integrate Microsoft BEATs as embedding option alongside wav2vec2/HuBERT
- Add TrainDialog with positive class selector, model picker, video dir
  fallback, and live training stats
- Add TrainWorker QThread with cancel support and proper lifecycle cleanup
- Add source_path column to DB for robust source video tracking
- Add get_export_folders/get_training_data/get_training_stats to DB
- Wire source_path in all export DB writes (_on_clip_done, _on_auto_clip_done)
- Cancel scan/train workers in closeEvent to prevent use-after-free crashes
- Add setup_env.sh supporting both conda and python venv (CUDA 12.8)
- Update requirements.txt with all actual dependencies
- Update 8cut_train.py with --positive flag for new DB-driven training

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-18 11:52:27 +02:00
Ethanfel f2c38aee79 feat: rewrite audio scan with MFCC+delta+spectral contrast pipeline
Root cause of poor discrimination: MFCC[0] (energy) dominated the
feature vector, making cosine similarity see all audio as similar.

Changes:
- Skip MFCC[0], use 12 coefficients instead of 20
- Add delta MFCCs for temporal dynamics
- Add 7-band spectral contrast for tonal vs noise quality
- Switch from cosine similarity to euclidean-distance-based score
- Pre-compute STFT once for whole file (10-20x faster)
- Vectorized sliding window via cumulative sums (no Python loop)
- Lower sample rate 22050→16000 Hz (faster, no quality loss)
- 62-dim feature vector (was 40-dim mean+std of raw MFCCs)
- Default threshold 0.05 (new similarity scale)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 15:28:44 +02:00
Ethanfel 8ab5bdba77 fix: use mean+std MFCC vectors (40-dim) for better discrimination
Mean-only vectors were too similar across different audio segments,
causing everything to match even at threshold 0.99. Adding std
captures temporal dynamics and makes the similarity scores much
more spread out.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 09:27:11 +02:00
Ethanfel c6c5934fe8 fix: threshold step 0.05 → 0.01 for finer control
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 09:21:14 +02:00
Ethanfel 73d5367424 fix: three audio scan bugs — signal shadow, re-entrancy, S-key jump
1. Rename ScanWorker.finished → scan_done to stop shadowing
   QThread.finished. Previously, cancelled scans leaked the QThread
   because the custom signal was never emitted.

2. Block signals on combobox reset in _on_scan_ref_changed to
   prevent re-entrant call when user cancels folder dialog.

3. Merge overlapping scan regions into clusters before S-key
   navigation so it jumps to the next distinct match, not 1s forward
   through overlapping windows.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 09:12:24 +02:00
Ethanfel 1e2cebd424 fix: prevent deleteLater on still-running ScanWorker QThread
When cancelling a scan during file change, connect finished signal
to deleteLater instead of calling it immediately on a running thread.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 09:02:35 +02:00
Ethanfel c439aca9b9 feat: add S shortcut and clear scan on file change
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:59:47 +02:00
Ethanfel afda9b2d9f feat: add scan UI controls and start_scan handler
Add Scan button, threshold spinner, mode combobox, and reference source
combobox to the settings row. Implement handler methods for starting scans,
handling results/errors, cleanup of workers, and reference folder selection.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:57:56 +02:00
Ethanfel fd42791c9f feat: add get_all_export_paths to ProcessedDB
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:55:39 +02:00
Ethanfel 4cf54f2642 feat: add ScanWorker QThread for background scanning
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:54:20 +02:00
Ethanfel e7f4de9ec1 feat: timeline scan region rendering
Add scan region storage and rendering to TimelineWidget:
- _scan_regions list in __init__ for (start, end, score) tuples
- set_scan_regions() and clear_scan_regions() methods
- paintEvent draws semi-transparent blue rectangles with score-based opacity

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:53:18 +02:00
Ethanfel 9cf9e3233f feat: add scan_video with average and nearest modes
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:50:47 +02:00
Ethanfel e17d8f67aa feat: add audio_scan module with build_profile
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:48:18 +02:00
Ethanfel b1980de6d1 fix: 9 bugs in audio scan implementation plan
- Swap Task 5/6 order so get_all_export_paths exists before UI uses it
- Remove cosine similarity clamping to preserve anti-correlation signal
- Use os.path.exists instead of os.path.isfile (handles image sequences)
- Add worker cleanup to disconnect stale signals before new scan
- Remove lock from get_all_export_paths (matches read-only convention)
- Always use get_all_export_paths for Current Profile (not current-file-first)
- Filter export paths with os.path.exists for deleted files
- Use abs() for float comparison in tests instead of ==
- Add cancel_flag to ScanWorker and scan_video for interruptible scans

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:43:53 +02:00
Ethanfel 85e0641440 docs: add audio scan implementation plan
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:36:56 +02:00
Ethanfel 834b89b682 docs: add audio similarity scanning design
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-17 08:33:25 +02:00
13 changed files with 3660 additions and 7 deletions
+255
View File
@@ -0,0 +1,255 @@
#!/usr/bin/env python3
"""Calibration — per-video normalized features + classifier."""
import sys, os, time, warnings
sys.path.insert(0, os.path.dirname(__file__))
warnings.filterwarnings("ignore")
import numpy as np
import librosa
from sklearn.ensemble import GradientBoostingClassifier
from core.audio_scan import _SR, _WINDOW
_HOP_LENGTH = 1024
_N_FFT = 2048
from core.db import ProcessedDB
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
PROFILE_NAME = "JAV_missionary"
TOLERANCE = 12.0
NEG_MARGIN = 120.0
def extract_rich_features(y, sr=_SR):
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
hop = _HOP_LENGTH
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
rms = librosa.feature.rms(S=S, hop_length=hop)
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
flatness = librosa.feature.spectral_flatness(S=S)
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
band_feats = []
for flo, fhi in bands:
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
if mask.sum() > 0:
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
else:
band_feats.append(np.zeros((1, mel_S.shape[1])))
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
band_feats[0].shape[1])
return np.vstack([
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
] + [b[:, :min_t] for b in band_feats]
+ [sc[:, :min_t]])
def compute_window_stats(feat, hop=1.0):
"""Sliding window mean/std → (timestamps, feature_vectors)."""
n_feats, T = feat.shape
fps = _SR / _HOP_LENGTH
win_frames = int(_WINDOW * fps)
hop_frames = int(hop * fps)
if win_frames > T:
return np.array([]), np.array([])
cumsum = np.zeros((n_feats, T + 1))
cumsum[:, 1:] = np.cumsum(feat, axis=1)
cumsq = np.zeros((n_feats, T + 1))
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
starts = np.arange(0, T - win_frames + 1, hop_frames)
ends = starts + win_frames
sums = cumsum[:, ends] - cumsum[:, starts]
sq_sums = cumsq[:, ends] - cumsq[:, starts]
means = sums / win_frames
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
return starts / fps, np.vstack([means, stds]).T
def label_windows(timestamps, gt_intense, gt_soft):
all_gt = list(gt_intense) + list(gt_soft)
labels = np.zeros(len(timestamps), dtype=int)
for i, t in enumerate(timestamps):
di = min((abs(t - g) for g in gt_intense), default=9999)
da = min((abs(t - g) for g in all_gt), default=9999)
if di < TOLERANCE:
labels[i] = 1
elif da > NEG_MARGIN:
labels[i] = -1
return labels
def main():
db = ProcessedDB()
rows = db._con.execute(
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
(PROFILE_NAME,),
).fetchall()
intense_by_video, soft_by_video = {}, {}
for fn, st, op in rows:
if '/mp4_Intense/' in op:
intense_by_video.setdefault(fn, set()).add(st)
elif '/mp4_Soft/' in op:
soft_by_video.setdefault(fn, set()).add(st)
videos = [fn for fn in intense_by_video
if os.path.exists(os.path.join(PLEX_DIR, fn))]
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
videos = videos[:n_vids]
print(f"Processing {len(videos)} videos...")
all_data_raw = [] # raw features
all_data_norm = [] # per-video z-scored features
for vi, vname in enumerate(videos):
vpath = os.path.join(PLEX_DIR, vname)
gt_intense = sorted(intense_by_video.get(vname, set()))
gt_soft = sorted(soft_by_video.get(vname, set()))
t0 = time.time()
y, _ = librosa.load(vpath, sr=_SR, mono=True)
feat = extract_rich_features(y)
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
dt = time.time() - t0
if len(timestamps) == 0:
continue
labels = label_windows(timestamps, gt_intense, gt_soft)
# Per-video z-score normalization
vid_mean = window_vectors.mean(axis=0)
vid_std = window_vectors.std(axis=0)
vid_std = np.maximum(vid_std, 1e-6)
normed = (window_vectors - vid_mean) / vid_std
n_pos = (labels == 1).sum()
n_neg = (labels == -1).sum()
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
all_data_norm.append((vi, vname, timestamps, normed, labels))
# Run CV for both raw and normalized
for label, data in [("RAW features", all_data_raw),
("PER-VIDEO NORMALIZED features", all_data_norm)]:
print(f"\n{'='*70}")
print(f" {label}")
print(f"{'='*70}")
all_y_true, all_y_prob = [], []
for test_idx in range(len(data)):
_, vname, _, test_X, test_labels = data[test_idx]
test_mask = test_labels != 0
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
continue
X_test = test_X[test_mask]
y_test = (test_labels[test_mask] == 1).astype(int)
X_parts, y_parts = [], []
for i, (_, _, _, feats, labs) in enumerate(data):
if i == test_idx:
continue
m = labs != 0
if m.sum() == 0:
continue
X_parts.append(feats[m])
y_parts.append((labs[m] == 1).astype(int))
if not X_parts:
continue
X_train = np.vstack(X_parts)
y_train = np.concatenate(y_parts)
pos_idx = np.where(y_train == 1)[0]
neg_idx = np.where(y_train == 0)[0]
if len(pos_idx) == 0 or len(neg_idx) == 0:
continue
rng = np.random.RandomState(42)
n_neg = min(len(neg_idx), len(pos_idx) * 3)
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
train_idx = np.concatenate([pos_idx, neg_sample])
clf = GradientBoostingClassifier(
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
)
clf.fit(X_train[train_idx], y_train[train_idx])
probs = clf.predict_proba(X_test)[:, 1]
tp = ((probs >= 0.5) & (y_test == 1)).sum()
fp = ((probs >= 0.5) & (y_test == 0)).sum()
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
all_y_true.extend(y_test)
all_y_prob.extend(probs)
if not all_y_true:
print(" No test results.")
continue
y_true = np.array(all_y_true)
y_prob = np.array(all_y_prob)
pos_probs = y_prob[y_true == 1]
neg_probs = y_prob[y_true == 0]
if len(pos_probs) > 0 and len(neg_probs) > 0:
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
best_f1, best_thr = 0, 0
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
for thr in np.arange(0.10, 0.91, 0.05):
tp = ((y_prob >= thr) & (y_true == 1)).sum()
fp = ((y_prob >= thr) & (y_true == 0)).sum()
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
if f1 > best_f1:
best_f1, best_thr = f1, thr
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
# Feature importance
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
pos_idx = np.where(y_all == 1)[0]
neg_idx = np.where(y_all == 0)[0]
rng = np.random.RandomState(42)
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
feat_names = (
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
+ [f"mel{i}" for i in range(8)]
+ [f"sc{i}" for i in range(7)]
)
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
imp = clf.feature_importances_
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
if __name__ == "__main__":
main()
+92
View File
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Train an audio scan classifier from DB ground truth.
Usage:
python 8cut_train.py # default model, auto-detect positive
python 8cut_train.py --model BEATS # specific embedding model
python 8cut_train.py --positive mp4_Intense # explicit positive folder
python 8cut_train.py --positive mp4_Intense --model BEATS # both
"""
import sys, os, warnings
sys.path.insert(0, os.path.dirname(__file__))
warnings.filterwarnings("ignore")
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
from core.db import ProcessedDB
PROFILE_NAME = "JAV_missionary"
# Fallback for old DB rows without source_path
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
def main():
embed_model = None
if "--model" in sys.argv:
idx = sys.argv.index("--model")
if idx + 1 < len(sys.argv):
embed_model = sys.argv[idx + 1]
if embed_model not in _EMBED_MODELS:
print(f"Unknown model: {embed_model}")
print(f"Available: {', '.join(_EMBED_MODELS)}")
sys.exit(1)
positive_suffix = None
if "--positive" in sys.argv:
idx = sys.argv.index("--positive")
if idx + 1 < len(sys.argv):
positive_suffix = sys.argv[idx + 1]
db = ProcessedDB()
# If --positive given, use the new DB helper
if positive_suffix:
video_infos = db.get_training_data(
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
)
if not video_infos:
print(f"No training data found for positive='{positive_suffix}'")
sys.exit(1)
else:
# Legacy fallback: classify by folder path pattern
rows = db._con.execute(
"SELECT filename, start_time, output_path, source_path"
" FROM processed WHERE profile = ?",
(PROFILE_NAME,),
).fetchall()
intense_by_video, soft_by_video = {}, {}
source_by_fn = {}
for fn, st, op, sp in rows:
if sp:
source_by_fn[fn] = sp
if "/mp4_Intense/" in op or "_Intense/" in op:
intense_by_video.setdefault(fn, set()).add(st)
elif "/mp4_Soft/" in op or "_Soft/" in op:
soft_by_video.setdefault(fn, set()).add(st)
video_infos = []
for fn in intense_by_video:
# Try source_path from DB first, fall back to PLEX_DIR
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
if not os.path.exists(vpath):
print(f" skip (not found): {fn}")
continue
gt_intense = sorted(intense_by_video[fn])
gt_soft = sorted(soft_by_video.get(fn, set()))
video_infos.append((vpath, gt_intense, gt_soft))
label = embed_model or "WAV2VEC2_BASE"
print(f"Training {label} model on {len(video_infos)} videos...")
model_path = default_model_path(PROFILE_NAME)
result = train_classifier(
video_infos, model_path=model_path, embed_model=embed_model,
)
if result is None:
print("Training failed: no valid samples or missing class balance")
sys.exit(1)
print(f"Model saved to {model_path}")
if __name__ == "__main__":
main()
+394
View File
@@ -0,0 +1,394 @@
"""Audio scanning — embedding-based classifier for audio event detection."""
import hashlib
import os
import numpy as np
import librosa
from .paths import _log
_SR = 16000 # lower sr = faster
_WINDOW = 8.0 # seconds
_MODEL_DIR = os.path.join(os.path.expanduser("~"), ".8cut_models")
_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v")
# ---------------------------------------------------------------------------
# Embedding extraction (lazy-loaded)
# ---------------------------------------------------------------------------
_w2v_model = None
_w2v_device = None
_w2v_model_name = None
# Supported embedding models — name → embed_dim
_EMBED_MODELS = {
"WAV2VEC2_BASE": 768,
"WAV2VEC2_LARGE": 1024,
"WAV2VEC2_LARGE_LV60K":1024,
"HUBERT_BASE": 768,
"HUBERT_LARGE": 1024,
"HUBERT_XLARGE": 1280,
"BEATS": 768,
}
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
_BEATS_CHECKPOINT = os.path.join(
os.path.expanduser("~"), ".cache", "huggingface", "hub",
"models--lpepino--beats_ckpts", "snapshots",
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
)
def _get_w2v_model(model_name: str | None = None):
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
global _w2v_model, _w2v_device, _w2v_model_name
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
if _w2v_model is None or _w2v_model_name != model_name:
import torch
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "BEATS":
from .beats_model import BEATs, BEATsConfig
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
weights_only=False)
cfg = BEATsConfig(checkpoint['cfg'])
_w2v_model = BEATs(cfg)
_w2v_model.load_state_dict(checkpoint['model'])
_w2v_model.to(_w2v_device)
else:
import torchaudio
bundle = getattr(torchaudio.pipelines, model_name)
_w2v_model = bundle.get_model().to(_w2v_device)
_w2v_model.eval()
_w2v_model_name = model_name
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
return _w2v_model, _w2v_device
def _embed_dim(model_name: str | None = None) -> int:
"""Return embedding dimension for a model name."""
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
return _EMBED_MODELS.get(model_name, 768)
def _w2v_cache_path(video_path: str, hop: float, window: float,
model_name: str | None = None) -> str:
"""Return cache file path for a video's embeddings (includes model name)."""
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
abspath = os.path.abspath(video_path)
mtime = os.path.getmtime(abspath)
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
h = hashlib.sha256(key.encode()).hexdigest()[:16]
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
hop: float = 1.0, window: float = _WINDOW,
video_path: str | None = None,
cancel_flag: object = None,
model_name: str | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract embeddings for all sliding windows using a torchaudio model.
If video_path is given, results are cached to disk for fast re-scans.
Returns (timestamps, embeddings) where embeddings is (N, D).
"""
edim = _embed_dim(model_name)
# Try loading from cache
cache_file = None
if video_path:
try:
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
if os.path.exists(cache_file):
data = np.load(cache_file)
_log(f"audio_scan: cache hit ({cache_file})")
return data["timestamps"], data["embeddings"]
except Exception as e:
_log(f"audio_scan: cache read failed: {e}")
win_samples = int(window * sr)
hop_samples = int(hop * sr)
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
if n_windows == 0:
return np.array([]), np.empty((0, edim))
import torch
model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
batch_size = 16
timestamps = np.arange(n_windows) * hop
embeddings = []
for batch_start in range(0, n_windows, batch_size):
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return np.array([]), np.empty((0, edim))
batch_end = min(batch_start + batch_size, n_windows)
chunks = []
for i in range(batch_start, batch_end):
start = i * hop_samples
chunks.append(y[start:start + win_samples])
with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings.append(batch_emb)
result_ts = timestamps
result_emb = np.vstack(embeddings)
# Save to cache
if cache_file:
try:
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
_log(f"audio_scan: w2v cache saved ({cache_file})")
except Exception as e:
_log(f"audio_scan: cache write failed: {e}")
return result_ts, result_emb
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
gt_soft: list[float], tolerance: float = 12.0,
neg_margin: float = 120.0,
model_name: str | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract embeddings only near positives and distant negatives.
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
"""
edim = _embed_dim(model_name)
duration = len(y) / sr
win_samples = int(_WINDOW * sr)
all_gt = list(gt_intense) + list(gt_soft)
# Positive windows: every second near intense markers
pos_times = set()
for gt in gt_intense:
for offset in range(-int(tolerance), int(tolerance) + 1):
t = gt + offset
if 0 <= t <= duration - _WINDOW:
pos_times.add(int(t))
# Negative windows: every 4s, far from any marker
neg_times = set()
for t in range(0, int(duration - _WINDOW), 4):
if min((abs(t - g) for g in all_gt), default=9999) > neg_margin:
neg_times.add(t)
all_times = sorted(pos_times | neg_times)
# Filter out windows that go past the end
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
if not valid_times:
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
import torch
model, device = _get_w2v_model(model_name)
batch_size = 16
timestamps_list: list[float] = []
embeddings_list: list[np.ndarray] = []
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
for batch_start in range(0, len(valid_times), batch_size):
batch_end = min(batch_start + batch_size, len(valid_times))
chunks = []
for t in valid_times[batch_start:batch_end]:
start = int(t * sr)
chunks.append(y[start:start + win_samples])
timestamps_list.append(float(t))
with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings_list.append(batch_emb)
timestamps = np.array(timestamps_list)
embeddings = np.vstack(embeddings_list)
labels = np.zeros(len(timestamps), dtype=int)
for i, t in enumerate(timestamps):
di = min((abs(t - g) for g in gt_intense), default=9999)
da = min((abs(t - g) for g in all_gt), default=9999)
if di < tolerance:
labels[i] = 1
elif da > neg_margin:
labels[i] = -1
return timestamps, embeddings, labels
# ---------------------------------------------------------------------------
# Classifier mode — train / save / load / scan
# ---------------------------------------------------------------------------
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
model_path: str | None = None,
tolerance: float = 12.0,
neg_margin: float = 120.0,
embed_model: str | None = None) -> dict:
"""Train a classifier from labeled videos.
Args:
video_infos: list of (video_path, intense_times, soft_times)
model_path: if given, save model to this path
tolerance/neg_margin: labeling parameters
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
Returns:
dict with 'classifier', 'embed_model', and metadata, or None on failure.
"""
from sklearn.ensemble import GradientBoostingClassifier
all_X, all_y = [], []
for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos):
_log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}")
y, _ = librosa.load(vpath, sr=_SR, mono=True)
timestamps, embeddings, labels = _extract_w2v_targeted(
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
model_name=embed_model,
)
if len(timestamps) == 0:
continue
# Per-video z-score normalize
vid_mean = embeddings.mean(axis=0)
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
normed = (embeddings - vid_mean) / vid_std
for i in range(len(labels)):
if labels[i] == 1:
all_X.append(normed[i])
all_y.append(1)
elif labels[i] == -1:
all_X.append(normed[i])
all_y.append(0)
if not all_X:
_log("audio_scan: no training samples collected")
return None
X = np.stack(all_X)
y_arr = np.array(all_y)
n_pos = (y_arr == 1).sum()
n_neg = (y_arr == 0).sum()
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
if n_pos == 0 or n_neg == 0:
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
return None
# Subsample negatives for balance
rng = np.random.RandomState(42)
pos_idx = np.where(y_arr == 1)[0]
neg_idx = np.where(y_arr == 0)[0]
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
train_idx = np.concatenate([pos_idx, neg_sample])
rng.shuffle(train_idx)
clf = GradientBoostingClassifier(
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42,
)
clf.fit(X[train_idx], y_arr[train_idx])
_log("audio_scan: classifier trained")
model = {"classifier": clf, "n_features": X.shape[1],
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
if model_path:
import joblib
parent = os.path.dirname(model_path)
if parent:
os.makedirs(parent, exist_ok=True)
joblib.dump(model, model_path)
_log(f"audio_scan: model saved to {model_path}")
return model
def load_classifier(model_path: str) -> dict | None:
"""Load a saved classifier model."""
if not os.path.exists(model_path):
return None
import joblib
return joblib.load(model_path)
def default_model_path(profile_name: str = "default") -> str:
"""Return the default path for a profile's classifier model."""
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
# ---------------------------------------------------------------------------
# Scanning
# ---------------------------------------------------------------------------
def scan_video(
video_path: str,
model: dict = None,
threshold: float = 0.30,
hop: float = 1.0,
window: float = _WINDOW,
cancel_flag: object = None,
) -> list[tuple[float, float, float]]:
"""Scan a video for matching audio regions using a trained classifier.
Returns list of (start_time, end_time, score) above threshold.
"""
if model is None:
_log("audio_scan: no model provided")
return []
_log(f"audio_scan: loading {video_path}")
y, sr = librosa.load(video_path, sr=_SR, mono=True)
duration = len(y) / sr
_log(f"audio_scan: {duration:.1f}s loaded, extracting features...")
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return []
clf = model["classifier"]
embed_model = model.get("embed_model")
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
timestamps, window_vectors = _extract_w2v_windows(
y, sr, hop=hop, window=window, video_path=video_path,
cancel_flag=cancel_flag, model_name=embed_model,
)
if len(timestamps) == 0:
_log("audio_scan: video shorter than window")
return []
# Per-video z-score normalize
vid_mean = window_vectors.mean(axis=0)
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
normed = (window_vectors - vid_mean) / vid_std
_log(f"audio_scan: classifying {len(normed)} windows...")
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return []
probs = clf.predict_proba(normed)[:, 1]
mask = probs >= threshold
results = [
(timestamps[i], timestamps[i] + window, float(probs[i]))
for i in np.nonzero(mask)[0]
]
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
return results
+783
View File
@@ -0,0 +1,783 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import numpy as np
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn import LayerNorm, Parameter
from .beats_modules import (
GradMultiply,
SamePad,
get_activation_fn,
GLU_Linear,
quant_noise,
)
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
deep_norm=args.deep_norm,
has_relative_attention_bias=self.relative_position_embedding,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
encoder_layers=args.encoder_layers,
)
for i in range(args.encoder_layers)
]
)
if self.relative_position_embedding:
for i in range(1, args.encoder_layers):
del self.layers[i].self_attn.relative_attention_bias
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
if args.deep_norm:
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
for i in range(args.encoder_layers):
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
if self.layer_wise_gradient_decay_ratio != 1.0:
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
deep_norm: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
encoder_layers: int = 0,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = LayerNorm(self.embedding_dim)
self.deep_norm = deep_norm
if self.deep_norm:
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
else:
self.deep_norm_alpha = 1
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None
):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual * self.deep_norm_alpha + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual * self.deep_norm_alpha + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position,
bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
alpha = 32
q *= 1 / alpha
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
attn_weights = attn_weights + attn_mask_rel_pos
attn_weights_float = F.softmax(
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
+179
View File
@@ -0,0 +1,179 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import torch
import torch.nn as nn
from torch.nn import LayerNorm
import torchaudio.compliance.kaldi as ta_kaldi
from .beats_backbone import (
TransformerEncoder,
)
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class BEATsConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = -1 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
# label predictor
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
self.predictor_class: int = 527 # target class number for the predictor
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class BEATs(nn.Module):
def __init__(
self,
cfg: BEATsConfig,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
if cfg.finetuned_model:
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
else:
self.predictor = None
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
if self.predictor is not None:
x = self.predictor_dropout(x)
logits = self.predictor(x)
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
else:
logits = logits.mean(dim=1)
lprobs = torch.sigmoid(logits)
return lprobs, padding_mask
else:
return x, padding_mask
+219
View File
@@ -0,0 +1,219 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import warnings
import torch
from torch import Tensor, nn
import torch.nn.functional as F
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
+116 -5
View File
@@ -1,3 +1,4 @@
import os
import sqlite3 import sqlite3
import threading import threading
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -7,7 +8,7 @@ from .paths import _log
class ProcessedDB: class ProcessedDB:
_SCHEMA_VERSION = 3 # bump when schema changes _SCHEMA_VERSION = 4 # bump when schema changes
def __init__(self, db_path: str | None = None): def __init__(self, db_path: str | None = None):
if db_path is None: if db_path is None:
@@ -47,6 +48,7 @@ class ProcessedDB:
" clip_count INTEGER NOT NULL DEFAULT 3," " clip_count INTEGER NOT NULL DEFAULT 3,"
" spread REAL NOT NULL DEFAULT 3.0," " spread REAL NOT NULL DEFAULT 3.0,"
" profile TEXT NOT NULL DEFAULT 'default'," " profile TEXT NOT NULL DEFAULT 'default',"
" source_path TEXT NOT NULL DEFAULT '',"
" processed_at TEXT NOT NULL" " processed_at TEXT NOT NULL"
")" ")"
) )
@@ -62,6 +64,7 @@ class ProcessedDB:
"clip_count": "INTEGER NOT NULL DEFAULT 3", "clip_count": "INTEGER NOT NULL DEFAULT 3",
"spread": "REAL NOT NULL DEFAULT 3.0", "spread": "REAL NOT NULL DEFAULT 3.0",
"profile": "TEXT NOT NULL DEFAULT 'default'", "profile": "TEXT NOT NULL DEFAULT 'default'",
"source_path": "TEXT NOT NULL DEFAULT ''",
} }
for col, typedef in new_cols.items(): for col, typedef in new_cols.items():
if col not in cols: if col not in cols:
@@ -85,7 +88,7 @@ class ProcessedDB:
short_side: int | None = None, portrait_ratio: str = "", short_side: int | None = None, portrait_ratio: str = "",
crop_center: float = 0.5, fmt: str = "MP4", crop_center: float = 0.5, fmt: str = "MP4",
clip_count: int = 3, spread: float = 3.0, clip_count: int = 3, spread: float = 3.0,
profile: str = "default") -> None: profile: str = "default", source_path: str = "") -> None:
if not self._enabled: if not self._enabled:
return return
with self._lock: with self._lock:
@@ -93,11 +96,11 @@ class ProcessedDB:
"INSERT INTO processed" "INSERT INTO processed"
" (filename, start_time, output_path, label, category," " (filename, start_time, output_path, label, category,"
" short_side, portrait_ratio, crop_center, format," " short_side, portrait_ratio, crop_center, format,"
" clip_count, spread, profile, processed_at)" " clip_count, spread, profile, source_path, processed_at)"
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(filename, start_time, output_path, label, category, (filename, start_time, output_path, label, category,
short_side, portrait_ratio, crop_center, fmt, short_side, portrait_ratio, crop_center, fmt,
clip_count, spread, profile, clip_count, spread, profile, source_path,
datetime.now(timezone.utc).isoformat()), datetime.now(timezone.utc).isoformat()),
) )
self._con.commit() self._con.commit()
@@ -213,6 +216,114 @@ class ProcessedDB:
).fetchall() ).fetchall()
return [r[0] for r in rows] return [r[0] for r in rows]
def get_all_export_paths(self, profile: str = "default") -> list[str]:
"""Return all unique output_path values for a given profile."""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
return [r[0] for r in rows]
def get_export_folders(self, profile: str = "default") -> list[str]:
"""Return distinct export folder names found in output_paths for a profile.
Export paths follow the structure:
.../export_folder/group_dir/clip.mp4
The export folder is 2 levels up from the clip file.
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
"""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
folder_names: set[str] = set()
for (op,) in rows:
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent:
folder_names.add(grandparent)
return sorted(folder_names)
def get_training_data(self, profile: str, positive_folder: str,
fallback_video_dir: str = "",
) -> list[tuple[str, list[float], list[float]]]:
"""Build training video_infos from DB data.
Args:
profile: profile name
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
fallback_video_dir: if source_path is empty, try filename in this dir
Returns:
list of (source_video_path, positive_times, soft_times) per video.
Soft times = clips from any other export folder.
"""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT filename, start_time, output_path, source_path"
" FROM processed WHERE profile = ?",
(profile,),
).fetchall()
# Collect times by video, split by positive vs other folders
pos_by_video: dict[str, set[float]] = {}
soft_by_video: dict[str, set[float]] = {}
source_by_filename: dict[str, str] = {}
for fn, st, op, sp in rows:
if sp:
source_by_filename[fn] = sp
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent == positive_folder:
pos_by_video.setdefault(fn, set()).add(st)
else:
soft_by_video.setdefault(fn, set()).add(st)
result = []
for fn in pos_by_video:
sp = source_by_filename.get(fn, "")
if not sp or not os.path.exists(sp):
# Fallback: try video_dir / filename
if fallback_video_dir:
sp = os.path.join(fallback_video_dir, fn)
if not sp or not os.path.exists(sp):
continue
gt_pos = sorted(pos_by_video[fn])
gt_soft = sorted(soft_by_video.get(fn, set()))
result.append((sp, gt_pos, gt_soft))
return result
def get_training_stats(self, profile: str) -> dict[str, dict]:
"""Return per-subprofile stats for training readiness display.
Returns dict mapping subprofile_name → {
'videos': number of distinct source videos,
'clips': total clip count,
}
"""
if not self._enabled:
return {}
rows = self._con.execute(
"SELECT filename, output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
folders = self.get_export_folders(profile)
stats: dict[str, dict] = {}
for folder_name in folders:
videos: set[str] = set()
clips = 0
for fn, op in rows:
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent == folder_name:
videos.add(fn)
clips += 1
stats[folder_name] = {"videos": len(videos), "clips": clips}
return stats
def hide_file(self, filename: str, profile: str = "default") -> None: def hide_file(self, filename: str, profile: str = "default") -> None:
if not self._enabled: if not self._enabled:
return return
@@ -0,0 +1,97 @@
# Audio Similarity Scanning — Design
**Goal:** Scan a video's audio track and highlight segments that match the sound profile of existing reference clips, so the user can quickly find similar moments without scrubbing manually.
**Runs in:** Python/Qt client (`main.py`), not the server.
---
## Core Module: `core/audio_scan.py`
New module alongside `core/tracking.py`. Two main functions:
- `build_profile(clip_paths: list[str]) -> dict` — extracts MFCCs (20 coefficients) from each clip using `librosa`, returns a profile containing both the averaged vector and individual clip vectors.
- `scan_video(video_path: str, profile: dict, mode: str, threshold: float, hop: float) -> list[tuple[float, float, float]]` — slides an 8s window across the video's audio, returns `(start_time, end_time, score)` tuples for segments above threshold.
### Feature Extraction
- Audio loaded via `librosa.load()` (handles video files directly, mono, 22050Hz).
- MFCCs: `librosa.feature.mfcc(n_mfcc=20)`, averaged over time axis to produce a single vector per window/clip.
- Similarity: cosine similarity (`numpy` dot product on L2-normalized vectors).
### Matching Modes
- **Average mode:** Compare each window to the mean of all reference MFCC vectors. Fast, good when references are homogeneous.
- **Nearest mode:** Compare each window to every reference vector, take the max score. Better when references have variety within the style.
### Parameters
- `threshold` (float, 0.01.0): minimum cosine similarity to include a segment. Default 0.7.
- `hop` (float, seconds): step size for the sliding window. Default 1.0s.
- Window size fixed at 8s to match reference clip length.
---
## UI Integration in `main.py`
### Controls
Added near the existing tracking checkbox area:
- **"Scan" button** — triggers audio scan on current video.
- **Threshold slider** (0.01.0, step 0.05) — controls match strictness.
- **Mode combobox** — "Average" / "Nearest".
- **Reference source combobox** — "Current Profile" / "Custom Folder" (shows folder picker when "Custom Folder" selected).
### Scan Workflow
1. User clicks Scan.
2. Reference clips collected: either all export `output_path` values from the current profile (via DB) or all audio/video files in a custom folder.
3. Scan runs in a `QThread` so UI stays responsive.
4. On completion, results sent to Timeline widget via signal.
### Timeline Display
- New `set_scan_regions(regions: list[tuple[float, float, float]])` method on Timeline.
- Drawn as semi-transparent colored rectangles behind existing markers.
- Color intensity proportional to score (brighter = higher match).
- Cleared on file change or re-scan.
### Keyboard Shortcut
- `S` — jump cursor to the next scan region (similar to `M` for next marker).
---
## Data Flow
```
Reference clips (DB export paths or folder)
|
librosa.load() each -> MFCC vectors (20-dim)
|
Profile: { mean_vector, clip_vectors[] }
|
Current video -> librosa.load() full audio (mono 22050Hz)
|
Sliding 8s window (hop=1s) -> MFCC per window
|
Cosine similarity vs profile -> score per position
|
Threshold filter -> [(start, end, score), ...]
|
Timeline: semi-transparent highlight regions
```
## Performance
- 2-hour video at 22050Hz mono ~ 380MB memory.
- MFCC extraction + sliding window: ~10-30s.
- QThread keeps UI responsive.
## What This Does NOT Do
- No DB schema changes — scan results are ephemeral (visual only).
- No auto-export — user decides what to cut.
- No server integration — runs entirely in the Python client.
- No GPU/ML model dependency — just librosa + numpy.
@@ -0,0 +1,739 @@
# Audio Similarity Scanning — Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** Scan a video's audio track to find segments matching a reference sound profile, displayed as highlighted regions on the timeline.
**Architecture:** New `core/audio_scan.py` module extracts MFCC features from reference clips and slides an 8s window across the target video's audio, scoring each position via cosine similarity. A `ScanWorker` QThread runs the scan in the background, and results are drawn as semi-transparent rectangles on the existing Timeline widget.
**Tech Stack:** Python 3, librosa 0.11, numpy, PyQt6
---
### Task 1: Core audio_scan module — build_profile
**Files:**
- Create: `core/audio_scan.py`
- Create: `tests/test_audio_scan.py`
**Step 1: Write the tests**
```python
# tests/test_audio_scan.py
import tempfile, os
import numpy as np
from core.audio_scan import build_profile, _extract_mfcc
def _make_wav(path: str, duration: float = 8.0, sr: int = 22050):
"""Create a short sine-wave WAV file for testing."""
import soundfile as sf
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
audio = 0.5 * np.sin(2 * np.pi * 440 * t)
sf.write(path, audio, sr)
def test_extract_mfcc_returns_1d_vector():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
vec = _extract_mfcc(f.name)
assert vec.shape == (20,)
assert not np.isnan(vec).any()
finally:
os.unlink(f.name)
def test_build_profile_single_clip():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
profile = build_profile([f.name])
assert "mean_vector" in profile
assert "clip_vectors" in profile
assert profile["mean_vector"].shape == (20,)
assert len(profile["clip_vectors"]) == 1
finally:
os.unlink(f.name)
def test_build_profile_multiple_clips():
paths = []
try:
for i in range(3):
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
freq = 440 + i * 200
import soundfile as sf
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
sf.write(f.name, 0.5 * np.sin(2 * np.pi * freq * t), 22050)
paths.append(f.name)
f.close()
profile = build_profile(paths)
assert len(profile["clip_vectors"]) == 3
assert profile["mean_vector"].shape == (20,)
finally:
for p in paths:
os.unlink(p)
def test_build_profile_skips_missing_files():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
profile = build_profile([f.name, "/no/such/file.wav"])
assert len(profile["clip_vectors"]) == 1
finally:
os.unlink(f.name)
def test_build_profile_empty_returns_none():
result = build_profile([])
assert result is None
```
**Step 2: Run tests to verify they fail**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
Expected: FAIL with `ModuleNotFoundError: No module named 'core.audio_scan'`
**Step 3: Write the implementation**
```python
# core/audio_scan.py
"""Audio similarity scanning — MFCC-based profile matching."""
import numpy as np
import librosa
from .paths import _log
_N_MFCC = 20
_SR = 22050
def _extract_mfcc(path: str, sr: int = _SR) -> np.ndarray:
"""Load audio from a file and return a mean MFCC vector (20-dim)."""
y, _ = librosa.load(path, sr=sr, mono=True)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=_N_MFCC)
return mfcc.mean(axis=1) # average over time → (20,)
def build_profile(clip_paths: list[str]) -> dict | None:
"""Extract MFCCs from reference clips.
Returns dict with:
- mean_vector: averaged MFCC across all clips (20,)
- clip_vectors: list of individual MFCC vectors
Returns None if no clips could be loaded.
"""
vectors = []
for p in clip_paths:
try:
vec = _extract_mfcc(p)
vectors.append(vec)
except Exception as e:
_log(f"audio_scan: skip {p}: {e}")
if not vectors:
return None
arr = np.stack(vectors)
return {
"mean_vector": arr.mean(axis=0),
"clip_vectors": vectors,
}
```
**Step 4: Run tests to verify they pass**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
Expected: all 5 PASS
**Step 5: Commit**
```bash
git add core/audio_scan.py tests/test_audio_scan.py
git commit -m "feat: add audio_scan module with build_profile"
```
---
### Task 2: Core audio_scan module — scan_video
**Files:**
- Modify: `core/audio_scan.py`
- Modify: `tests/test_audio_scan.py`
**Step 1: Write the tests**
Add to `tests/test_audio_scan.py`:
```python
from core.audio_scan import scan_video
def test_scan_video_finds_matching_region():
"""A video made of the same sine wave as the reference should match."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
_make_wav(vid.name, duration=20.0)
try:
profile = build_profile([ref.name])
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0)
assert len(regions) > 0
for start, end, score in regions:
assert abs((end - start) - 8.0) < 1e-9
assert score >= 0.5
assert score >= 0.5
finally:
os.unlink(ref.name)
os.unlink(vid.name)
def test_scan_video_nearest_mode():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
_make_wav(vid.name, duration=20.0)
try:
profile = build_profile([ref.name])
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.5, hop=1.0)
assert len(regions) > 0
finally:
os.unlink(ref.name)
os.unlink(vid.name)
def test_scan_video_high_threshold_no_match():
"""Different frequencies with very high threshold should not match."""
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
sf.write(ref.name, 0.5 * np.sin(2 * np.pi * 440 * t), 22050)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
# White noise — very different from sine wave
sf.write(vid.name, np.random.randn(22050 * 20).astype(np.float32) * 0.1, 22050)
try:
profile = build_profile([ref.name])
regions = scan_video(vid.name, profile, mode="average", threshold=0.99, hop=1.0)
assert len(regions) == 0
finally:
os.unlink(ref.name)
os.unlink(vid.name)
```
**Step 2: Run tests to verify they fail**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_scan_video_finds_matching_region -v`
Expected: FAIL with `ImportError: cannot import name 'scan_video'`
**Step 3: Write the implementation**
Add to `core/audio_scan.py`:
```python
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Cosine similarity between two vectors.
Returns value in [-1, 1]. Negative means anti-correlated (very
dissimilar). For threshold filtering this is fine — negative scores
never exceed the threshold. Scores near 0 may be uncorrelated or
weakly anti-correlated.
"""
na = np.linalg.norm(a)
nb = np.linalg.norm(b)
if na == 0 or nb == 0:
return 0.0
return float(np.dot(a, b) / (na * nb))
def scan_video(
video_path: str,
profile: dict,
mode: str = "average",
threshold: float = 0.7,
hop: float = 1.0,
window: float = 8.0,
cancel_flag: object = None,
) -> list[tuple[float, float, float]]:
"""Slide a window across the video audio and score against the profile.
Args:
video_path: path to video/audio file
profile: dict from build_profile()
mode: "average" (compare to mean) or "nearest" (max over all clips)
threshold: minimum cosine similarity to include
hop: step size in seconds
window: window size in seconds (default 8s)
cancel_flag: object with _cancel bool attribute; checked each iteration
Returns:
list of (start_time, end_time, score) for regions above threshold
"""
_log(f"audio_scan: loading {video_path}")
y, sr = librosa.load(video_path, sr=_SR, mono=True)
duration = len(y) / sr
_log(f"audio_scan: {duration:.1f}s loaded, scanning with hop={hop}s")
win_samples = int(window * sr)
hop_samples = int(hop * sr)
results = []
pos = 0
while pos + win_samples <= len(y):
if cancel_flag and getattr(cancel_flag, '_cancel', False):
_log("audio_scan: cancelled")
return results
chunk = y[pos : pos + win_samples]
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=_N_MFCC)
vec = mfcc.mean(axis=1)
if mode == "nearest":
score = max(
_cosine_similarity(vec, cv) for cv in profile["clip_vectors"]
)
else: # average
score = _cosine_similarity(vec, profile["mean_vector"])
if score >= threshold:
start_t = pos / sr
results.append((start_t, start_t + window, score))
pos += hop_samples
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
return results
```
**Step 4: Run tests to verify they pass**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
Expected: all 8 PASS
**Step 5: Commit**
```bash
git add core/audio_scan.py tests/test_audio_scan.py
git commit -m "feat: add scan_video with average and nearest modes"
```
---
### Task 3: Timeline — draw scan regions
**Files:**
- Modify: `main.py` (Timeline class, around lines 209-260 and 300-375)
**Step 1: Add scan region storage to Timeline.__init__**
In `main.py`, find the Timeline class `__init__` method (around line 198). After `self._markers` initialization (line 209), add:
```python
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
```
**Step 2: Add set_scan_regions method**
After the `set_markers` method (line 249-252), add:
```python
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
"""regions: list of (start_time, end_time, score)"""
self._scan_regions = regions
self.update()
def clear_scan_regions(self) -> None:
self._scan_regions = []
self.update()
```
**Step 3: Draw scan regions in paintEvent**
In `paintEvent` (starts around line 282), find the marker drawing section (line 363, comment `# ── export markers`). BEFORE that section, add:
```python
# ── scan regions ──────────────────────────────────────────────
if self._scan_regions and self._duration > 0:
for (start, end, score) in self._scan_regions:
x1 = int(start / self._duration * w)
x2 = int(end / self._duration * w)
alpha = int(40 + score * 80) # 40120 opacity
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
```
**Step 4: Verify manually**
Run: `cd /media/p5/8-cut && python main.py`
Expected: app starts without errors. No scan regions visible yet (none set).
**Step 5: Commit**
```bash
git add main.py
git commit -m "feat: timeline scan region rendering"
```
---
### Task 4: ScanWorker QThread
**Files:**
- Modify: `main.py` (add ScanWorker class, after ExportWorker around line 165)
**Step 1: Add the ScanWorker class**
After the `ExportWorker` class (ends around line 165), add:
```python
class ScanWorker(QThread):
"""Runs audio similarity scan off the main thread."""
finished = pyqtSignal(list) # emits list of (start, end, score)
error = pyqtSignal(str)
progress = pyqtSignal(str) # status message
def __init__(self, video_path: str, clip_paths: list[str],
mode: str = "average", threshold: float = 0.7):
super().__init__()
self._video_path = video_path
self._clip_paths = clip_paths
self._mode = mode
self._threshold = threshold
self._cancel = False
def cancel(self) -> None:
self._cancel = True
def run(self):
from core.audio_scan import build_profile, scan_video
try:
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
profile = build_profile(self._clip_paths)
if self._cancel:
return
if profile is None:
self.error.emit("No valid reference clips found")
return
self.progress.emit("Scanning audio...")
regions = scan_video(
self._video_path, profile,
mode=self._mode, threshold=self._threshold,
cancel_flag=self,
)
if not self._cancel:
self.finished.emit(regions)
except Exception as e:
if not self._cancel:
self.error.emit(str(e))
```
**Step 2: Verify import works**
Run: `cd /media/p5/8-cut && python -c "from main import ScanWorker; print('ok')"`
Expected: `ok`
**Step 3: Commit**
```bash
git add main.py
git commit -m "feat: add ScanWorker QThread for background scanning"
```
---
### Task 5: DB helper — get_all_export_paths
**Files:**
- Modify: `core/db.py`
- Modify: `tests/test_audio_scan.py`
**Step 1: Write the test**
Add to `tests/test_audio_scan.py`:
```python
def test_db_get_all_export_paths():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
try:
from core.db import ProcessedDB
db = ProcessedDB(path)
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
paths = db.get_all_export_paths("test")
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
finally:
os.unlink(path)
```
**Step 2: Run test to verify it fails**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
Expected: FAIL with `AttributeError: 'ProcessedDB' object has no attribute 'get_all_export_paths'`
**Step 3: Write the implementation**
Add to `core/db.py`, after the `get_markers` method. Note: no lock needed — follows
the codebase convention where read-only methods don't acquire the lock.
```python
def get_all_export_paths(self, profile: str = "default") -> list[str]:
"""Return all unique output_path values for a given profile."""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
return [r[0] for r in rows]
```
**Step 4: Run test to verify it passes**
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
Expected: PASS
**Step 5: Commit**
```bash
git add core/db.py tests/test_audio_scan.py
git commit -m "feat: add get_all_export_paths to ProcessedDB"
```
---
### Task 6: UI controls for audio scanning
**Files:**
- Modify: `main.py` (MainWindow class — control creation ~1490-1575, layout ~1620-1640)
**Step 1: Add scan control widgets**
In the MainWindow `__init__`, find the control creation section. After `self._chk_track` (around line 1501), add:
```python
# ── audio scan controls ──────────────────────────────────────
self._btn_scan = QPushButton("Scan")
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
self._btn_scan.clicked.connect(self._start_scan)
self._sld_threshold = QDoubleSpinBox()
self._sld_threshold.setRange(0.0, 1.0)
self._sld_threshold.setSingleStep(0.05)
self._sld_threshold.setValue(0.7)
self._sld_threshold.setPrefix("Thr: ")
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
self._cmb_scan_mode = QComboBox()
self._cmb_scan_mode.addItems(["Average", "Nearest"])
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
self._cmb_scan_ref = QComboBox()
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
self._scan_folder: str = ""
self._scan_worker: ScanWorker | None = None
```
**Step 2: Add controls to settings_row layout**
Find the `settings_row` assembly (around line 1620). Before `settings_row.addStretch()` (around line 1635), add:
```python
settings_row.addWidget(self._btn_scan)
settings_row.addWidget(self._sld_threshold)
settings_row.addWidget(self._cmb_scan_mode)
settings_row.addWidget(self._cmb_scan_ref)
```
**Step 3: Add handler methods**
Add these methods to MainWindow (after `_jump_to_next_marker` around line 2410):
```python
def _on_scan_ref_changed(self, index: int) -> None:
if index == 1: # Custom Folder
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
if folder:
self._scan_folder = folder
else:
self._cmb_scan_ref.setCurrentIndex(0)
def _cleanup_scan_worker(self) -> None:
"""Disconnect signals and schedule deletion of old scan worker."""
if self._scan_worker is not None:
try:
self._scan_worker.finished.disconnect()
self._scan_worker.error.disconnect()
self._scan_worker.progress.disconnect()
except TypeError:
pass # already disconnected
self._scan_worker.deleteLater()
self._scan_worker = None
def _start_scan(self) -> None:
if not self._file_path:
self._show_status("No video loaded")
return
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
# Clean up previous worker
self._cleanup_scan_worker()
# Collect reference clip paths
if self._cmb_scan_ref.currentIndex() == 0:
# Current profile — all exports across all files in this profile
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
if os.path.exists(p)]
else:
# Custom folder
if not self._scan_folder:
self._show_status("No reference folder selected")
return
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
clip_paths = [
os.path.join(self._scan_folder, f)
for f in sorted(os.listdir(self._scan_folder))
if f.lower().endswith(exts)
]
if not clip_paths:
self._show_status("No reference clips found")
return
mode = self._cmb_scan_mode.currentText().lower()
threshold = self._sld_threshold.value()
self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path # remember which file we're scanning
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
self._scan_worker.finished.connect(self._on_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
self._scan_worker.start()
def _on_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
# Ignore stale results if the user switched files during scan
if self._file_path != getattr(self, '_scan_file_path', None):
return
self._timeline.set_scan_regions(regions)
self._show_status(f"Scan complete: {len(regions)} matching regions")
def _on_scan_error(self, msg: str) -> None:
self._btn_scan.setEnabled(True)
self._show_status(f"Scan error: {msg}")
```
**Step 4: Verify manually**
Run: `cd /media/p5/8-cut && python main.py`
Expected: Scan button, threshold spinner, mode dropdown, and reference source dropdown visible in the settings row. Clicking Scan with no file loaded shows "No video loaded" in status.
**Step 5: Commit**
```bash
git add main.py
git commit -m "feat: add scan UI controls and start_scan handler"
```
---
### Task 7: Keyboard shortcut — jump to next scan region
**Files:**
- Modify: `main.py`
**Step 1: Add the keyboard shortcut**
Find the shortcut definitions (around line 1728, where `QShortcut(QKeySequence("M"), ...)` is defined). Add after it:
```python
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
```
**Step 2: Add the jump method**
After `_on_scan_error` (or after `_jump_to_next_marker`), add:
```python
def _jump_to_next_scan_region(self) -> None:
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
if not regions:
return
for (start, _end, _score) in regions:
if start > self._cursor + 0.1:
self._step_cursor(start - self._cursor)
return
# Wrap to first region
self._step_cursor(regions[0][0] - self._cursor)
```
**Step 3: Update help text**
Find the help/shortcuts tooltip (around line 1757). Add a row:
```python
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
```
**Step 4: Clear scan regions and cancel running scan on file change**
Find `_load_file` method (around line 1931). After the existing marker/state resets, add:
```python
self._timeline.clear_scan_regions()
if self._scan_worker and self._scan_worker.isRunning():
self._scan_worker.cancel()
self._cleanup_scan_worker()
self._btn_scan.setEnabled(True)
```
**Step 5: Verify manually**
Run: `cd /media/p5/8-cut && python main.py`
Expected: S key does nothing when no scan regions exist. After a scan, S jumps through matched regions.
**Step 6: Commit**
```bash
git add main.py
git commit -m "feat: add S shortcut and clear scan on file change"
```
---
### Task 8: Final integration test
**Step 1: End-to-end manual test**
1. Open the app: `cd /media/p5/8-cut && python main.py`
2. Load a video file
3. Export a few clips (these become the reference)
4. Set reference source to "Current Profile"
5. Click "Scan"
6. Verify: status shows progress messages, then "Scan complete: N matching regions"
7. Verify: cyan-tinted regions appear on the timeline
8. Press S to jump through scan regions
9. Change threshold and re-scan — verify different number of regions
10. Switch mode to "Nearest" and re-scan
11. Switch reference to "Custom Folder", pick a folder with clips
12. Re-scan and verify results
**Step 2: Run all tests**
Run: `cd /media/p5/8-cut && python -m pytest tests/ -v`
Expected: all tests PASS
**Step 3: Final commit**
```bash
git add -A
git commit -m "feat: audio similarity scanning complete"
```
+616 -1
View File
@@ -15,7 +15,7 @@ from PyQt6.QtWidgets import (
QLabel, QPushButton, QLineEdit, QFileDialog, QLabel, QPushButton, QLineEdit, QFileDialog,
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip, QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
QMessageBox, QInputDialog, QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
) )
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
@@ -185,6 +185,183 @@ class FrameGrabber(QThread):
pass pass
class ScanWorker(QThread):
"""Runs audio similarity scan off the main thread."""
scan_done = pyqtSignal(list) # emits list of (start, end, score)
error = pyqtSignal(str)
progress = pyqtSignal(str) # status message
def __init__(self, video_path: str, model: dict,
threshold: float = 0.30):
super().__init__()
self._video_path = video_path
self._model = model
self._threshold = threshold
self._cancel = False
def cancel(self) -> None:
self._cancel = True
def run(self):
from core.audio_scan import scan_video
try:
self.progress.emit("Scanning audio...")
regions = scan_video(
self._video_path, model=self._model,
threshold=self._threshold, cancel_flag=self,
)
if not self._cancel:
self.scan_done.emit(regions)
except Exception as e:
if not self._cancel:
self.error.emit(str(e))
class TrainDialog(QDialog):
"""Dialog for configuring and launching classifier training."""
def __init__(self, db: ProcessedDB, profile: str, video_dir: str = "",
parent=None):
super().__init__(parent)
self.setWindowTitle("Train Classifier")
self.setMinimumWidth(400)
from core.audio_scan import _EMBED_MODELS
self._db = db
self._profile = profile
self._video_dir = video_dir
layout = QVBoxLayout(self)
form = QFormLayout()
# Positive class selector — lists export folders
self._cmb_positive = QComboBox()
stats = db.get_training_stats(profile)
if not stats:
form.addRow("", QLabel("No exported clips found for this profile."))
for folder_name, info in stats.items():
label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)"
self._cmb_positive.addItem(label, userData=folder_name)
form.addRow("Positive class:", self._cmb_positive)
# Model selector
self._cmb_model = QComboBox()
for name in _EMBED_MODELS:
self._cmb_model.addItem(name)
self._cmb_model.setCurrentText("WAV2VEC2_BASE")
form.addRow("Model:", self._cmb_model)
# Video source directory (fallback for old DB rows without source_path)
self._txt_video_dir = QLineEdit(video_dir)
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
self._debounce = QTimer(self)
self._debounce.setSingleShot(True)
self._debounce.setInterval(400)
self._debounce.timeout.connect(self._update_stats)
self._txt_video_dir.textChanged.connect(lambda: self._debounce.start())
vid_row = QHBoxLayout()
vid_row.addWidget(self._txt_video_dir)
btn_browse = QPushButton("...")
btn_browse.setFixedWidth(30)
btn_browse.clicked.connect(self._browse_video_dir)
vid_row.addWidget(btn_browse)
form.addRow("Video dir:", vid_row)
layout.addLayout(form)
# Stats summary
self._lbl_stats = QLabel()
self._update_stats()
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
layout.addWidget(self._lbl_stats)
# Buttons
btns = QDialogButtonBox(
QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel
)
btns.button(QDialogButtonBox.StandardButton.Ok).setText("Train")
btns.button(QDialogButtonBox.StandardButton.Ok).setEnabled(
self._cmb_positive.count() > 0
)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
layout.addWidget(btns)
def _browse_video_dir(self):
d = QFileDialog.getExistingDirectory(self, "Select video source directory")
if d:
self._txt_video_dir.setText(d)
def _update_stats(self):
folder = self._cmb_positive.currentData()
if not folder:
self._lbl_stats.setText("No export folder data available.")
return
video_infos = self._db.get_training_data(
self._profile, folder,
fallback_video_dir=self._txt_video_dir.text(),
)
n_videos = len(video_infos)
n_pos = sum(len(gt) for _, gt, _ in video_infos)
n_soft = sum(len(s) for _, _, s in video_infos)
lines = [f"<b>{n_videos}</b> videos with positive clips"]
lines.append(f"<b>{n_pos}</b> positive markers, <b>{n_soft}</b> soft/buffer markers")
if n_videos == 0:
lines.append("<i>No source videos found. Set Video dir above.</i>")
elif n_videos < 3:
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
self._lbl_stats.setText("<br>".join(lines))
@property
def positive_folder(self) -> str:
return self._cmb_positive.currentData() or ""
@property
def embed_model(self) -> str:
return self._cmb_model.currentText()
@property
def video_dir(self) -> str:
return self._txt_video_dir.text()
class TrainWorker(QThread):
"""Trains an audio classifier off the main thread."""
train_done = pyqtSignal(str) # emits model path on success
error = pyqtSignal(str)
progress = pyqtSignal(str) # per-video status
def __init__(self, video_infos: list, model_path: str,
embed_model: str | None = None):
super().__init__()
self._video_infos = video_infos
self._model_path = model_path
self._embed_model = embed_model
self._cancel = False
def cancel(self) -> None:
self._cancel = True
def run(self):
from core.audio_scan import train_classifier
try:
self.progress.emit(f"Training on {len(self._video_infos)} videos...")
result = train_classifier(
self._video_infos,
model_path=self._model_path,
embed_model=self._embed_model,
)
if self._cancel:
return
if result is None:
self.error.emit("Training failed: not enough data or missing class balance")
else:
self.train_done.emit(self._model_path)
except Exception as e:
if not self._cancel:
self.error.emit(str(e))
class TimelineWidget(QWidget): class TimelineWidget(QWidget):
cursor_changed = pyqtSignal(float) # emits position in seconds cursor_changed = pyqtSignal(float) # emits position in seconds
seek_changed = pyqtSignal(float) # emits seek position (lock mode) seek_changed = pyqtSignal(float) # emits seek position (lock mode)
@@ -208,6 +385,7 @@ class TimelineWidget(QWidget):
self._crop_keyframes: list[tuple[float, float, str | None, bool, bool]] = [] self._crop_keyframes: list[tuple[float, float, str | None, bool, bool]] = []
self._markers: list[tuple[float, int, str]] = [] self._markers: list[tuple[float, int, str]] = []
self._hover_cache: list[tuple[float, str]] = [] # (t/duration, path) self._hover_cache: list[tuple[float, str]] = [] # (t/duration, path)
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
# Cached paint resources — created once, reused every frame # Cached paint resources — created once, reused every frame
self._cursor_pen = QPen(QColor(255, 210, 0)) self._cursor_pen = QPen(QColor(255, 210, 0))
@@ -252,6 +430,15 @@ class TimelineWidget(QWidget):
self._rebuild_hover_cache() self._rebuild_hover_cache()
self.update() self.update()
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
"""regions: list of (start_time, end_time, score)"""
self._scan_regions = regions
self.update()
def clear_scan_regions(self) -> None:
self._scan_regions = []
self.update()
def set_play_position(self, t: float | None) -> None: def set_play_position(self, t: float | None) -> None:
# In lock mode, ignore mpv position updates while the user is dragging # In lock mode, ignore mpv position updates while the user is dragging
# — the async seek hasn't caught up yet, so mpv reports stale values. # — the async seek hasn't caught up yet, so mpv reports stale values.
@@ -360,6 +547,14 @@ class TimelineWidget(QWidget):
p.drawLine(x_start, rh, x_start, h) p.drawLine(x_start, rh, x_start, h)
p.drawLine(x_end, rh, x_end, h) p.drawLine(x_end, rh, x_end, h)
# ── scan regions ──────────────────────────────────────────────
if self._scan_regions and self._duration > 0:
for (start, end, score) in self._scan_regions:
x1 = int(start / self._duration * w)
x2 = int(end / self._duration * w)
alpha = int(40 + score * 80) # 40120 opacity
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
# ── export markers ──────────────────────────────────────────── # ── export markers ────────────────────────────────────────────
p.setFont(self._marker_font) p.setFont(self._marker_font)
for (t, num, _path) in self._markers: for (t, num, _path) in self._markers:
@@ -1500,6 +1695,42 @@ class MainWindow(QMainWindow):
lambda v: self._settings.setValue("track_subject", "true" if v else "false") lambda v: self._settings.setValue("track_subject", "true" if v else "false")
) )
# ── audio scan controls ──────────────────────────────────────
self._btn_scan = QPushButton("Scan")
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
self._btn_scan.clicked.connect(self._start_scan)
self._btn_auto_export = QPushButton("Auto")
self._btn_auto_export.setToolTip("Scan + auto-export best 8s clips")
self._btn_auto_export.clicked.connect(self._auto_export)
self._btn_train = QPushButton("Train")
self._btn_train.setToolTip("Train audio classifier from exported clips")
self._btn_train.clicked.connect(self._open_train_dialog)
self._train_worker: TrainWorker | None = None
self._spn_auto_fuse = QDoubleSpinBox()
self._spn_auto_fuse.setDecimals(1)
self._spn_auto_fuse.setRange(0.0, 60.0)
self._spn_auto_fuse.setSingleStep(1.0)
self._spn_auto_fuse.setValue(float(self._settings.value("auto_fuse", "4.0")))
self._spn_auto_fuse.setPrefix("Fuse: ")
self._spn_auto_fuse.setSuffix("s")
self._spn_auto_fuse.setToolTip("Max gap between scan regions to merge into one cluster")
self._spn_auto_fuse.valueChanged.connect(
lambda v: self._settings.setValue("auto_fuse", str(v))
)
self._sld_threshold = QDoubleSpinBox()
self._sld_threshold.setDecimals(2)
self._sld_threshold.setRange(0.0, 1.0)
self._sld_threshold.setSingleStep(0.01)
self._sld_threshold.setValue(0.30)
self._sld_threshold.setPrefix("Thr: ")
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
self._scan_worker: ScanWorker | None = None
cpu_count = os.cpu_count() or 2 cpu_count = os.cpu_count() or 2
self._spn_workers = QSpinBox() self._spn_workers = QSpinBox()
self._spn_workers.setRange(1, cpu_count) self._spn_workers.setRange(1, cpu_count)
@@ -1632,6 +1863,11 @@ class MainWindow(QMainWindow):
settings_row.addWidget(self._chk_rand_portrait) settings_row.addWidget(self._chk_rand_portrait)
settings_row.addWidget(self._chk_rand_square) settings_row.addWidget(self._chk_rand_square)
settings_row.addWidget(self._chk_track) settings_row.addWidget(self._chk_track)
settings_row.addWidget(self._btn_scan)
settings_row.addWidget(self._btn_auto_export)
settings_row.addWidget(self._spn_auto_fuse)
settings_row.addWidget(self._sld_threshold)
settings_row.addWidget(self._btn_train)
settings_row.addStretch() settings_row.addStretch()
self._lbl_status = QLabel() self._lbl_status = QLabel()
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;") self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
@@ -1726,6 +1962,7 @@ class MainWindow(QMainWindow):
lambda _, idx=i - 1: self._export_subprofile(idx) lambda _, idx=i - 1: self._export_subprofile(idx)
) )
QShortcut(QKeySequence("M"), self, context=ctx).activated.connect(self._jump_to_next_marker) QShortcut(QKeySequence("M"), self, context=ctx).activated.connect(self._jump_to_next_marker)
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
QShortcut(QKeySequence("N"), self, context=ctx).activated.connect(self._playlist.advance) QShortcut(QKeySequence("N"), self, context=ctx).activated.connect(self._playlist.advance)
QShortcut(QKeySequence("G"), self, context=ctx).activated.connect(self._btn_lock.toggle) QShortcut(QKeySequence("G"), self, context=ctx).activated.connect(self._btn_lock.toggle)
QShortcut(QKeySequence("A"), self, context=ctx).activated.connect(self._autoclip) QShortcut(QKeySequence("A"), self, context=ctx).activated.connect(self._autoclip)
@@ -1755,6 +1992,7 @@ class MainWindow(QMainWindow):
"<tr><td><b>E</b></td><td>Export</td></tr>" "<tr><td><b>E</b></td><td>Export</td></tr>"
"<tr><td><b>19</b></td><td>Export to subprofile 19</td></tr>" "<tr><td><b>19</b></td><td>Export to subprofile 19</td></tr>"
"<tr><td><b>M</b></td><td>Jump to next marker</td></tr>" "<tr><td><b>M</b></td><td>Jump to next marker</td></tr>"
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
"<tr><td><b>N</b></td><td>Next file in playlist</td></tr>" "<tr><td><b>N</b></td><td>Next file in playlist</td></tr>"
"<tr><td><b>G</b></td><td>Toggle cursor lock</td></tr>" "<tr><td><b>G</b></td><td>Toggle cursor lock</td></tr>"
"<tr><td><b>A</b></td><td>Autoclip — fit clip count to pause position</td></tr>" "<tr><td><b>A</b></td><td>Autoclip — fit clip count to pause position</td></tr>"
@@ -1941,6 +2179,11 @@ class MainWindow(QMainWindow):
self._btn_lock.setChecked(False) self._btn_lock.setChecked(False)
self._crop_keyframes.clear() self._crop_keyframes.clear()
self._timeline.set_crop_keyframes([]) self._timeline.set_crop_keyframes([])
self._timeline.clear_scan_regions()
if self._scan_worker and self._scan_worker.isRunning():
self._scan_worker.cancel()
self._cleanup_scan_worker()
self._btn_scan.setEnabled(True)
dur = self._mpv.get_duration() dur = self._mpv.get_duration()
self._timeline.set_duration(dur) self._timeline.set_duration(dur)
@@ -2409,6 +2652,372 @@ class MainWindow(QMainWindow):
return return
self._step_cursor(markers[0][0] - self._cursor) # wrap to first self._step_cursor(markers[0][0] - self._cursor) # wrap to first
def _cleanup_scan_worker(self) -> None:
"""Disconnect signals and schedule deletion of old scan worker."""
if self._scan_worker is not None:
try:
self._scan_worker.scan_done.disconnect()
self._scan_worker.error.disconnect()
self._scan_worker.progress.disconnect()
except TypeError:
pass # already disconnected
if self._scan_worker.isRunning():
# QThread.finished fires when run() returns, even on cancel
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
else:
self._scan_worker.deleteLater()
self._scan_worker = None
def _start_scan(self) -> None:
if not self._file_path:
self._show_status("No video loaded")
return
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
# Clean up previous worker
self._cleanup_scan_worker()
threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
if model is None:
self._show_status("No trained model — click Train first")
return
self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path
self._show_status("Scanning...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
self._scan_worker.scan_done.connect(self._on_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
self._scan_worker.start()
def _on_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
# Ignore stale results if the user switched files during scan
if self._file_path != getattr(self, '_scan_file_path', None):
return
self._timeline.set_scan_regions(regions)
self._show_status(f"Scan complete: {len(regions)} matching regions")
def _on_scan_error(self, msg: str) -> None:
self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._show_status(f"Scan error: {msg}")
# ── Training ────────────────────────────────────────────────
def _cleanup_train_worker(self) -> None:
"""Disconnect signals and schedule deletion of old train worker."""
if self._train_worker is not None:
try:
self._train_worker.train_done.disconnect()
self._train_worker.error.disconnect()
self._train_worker.progress.disconnect()
except TypeError:
pass
if self._train_worker.isRunning():
self._train_worker.cancel()
self._train_worker.finished.connect(self._train_worker.deleteLater)
else:
self._train_worker.deleteLater()
self._train_worker = None
def _open_train_dialog(self):
"""Show the training config dialog and start training if accepted."""
if self._train_worker and self._train_worker.isRunning():
self._show_status("Training already in progress…")
return
# Default video dir: parent of currently loaded file, or saved setting
default_dir = ""
if self._file_path:
default_dir = os.path.dirname(self._file_path)
saved_dir = self._settings.value("train_video_dir", default_dir)
dlg = TrainDialog(self._db, self._profile,
video_dir=saved_dir or default_dir, parent=self)
if dlg.exec() != QDialog.DialogCode.Accepted:
return
pos_folder = dlg.positive_folder
embed_model = dlg.embed_model
video_dir = dlg.video_dir
if not pos_folder:
self._show_status("No positive class selected")
return
# Persist video dir for next time
if video_dir:
self._settings.setValue("train_video_dir", video_dir)
video_infos = self._db.get_training_data(
self._profile, pos_folder, fallback_video_dir=video_dir,
)
if not video_infos:
self._show_status("No training data found for this subprofile")
return
from core.audio_scan import default_model_path
model_path = default_model_path(self._profile)
self._cleanup_train_worker()
self._btn_train.setEnabled(False)
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
self._train_worker = TrainWorker(video_infos, model_path, embed_model)
self._train_worker.train_done.connect(self._on_train_done)
self._train_worker.error.connect(self._on_train_error)
self._train_worker.progress.connect(self._show_status)
self._train_worker.start()
def _on_train_done(self, model_path: str):
self._btn_train.setEnabled(True)
self._show_status(f"Model trained and saved")
_log(f"Training complete: {model_path}")
def _on_train_error(self, msg: str):
self._btn_train.setEnabled(True)
self._show_status(f"Training error: {msg}")
# ── Auto-export ─────────────────────────────────────────────
def _auto_export(self) -> None:
"""Scan → NMS → export one 8s clip per selected position."""
if not self._file_path:
self._show_status("No video loaded")
return
if self._export_worker and self._export_worker.isRunning():
self._show_status("Export already running…")
return
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
self._cleanup_scan_worker()
self._btn_auto_export.setEnabled(False)
self._btn_scan.setEnabled(False)
threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
if model is not None:
self._scan_file_path = self._file_path
self._show_status("Auto: scanning with classifier...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
else:
self._show_status("Auto: no trained model — click Train first")
self._btn_auto_export.setEnabled(True)
self._btn_scan.setEnabled(True)
return
self._scan_worker.scan_done.connect(self._on_auto_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
self._scan_worker.start()
@staticmethod
def _select_export_positions(regions: list[tuple[float, float, float]],
min_gap: float = 2.0,
cluster_fuse: float = 30.0,
) -> list[float]:
"""Cluster scan regions, then fill each cluster with clips spaced min_gap apart.
1. Merge overlapping regions into clusters, fusing clusters <cluster_fuse apart.
2. Within each cluster, greedily pick positions by score, min_gap apart.
"""
if not regions:
return []
# Build clusters — merge overlapping + fuse if gap < cluster_fuse
sorted_r = sorted(regions, key=lambda r: r[0])
clusters: list[list[tuple[float, float, float]]] = []
cur_start, cur_end = sorted_r[0][0], sorted_r[0][1]
cur_regions = [sorted_r[0]]
for start, end, score in sorted_r[1:]:
if start - cur_end <= cluster_fuse:
cur_end = max(cur_end, end)
cur_regions.append((start, end, score))
else:
clusters.append(cur_regions)
cur_start, cur_end = start, end
cur_regions = [(start, end, score)]
clusters.append(cur_regions)
# Within each cluster, NMS by score with min_gap
picked: list[float] = []
for cluster in clusters:
by_score = sorted(cluster, key=lambda r: -r[2])
cluster_picks: list[float] = []
for start, _end, _score in by_score:
if all(abs(start - p) >= min_gap for p in cluster_picks):
cluster_picks.append(start)
picked.extend(cluster_picks)
return sorted(picked)
def _on_auto_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
if self._file_path != getattr(self, '_scan_file_path', None):
self._btn_auto_export.setEnabled(True)
return
self._timeline.set_scan_regions(regions)
if not regions:
self._show_status("Auto: no regions found")
self._btn_auto_export.setEnabled(True)
return
positions = self._select_export_positions(
regions, min_gap=2.0, cluster_fuse=self._spn_auto_fuse.value(),
)
if not positions:
self._show_status("Auto: no positions after NMS")
self._btn_auto_export.setEnabled(True)
return
# Build export jobs — one 8s clip per position
folder = self._txt_folder.text()
name = self._txt_name.text() or "clip"
fmt = self._cmb_format.currentText()
image_sequence = fmt == "WebP sequence"
os.makedirs(folder, exist_ok=True)
# Find starting counter
counter = 1
while True:
if image_sequence:
p = build_sequence_dir(folder, name, counter, sub=0)
else:
p = build_export_path(folder, name, counter, sub=0)
if not os.path.exists(p):
break
counter += 1
jobs = []
self._auto_export_positions = [] # stash for DB writes
for start_t in positions:
group_dir = os.path.join(folder, f"{name}_{counter:03d}")
os.makedirs(group_dir, exist_ok=True)
if image_sequence:
out = build_sequence_dir(folder, name, counter, sub=0)
else:
out = build_export_path(folder, name, counter, sub=0)
jobs.append((start_t, out, None, 0.5))
self._auto_export_positions.append((start_t, counter))
counter += 1
self._show_status(f"Auto: exporting {len(jobs)} clips...")
short_side = self._spn_resize.value() or None
self._export_short_side = short_side
self._export_portrait = "Off"
self._export_format = fmt
self._export_clip_count = 1
self._export_spread = 0
self._export_folder = folder
self._export_folder_suffix = ""
hw_on = self._chk_hw.isChecked() and self._hw_encoders
encoder = self._hw_encoders[0] if hw_on else "libx264"
max_workers = min(self._spn_workers.value(), 3) if hw_on else self._spn_workers.value()
self._export_worker = ExportWorker(
self._file_path, jobs,
short_side=short_side,
image_sequence=image_sequence,
max_workers=max_workers,
encoder=encoder,
)
self._export_worker.finished.connect(self._on_auto_clip_done)
self._export_worker.all_done.connect(self._on_auto_batch_done)
self._export_worker.error.connect(self._on_export_error)
self._export_worker.cancelled.connect(self._on_export_cancelled)
self._btn_cancel.setEnabled(True)
self._btn_export.setEnabled(False)
self._set_subprofile_btns_enabled(False)
self._export_worker.start()
def _on_auto_clip_done(self, path: str):
"""Record each auto-exported clip to DB."""
# Find the start_time for this clip from stashed positions
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
name = self._txt_name.text() or "clip"
start_t = None
for t, c in self._auto_export_positions:
if counter_str == f"{name}_{c:03d}":
start_t = t
break
label = self._txt_label.currentText().strip()
category = self._cmb_category.currentText()
self._db.add(
os.path.basename(self._file_path),
start_t or 0.0,
path,
label=label,
category=category,
short_side=self._export_short_side,
portrait_ratio="",
crop_center=0.5,
fmt=self._export_format,
clip_count=1,
spread=0,
profile=self._profile,
source_path=self._file_path,
)
upsert_clip_annotation(self._export_folder, path, label)
self._show_status(f"Auto: {os.path.basename(path)}")
_log(f" auto clip done: {os.path.basename(path)}")
def _on_auto_batch_done(self):
n = len(self._auto_export_positions)
self._btn_auto_export.setEnabled(True)
self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True)
self._set_subprofile_btns_enabled(True)
self._refresh_markers()
markers = self._db.get_markers(os.path.basename(self._file_path), self._profile)
self._playlist.mark_done(self._file_path, len(markers))
self._update_next_label()
self._show_status(f"Auto export complete: {n} clips")
_log(f"Auto export complete: {n} clips")
def _jump_to_next_scan_region(self) -> None:
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
if not regions:
return
# Merge overlapping regions into clusters so S jumps past each group
clusters: list[tuple[float, float]] = []
for (start, end, _score) in regions:
if clusters and start <= clusters[-1][1]:
clusters[-1] = (clusters[-1][0], max(clusters[-1][1], end))
else:
clusters.append((start, end))
# Jump to the start of the next cluster after cursor
for (start, _end) in clusters:
if start > self._cursor + 0.1:
self._step_cursor(start - self._cursor)
return
# Wrap to first cluster
self._step_cursor(clusters[0][0] - self._cursor)
# --- Export --- # --- Export ---
def _pick_folder(self): def _pick_folder(self):
@@ -2616,6 +3225,7 @@ class MainWindow(QMainWindow):
clip_count=self._export_clip_count, clip_count=self._export_clip_count,
spread=self._export_spread, spread=self._export_spread,
profile=self._profile, profile=self._profile,
source_path=self._file_path,
) )
upsert_clip_annotation(self._export_folder, path, label) upsert_clip_annotation(self._export_folder, path, label)
self._last_export_path = path self._last_export_path = path
@@ -2655,6 +3265,7 @@ class MainWindow(QMainWindow):
_log(f"Export error: {msg}") _log(f"Export error: {msg}")
self._btn_cancel.setEnabled(False) self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True) self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True) self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export") self._btn_export.setText("Export")
self._btn_export.setStyleSheet("") self._btn_export.setStyleSheet("")
@@ -2670,6 +3281,7 @@ class MainWindow(QMainWindow):
def _on_export_cancelled(self): def _on_export_cancelled(self):
_log("Export cancelled") _log("Export cancelled")
self._btn_export.setEnabled(True) self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True) self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export") self._btn_export.setText("Export")
self._btn_export.setStyleSheet("") self._btn_export.setStyleSheet("")
@@ -2690,6 +3302,9 @@ class MainWindow(QMainWindow):
_log("Shutting down…") _log("Shutting down…")
# Save session playlist for resume. # Save session playlist for resume.
self._settings.setValue("session_files", self._playlist._paths) self._settings.setValue("session_files", self._playlist._paths)
# Cancel background workers to prevent callbacks into dead objects.
self._cleanup_scan_worker()
self._cleanup_train_worker()
# Stop timers first to prevent callbacks into dead objects. # Stop timers first to prevent callbacks into dead objects.
self._preview_timer.stop() self._preview_timer.stop()
self._mpv._render_timer.stop() self._mpv._render_timer.stop()
+22 -1
View File
@@ -1,4 +1,25 @@
# Core GUI
PyQt6>=6.4 PyQt6>=6.4
python-mpv>=1.0 python-mpv>=1.0
pytest>=7.0
# Audio & ML
librosa>=0.10
numpy>=1.24
scikit-learn>=1.3
joblib>=1.3
soundfile>=0.12
# Deep learning (torch installed separately for CUDA support)
# torch and torchaudio are installed via --index-url in setup_env.sh
torchaudio>=2.0
# Object detection
ultralytics>=8.0 ultralytics>=8.0
# Server API
fastapi>=0.100
pydantic>=2.0
uvicorn>=0.23
# Dev
pytest>=7.0
Executable
+108
View File
@@ -0,0 +1,108 @@
#!/usr/bin/env bash
set -euo pipefail
# ──────────────────────────────────────────────────────────────────────
# 8-cut environment setup — supports conda (miniforge) or python venv
#
# Usage:
# ./setup_env.sh # auto-detect (prefers conda if available)
# ./setup_env.sh --conda # force conda
# ./setup_env.sh --venv # force python venv
# ────────────────────────────────────────────────────────────────────
ENV_NAME="8cut"
PYTHON_VERSION="3.12"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
VENV_DIR="$SCRIPT_DIR/.venv"
# CUDA version for PyTorch index URL
TORCH_INDEX="https://download.pytorch.org/whl/cu128"
# ── Parse args ────────────────────────────────────────────────────────
MODE=""
for arg in "$@"; do
case "$arg" in
--conda) MODE="conda" ;;
--venv) MODE="venv" ;;
*) echo "Unknown arg: $arg"; exit 1 ;;
esac
done
if [ -z "$MODE" ]; then
if command -v conda &>/dev/null; then
MODE="conda"
else
MODE="venv"
fi
echo "Auto-detected mode: $MODE"
fi
# ── Conda setup ──────────────────────────────────────────────────────
setup_conda() {
echo "==> Setting up conda environment: $ENV_NAME"
# Source conda shell hooks if not already active
if ! command -v conda &>/dev/null; then
echo "conda not found in PATH"
exit 1
fi
eval "$(conda shell.bash hook)"
if conda env list | grep -qw "$ENV_NAME"; then
echo " Environment '$ENV_NAME' already exists, updating..."
conda activate "$ENV_NAME"
else
echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..."
conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION"
conda activate "$ENV_NAME"
fi
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
pip install torch torchaudio --index-url "$TORCH_INDEX"
echo " Installing project dependencies..."
pip install -r "$SCRIPT_DIR/requirements.txt"
echo ""
echo "Done! Activate with:"
echo " conda activate $ENV_NAME"
}
# ── Venv setup ───────────────────────────────────────────────────────
setup_venv() {
echo "==> Setting up Python venv at: $VENV_DIR"
if [ ! -d "$VENV_DIR" ]; then
python3 -m venv "$VENV_DIR"
echo " Created venv"
else
echo " Venv already exists, updating..."
fi
source "$VENV_DIR/bin/activate"
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
pip install torch torchaudio --index-url "$TORCH_INDEX"
echo " Installing project dependencies..."
pip install -r "$SCRIPT_DIR/requirements.txt"
echo ""
echo "Done! Activate with:"
echo " source $VENV_DIR/bin/activate"
}
# ── Run ───────────────────────────────────────────────────────────────
case "$MODE" in
conda) setup_conda ;;
venv) setup_venv ;;
esac
echo ""
echo "Verify with:"
echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\""
echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\""
+40
View File
@@ -0,0 +1,40 @@
import tempfile, os
import numpy as np
from core.audio_scan import scan_video, load_classifier, default_model_path
def test_scan_video_no_model_returns_empty():
"""scan_video with no model should return empty list."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
import soundfile as sf
sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000)
try:
regions = scan_video(vid.name, model=None)
assert regions == []
finally:
os.unlink(vid.name)
def test_load_classifier_missing_returns_none():
assert load_classifier("/no/such/model.joblib") is None
def test_default_model_path_contains_profile():
path = default_model_path("test_profile")
assert "test_profile" in path
assert path.endswith(".joblib")
def test_db_get_all_export_paths():
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
try:
from core.db import ProcessedDB
db = ProcessedDB(path)
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
paths = db.get_all_export_paths("test")
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
finally:
os.unlink(path)