diff --git a/8cut_calibrate.py b/8cut_calibrate.py new file mode 100644 index 0000000..b7dcdc4 --- /dev/null +++ b/8cut_calibrate.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +"""Calibration — per-video normalized features + classifier.""" +import sys, os, time, warnings +sys.path.insert(0, os.path.dirname(__file__)) +warnings.filterwarnings("ignore") + +import numpy as np +import librosa +from sklearn.ensemble import GradientBoostingClassifier + +from core.audio_scan import _SR, _WINDOW + +_HOP_LENGTH = 1024 +_N_FFT = 2048 +from core.db import ProcessedDB + +PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/" +PROFILE_NAME = "JAV_missionary" +TOLERANCE = 12.0 +NEG_MARGIN = 120.0 + + +def extract_rich_features(y, sr=_SR): + """Per-frame features: onset, energy, spectral shape, mel bands (22 features).""" + hop = _HOP_LENGTH + S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2 + rms = librosa.feature.rms(S=S, hop_length=hop) + cent = librosa.feature.spectral_centroid(S=S, sr=sr) + bw = librosa.feature.spectral_bandwidth(S=S, sr=sr) + rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr) + flatness = librosa.feature.spectral_flatness(S=S) + zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop) + onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1) + + mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128) + mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2) + bands = [(0, 100), (100, 300), (300, 600), (600, 1200), + (1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)] + band_feats = [] + for flo, fhi in bands: + mask = (mel_freqs >= flo) & (mel_freqs < fhi) + if mask.sum() > 0: + band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10)) + else: + band_feats.append(np.zeros((1, mel_S.shape[1]))) + + sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop) + + min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1], + band_feats[0].shape[1]) + return np.vstack([ + rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t], + flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t], + ] + [b[:, :min_t] for b in band_feats] + + [sc[:, :min_t]]) + + +def compute_window_stats(feat, hop=1.0): + """Sliding window mean/std → (timestamps, feature_vectors).""" + n_feats, T = feat.shape + fps = _SR / _HOP_LENGTH + win_frames = int(_WINDOW * fps) + hop_frames = int(hop * fps) + if win_frames > T: + return np.array([]), np.array([]) + + cumsum = np.zeros((n_feats, T + 1)) + cumsum[:, 1:] = np.cumsum(feat, axis=1) + cumsq = np.zeros((n_feats, T + 1)) + cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1) + + starts = np.arange(0, T - win_frames + 1, hop_frames) + ends = starts + win_frames + sums = cumsum[:, ends] - cumsum[:, starts] + sq_sums = cumsq[:, ends] - cumsq[:, starts] + means = sums / win_frames + stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10) + + return starts / fps, np.vstack([means, stds]).T + + +def label_windows(timestamps, gt_intense, gt_soft): + all_gt = list(gt_intense) + list(gt_soft) + labels = np.zeros(len(timestamps), dtype=int) + for i, t in enumerate(timestamps): + di = min((abs(t - g) for g in gt_intense), default=9999) + da = min((abs(t - g) for g in all_gt), default=9999) + if di < TOLERANCE: + labels[i] = 1 + elif da > NEG_MARGIN: + labels[i] = -1 + return labels + + +def main(): + db = ProcessedDB() + rows = db._con.execute( + "SELECT filename, start_time, output_path FROM processed WHERE profile = ?", + (PROFILE_NAME,), + ).fetchall() + + intense_by_video, soft_by_video = {}, {} + for fn, st, op in rows: + if '/mp4_Intense/' in op: + intense_by_video.setdefault(fn, set()).add(st) + elif '/mp4_Soft/' in op: + soft_by_video.setdefault(fn, set()).add(st) + + videos = [fn for fn in intense_by_video + if os.path.exists(os.path.join(PLEX_DIR, fn))] + n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos) + videos = videos[:n_vids] + print(f"Processing {len(videos)} videos...") + + all_data_raw = [] # raw features + all_data_norm = [] # per-video z-scored features + + for vi, vname in enumerate(videos): + vpath = os.path.join(PLEX_DIR, vname) + gt_intense = sorted(intense_by_video.get(vname, set())) + gt_soft = sorted(soft_by_video.get(vname, set())) + + t0 = time.time() + y, _ = librosa.load(vpath, sr=_SR, mono=True) + feat = extract_rich_features(y) + timestamps, window_vectors = compute_window_stats(feat, hop=1.0) + dt = time.time() - t0 + + if len(timestamps) == 0: + continue + + labels = label_windows(timestamps, gt_intense, gt_soft) + + # Per-video z-score normalization + vid_mean = window_vectors.mean(axis=0) + vid_std = window_vectors.std(axis=0) + vid_std = np.maximum(vid_std, 1e-6) + normed = (window_vectors - vid_mean) / vid_std + + n_pos = (labels == 1).sum() + n_neg = (labels == -1).sum() + print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)") + + all_data_raw.append((vi, vname, timestamps, window_vectors, labels)) + all_data_norm.append((vi, vname, timestamps, normed, labels)) + + # Run CV for both raw and normalized + for label, data in [("RAW features", all_data_raw), + ("PER-VIDEO NORMALIZED features", all_data_norm)]: + print(f"\n{'='*70}") + print(f" {label}") + print(f"{'='*70}") + + all_y_true, all_y_prob = [], [] + + for test_idx in range(len(data)): + _, vname, _, test_X, test_labels = data[test_idx] + test_mask = test_labels != 0 + if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0: + continue + X_test = test_X[test_mask] + y_test = (test_labels[test_mask] == 1).astype(int) + + X_parts, y_parts = [], [] + for i, (_, _, _, feats, labs) in enumerate(data): + if i == test_idx: + continue + m = labs != 0 + if m.sum() == 0: + continue + X_parts.append(feats[m]) + y_parts.append((labs[m] == 1).astype(int)) + + if not X_parts: + continue + X_train = np.vstack(X_parts) + y_train = np.concatenate(y_parts) + + pos_idx = np.where(y_train == 1)[0] + neg_idx = np.where(y_train == 0)[0] + if len(pos_idx) == 0 or len(neg_idx) == 0: + continue + rng = np.random.RandomState(42) + n_neg = min(len(neg_idx), len(pos_idx) * 3) + neg_sample = rng.choice(neg_idx, n_neg, replace=False) + train_idx = np.concatenate([pos_idx, neg_sample]) + + clf = GradientBoostingClassifier( + n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42 + ) + clf.fit(X_train[train_idx], y_train[train_idx]) + probs = clf.predict_proba(X_test)[:, 1] + + tp = ((probs >= 0.5) & (y_test == 1)).sum() + fp = ((probs >= 0.5) & (y_test == 0)).sum() + fn_count = ((probs < 0.5) & (y_test == 1)).sum() + pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0 + neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0 + print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}") + + all_y_true.extend(y_test) + all_y_prob.extend(probs) + + if not all_y_true: + print(" No test results.") + continue + + y_true = np.array(all_y_true) + y_prob = np.array(all_y_prob) + pos_probs = y_prob[y_true == 1] + neg_probs = y_prob[y_true == 0] + + if len(pos_probs) > 0 and len(neg_probs) > 0: + print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}" + f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}") + print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}" + f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}") + + best_f1, best_thr = 0, 0 + print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}") + for thr in np.arange(0.10, 0.91, 0.05): + tp = ((y_prob >= thr) & (y_true == 1)).sum() + fp = ((y_prob >= thr) & (y_true == 0)).sum() + fn_count = ((y_prob < thr) & (y_true == 1)).sum() + prec = tp / (tp + fp) if (tp + fp) > 0 else 0 + rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0 + f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0 + if f1 > best_f1: + best_f1, best_thr = f1, thr + print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}") + print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}") + + # Feature importance + X_all = np.vstack([f[l != 0] for _, _, _, f, l in data]) + y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data]) + pos_idx = np.where(y_all == 1)[0] + neg_idx = np.where(y_all == 0)[0] + rng = np.random.RandomState(42) + neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False) + clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42) + clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])]) + + feat_names = ( + ["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"] + + [f"mel{i}" for i in range(8)] + + [f"sc{i}" for i in range(7)] + ) + stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names] + imp = clf.feature_importances_ + top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10] + print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}") + + +if __name__ == "__main__": + main() diff --git a/8cut_train.py b/8cut_train.py new file mode 100644 index 0000000..97fc012 --- /dev/null +++ b/8cut_train.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Train an audio scan classifier from DB ground truth. + +Usage: + python 8cut_train.py # default model, auto-detect positive + python 8cut_train.py --model BEATS # specific embedding model + python 8cut_train.py --positive mp4_Intense # explicit positive folder + python 8cut_train.py --positive mp4_Intense --model BEATS # both +""" +import sys, os, warnings +sys.path.insert(0, os.path.dirname(__file__)) +warnings.filterwarnings("ignore") + +from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS +from core.db import ProcessedDB + +PROFILE_NAME = "JAV_missionary" + +# Fallback for old DB rows without source_path +PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/" + + +def main(): + embed_model = None + if "--model" in sys.argv: + idx = sys.argv.index("--model") + if idx + 1 < len(sys.argv): + embed_model = sys.argv[idx + 1] + if embed_model not in _EMBED_MODELS: + print(f"Unknown model: {embed_model}") + print(f"Available: {', '.join(_EMBED_MODELS)}") + sys.exit(1) + + positive_suffix = None + if "--positive" in sys.argv: + idx = sys.argv.index("--positive") + if idx + 1 < len(sys.argv): + positive_suffix = sys.argv[idx + 1] + + db = ProcessedDB() + + # If --positive given, use the new DB helper + if positive_suffix: + video_infos = db.get_training_data( + PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR, + ) + if not video_infos: + print(f"No training data found for positive='{positive_suffix}'") + sys.exit(1) + else: + # Legacy fallback: classify by folder path pattern + rows = db._con.execute( + "SELECT filename, start_time, output_path, source_path" + " FROM processed WHERE profile = ?", + (PROFILE_NAME,), + ).fetchall() + + intense_by_video, soft_by_video = {}, {} + source_by_fn = {} + for fn, st, op, sp in rows: + if sp: + source_by_fn[fn] = sp + if "/mp4_Intense/" in op or "_Intense/" in op: + intense_by_video.setdefault(fn, set()).add(st) + elif "/mp4_Soft/" in op or "_Soft/" in op: + soft_by_video.setdefault(fn, set()).add(st) + + video_infos = [] + for fn in intense_by_video: + # Try source_path from DB first, fall back to PLEX_DIR + vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn) + if not os.path.exists(vpath): + print(f" skip (not found): {fn}") + continue + gt_intense = sorted(intense_by_video[fn]) + gt_soft = sorted(soft_by_video.get(fn, set())) + video_infos.append((vpath, gt_intense, gt_soft)) + + label = embed_model or "WAV2VEC2_BASE" + print(f"Training {label} model on {len(video_infos)} videos...") + model_path = default_model_path(PROFILE_NAME) + result = train_classifier( + video_infos, model_path=model_path, embed_model=embed_model, + ) + if result is None: + print("Training failed: no valid samples or missing class balance") + sys.exit(1) + print(f"Model saved to {model_path}") + + +if __name__ == "__main__": + main() diff --git a/core/audio_scan.py b/core/audio_scan.py index 505b016..96956c0 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -1,105 +1,359 @@ -"""Audio similarity scanning — MFCC + spectral contrast profile matching.""" +"""Audio scanning — embedding-based classifier for audio event detection.""" +import hashlib +import os import numpy as np import librosa from .paths import _log -_N_MFCC = 13 # coefficients 0-12; we drop C0 → 12 usable -_SR = 16000 # lower sr = faster, no quality loss for style matching -_HOP_LENGTH = 1024 # STFT hop (~64ms frames at 16kHz) -_N_FFT = 2048 # STFT window +_SR = 16000 # lower sr = faster _WINDOW = 8.0 # seconds -_N_FEATURES = 62 # (12 mfcc + 12 delta + 7 sc) * 2 (mean + std) +_MODEL_DIR = os.path.join(os.path.expanduser("~"), ".8cut_models") +_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v") + +# --------------------------------------------------------------------------- +# Embedding extraction (lazy-loaded) +# --------------------------------------------------------------------------- + +_w2v_model = None +_w2v_device = None +_w2v_model_name = None + +# Supported embedding models — name → embed_dim +_EMBED_MODELS = { + "WAV2VEC2_BASE": 768, + "WAV2VEC2_LARGE": 1024, + "WAV2VEC2_LARGE_LV60K":1024, + "HUBERT_BASE": 768, + "HUBERT_LARGE": 1024, + "HUBERT_XLARGE": 1280, + "BEATS": 768, +} +_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE" + +_BEATS_CHECKPOINT = os.path.join( + os.path.expanduser("~"), ".cache", "huggingface", "hub", + "models--lpepino--beats_ckpts", "snapshots", + "5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt", +) -def _extract_features_from_signal(y: np.ndarray, sr: int = _SR) -> np.ndarray: - """Compute feature matrix (31 x T) from a raw audio signal. +def _get_w2v_model(model_name: str | None = None): + """Lazy-load an embedding model. Reloads if model_name differs from cached.""" + global _w2v_model, _w2v_device, _w2v_model_name + if model_name is None: + model_name = _DEFAULT_EMBED_MODEL + if _w2v_model is None or _w2v_model_name != model_name: + import torch + _w2v_device = "cuda" if torch.cuda.is_available() else "cpu" - Features per frame: 12 MFCCs (skip C0) + 12 delta MFCCs + 7 spectral contrast. + if model_name == "BEATS": + from .beats_model import BEATs, BEATsConfig + checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device, + weights_only=False) + cfg = BEATsConfig(checkpoint['cfg']) + _w2v_model = BEATs(cfg) + _w2v_model.load_state_dict(checkpoint['model']) + _w2v_model.to(_w2v_device) + else: + import torchaudio + bundle = getattr(torchaudio.pipelines, model_name) + _w2v_model = bundle.get_model().to(_w2v_device) + + _w2v_model.eval() + _w2v_model_name = model_name + _log(f"audio_scan: {model_name} loaded on {_w2v_device}") + return _w2v_model, _w2v_device + + +def _embed_dim(model_name: str | None = None) -> int: + """Return embedding dimension for a model name.""" + if model_name is None: + model_name = _DEFAULT_EMBED_MODEL + return _EMBED_MODELS.get(model_name, 768) + + +def _w2v_cache_path(video_path: str, hop: float, window: float, + model_name: str | None = None) -> str: + """Return cache file path for a video's embeddings (includes model name).""" + if model_name is None: + model_name = _DEFAULT_EMBED_MODEL + abspath = os.path.abspath(video_path) + mtime = os.path.getmtime(abspath) + key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}" + h = hashlib.sha256(key.encode()).hexdigest()[:16] + return os.path.join(_W2V_CACHE_DIR, f"{h}.npz") + + +def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, + hop: float = 1.0, window: float = _WINDOW, + video_path: str | None = None, + cancel_flag: object = None, + model_name: str | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Extract embeddings for all sliding windows using a torchaudio model. + + If video_path is given, results are cached to disk for fast re-scans. + Returns (timestamps, embeddings) where embeddings is (N, D). """ - S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=_HOP_LENGTH)) ** 2 - mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=_HOP_LENGTH) - mfcc = librosa.feature.mfcc(S=librosa.power_to_db(mel_S), sr=sr, n_mfcc=_N_MFCC) - mfcc = mfcc[1:] # drop C0 (energy) — dominates cosine sim, kills discrimination - delta = librosa.feature.delta(mfcc) - sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=_HOP_LENGTH) - return np.vstack([mfcc, delta, sc]) # (31, T) + edim = _embed_dim(model_name) - -def _aggregate(feature_matrix: np.ndarray) -> np.ndarray: - """Collapse a (31, T) feature matrix into a (62,) vector via mean + std.""" - return np.concatenate([ - feature_matrix.mean(axis=1), - feature_matrix.std(axis=1), - ]) - - -def _extract_features(path: str, sr: int = _SR) -> np.ndarray: - """Load audio from a file and return a 62-dim feature vector.""" - y, _ = librosa.load(path, sr=sr, mono=True) - feat = _extract_features_from_signal(y, sr) - return _aggregate(feat) - - -def build_profile(clip_paths: list[str]) -> dict | None: - """Extract features from reference clips. - - Returns dict with: - - mean_vector: averaged feature vector across all clips (62,) - - clip_vectors: list of individual feature vectors - Returns None if no clips could be loaded. - """ - vectors = [] - for p in clip_paths: + # Try loading from cache + cache_file = None + if video_path: try: - vec = _extract_features(p) - vectors.append(vec) + cache_file = _w2v_cache_path(video_path, hop, window, model_name) + if os.path.exists(cache_file): + data = np.load(cache_file) + _log(f"audio_scan: cache hit ({cache_file})") + return data["timestamps"], data["embeddings"] except Exception as e: - _log(f"audio_scan: skip {p}: {e}") - if not vectors: - return None - arr = np.stack(vectors) - return { - "mean_vector": arr.mean(axis=0), - "clip_vectors": vectors, - } + _log(f"audio_scan: cache read failed: {e}") + + win_samples = int(window * sr) + hop_samples = int(hop * sr) + n_windows = max(0, (len(y) - win_samples) // hop_samples + 1) + + if n_windows == 0: + return np.array([]), np.empty((0, edim)) + + import torch + model, device = _get_w2v_model(model_name) + is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + batch_size = 16 + timestamps = np.arange(n_windows) * hop + embeddings = [] + + for batch_start in range(0, n_windows, batch_size): + if cancel_flag and getattr(cancel_flag, '_cancel', False): + return np.array([]), np.empty((0, edim)) + batch_end = min(batch_start + batch_size, n_windows) + chunks = [] + for i in range(batch_start, batch_end): + start = i * hop_samples + chunks.append(y[start:start + win_samples]) + with torch.no_grad(): + waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) + if is_beats: + padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) + features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + else: + features, _ = model(waveforms) + batch_emb = features.mean(dim=1).cpu().numpy() + embeddings.append(batch_emb) + + result_ts = timestamps + result_emb = np.vstack(embeddings) + + # Save to cache + if cache_file: + try: + os.makedirs(_W2V_CACHE_DIR, exist_ok=True) + np.savez(cache_file, timestamps=result_ts, embeddings=result_emb) + _log(f"audio_scan: w2v cache saved ({cache_file})") + except Exception as e: + _log(f"audio_scan: cache write failed: {e}") + + return result_ts, result_emb -def _similarity(a: np.ndarray, b: np.ndarray) -> float: - """Euclidean-distance-based similarity in (0, 1]. +def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float], + gt_soft: list[float], tolerance: float = 12.0, + neg_margin: float = 120.0, + model_name: str | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract embeddings only near positives and distant negatives. - 1/(1+dist): identical → 1.0, very different → near 0. + Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig. """ - return float(1.0 / (1.0 + np.linalg.norm(a - b))) + edim = _embed_dim(model_name) + duration = len(y) / sr + win_samples = int(_WINDOW * sr) + all_gt = list(gt_intense) + list(gt_soft) + # Positive windows: every second near intense markers + pos_times = set() + for gt in gt_intense: + for offset in range(-int(tolerance), int(tolerance) + 1): + t = gt + offset + if 0 <= t <= duration - _WINDOW: + pos_times.add(int(t)) + + # Negative windows: every 4s, far from any marker + neg_times = set() + for t in range(0, int(duration - _WINDOW), 4): + if min((abs(t - g) for g in all_gt), default=9999) > neg_margin: + neg_times.add(t) + + all_times = sorted(pos_times | neg_times) + # Filter out windows that go past the end + valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)] + + if not valid_times: + return np.array([]), np.zeros((0, edim)), np.array([], dtype=int) + + import torch + model, device = _get_w2v_model(model_name) + batch_size = 16 + timestamps_list: list[float] = [] + embeddings_list: list[np.ndarray] = [] + + is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + + for batch_start in range(0, len(valid_times), batch_size): + batch_end = min(batch_start + batch_size, len(valid_times)) + chunks = [] + for t in valid_times[batch_start:batch_end]: + start = int(t * sr) + chunks.append(y[start:start + win_samples]) + timestamps_list.append(float(t)) + with torch.no_grad(): + waveforms = torch.from_numpy(np.stack(chunks)).float().to(device) + if is_beats: + padding_mask = torch.zeros_like(waveforms, dtype=torch.bool) + features, _ = model.extract_features(waveforms, padding_mask=padding_mask) + else: + features, _ = model(waveforms) + batch_emb = features.mean(dim=1).cpu().numpy() + embeddings_list.append(batch_emb) + + timestamps = np.array(timestamps_list) + embeddings = np.vstack(embeddings_list) + + labels = np.zeros(len(timestamps), dtype=int) + for i, t in enumerate(timestamps): + di = min((abs(t - g) for g in gt_intense), default=9999) + da = min((abs(t - g) for g in all_gt), default=9999) + if di < tolerance: + labels[i] = 1 + elif da > neg_margin: + labels[i] = -1 + return timestamps, embeddings, labels + + +# --------------------------------------------------------------------------- +# Classifier mode — train / save / load / scan +# --------------------------------------------------------------------------- + +def train_classifier(video_infos: list[tuple[str, list[float], list[float]]], + model_path: str | None = None, + tolerance: float = 12.0, + neg_margin: float = 120.0, + embed_model: str | None = None) -> dict: + """Train a classifier from labeled videos. + + Args: + video_infos: list of (video_path, intense_times, soft_times) + model_path: if given, save model to this path + tolerance/neg_margin: labeling parameters + embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE + + Returns: + dict with 'classifier', 'embed_model', and metadata, or None on failure. + """ + from sklearn.ensemble import GradientBoostingClassifier + + all_X, all_y = [], [] + + for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos): + _log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}") + y, _ = librosa.load(vpath, sr=_SR, mono=True) + + timestamps, embeddings, labels = _extract_w2v_targeted( + y, _SR, gt_intense, gt_soft, tolerance, neg_margin, + model_name=embed_model, + ) + if len(timestamps) == 0: + continue + # Per-video z-score normalize + vid_mean = embeddings.mean(axis=0) + vid_std = np.maximum(embeddings.std(axis=0), 1e-6) + normed = (embeddings - vid_mean) / vid_std + for i in range(len(labels)): + if labels[i] == 1: + all_X.append(normed[i]) + all_y.append(1) + elif labels[i] == -1: + all_X.append(normed[i]) + all_y.append(0) + + if not all_X: + _log("audio_scan: no training samples collected") + return None + + X = np.stack(all_X) + y_arr = np.array(all_y) + n_pos = (y_arr == 1).sum() + n_neg = (y_arr == 0).sum() + _log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative") + + if n_pos == 0 or n_neg == 0: + _log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg") + return None + + # Subsample negatives for balance + rng = np.random.RandomState(42) + pos_idx = np.where(y_arr == 1)[0] + neg_idx = np.where(y_arr == 0)[0] + n_neg_sample = min(len(neg_idx), len(pos_idx) * 3) + neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False) + train_idx = np.concatenate([pos_idx, neg_sample]) + rng.shuffle(train_idx) + + clf = GradientBoostingClassifier( + n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42, + ) + clf.fit(X[train_idx], y_arr[train_idx]) + _log("audio_scan: classifier trained") + + model = {"classifier": clf, "n_features": X.shape[1], + "embed_model": embed_model or _DEFAULT_EMBED_MODEL} + + if model_path: + import joblib + parent = os.path.dirname(model_path) + if parent: + os.makedirs(parent, exist_ok=True) + joblib.dump(model, model_path) + _log(f"audio_scan: model saved to {model_path}") + + return model + + +def load_classifier(model_path: str) -> dict | None: + """Load a saved classifier model.""" + if not os.path.exists(model_path): + return None + import joblib + return joblib.load(model_path) + + +def default_model_path(profile_name: str = "default") -> str: + """Return the default path for a profile's classifier model.""" + return os.path.join(_MODEL_DIR, f"{profile_name}.joblib") + + +# --------------------------------------------------------------------------- +# Scanning +# --------------------------------------------------------------------------- def scan_video( video_path: str, - profile: dict, - mode: str = "average", - threshold: float = 0.05, + model: dict = None, + threshold: float = 0.30, hop: float = 1.0, window: float = _WINDOW, cancel_flag: object = None, ) -> list[tuple[float, float, float]]: - """Slide a window across the video audio and score against the profile. + """Scan a video for matching audio regions using a trained classifier. - Pre-computes STFT once for the whole file, then uses vectorized - cumulative-sum sliding window for speed. - - Args: - video_path: path to video/audio file - profile: dict from build_profile() - mode: "average" (compare to mean) or "nearest" (max over all clips) - threshold: minimum similarity to include (0-1, default 0.05) - hop: step size in seconds - window: window size in seconds (default 8s) - cancel_flag: object with _cancel bool attribute; checked periodically - - Returns: - list of (start_time, end_time, score) for regions above threshold + Returns list of (start_time, end_time, score) above threshold. """ + if model is None: + _log("audio_scan: no model provided") + return [] + _log(f"audio_scan: loading {video_path}") y, sr = librosa.load(video_path, sr=_SR, mono=True) duration = len(y) / sr @@ -108,68 +362,33 @@ def scan_video( if cancel_flag and getattr(cancel_flag, '_cancel', False): return [] - # Compute features for the entire file at once (one STFT) - feat = _extract_features_from_signal(y, sr) # (31, T) - n_feats, T = feat.shape - fps = sr / _HOP_LENGTH # frames per second - win_frames = int(window * fps) - hop_frames = int(hop * fps) + clf = model["classifier"] + embed_model = model.get("embed_model") - if win_frames > T: + _log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...") + timestamps, window_vectors = _extract_w2v_windows( + y, sr, hop=hop, window=window, video_path=video_path, + cancel_flag=cancel_flag, model_name=embed_model, + ) + if len(timestamps) == 0: _log("audio_scan: video shorter than window") return [] - _log(f"audio_scan: scanning {T} frames, win={win_frames}, hop={hop_frames}") + # Per-video z-score normalize + vid_mean = window_vectors.mean(axis=0) + vid_std = np.maximum(window_vectors.std(axis=0), 1e-6) + normed = (window_vectors - vid_mean) / vid_std - # Vectorized sliding window via cumulative sums - cumsum = np.zeros((n_feats, T + 1)) - cumsum[:, 1:] = np.cumsum(feat, axis=1) - cumsq = np.zeros((n_feats, T + 1)) - cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1) - - starts = np.arange(0, T - win_frames + 1, hop_frames) - ends = starts + win_frames - - sums = cumsum[:, ends] - cumsum[:, starts] # (31, n_windows) - sq_sums = cumsq[:, ends] - cumsq[:, starts] - means = sums / win_frames - stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10) - - window_vectors = np.vstack([means, stds]).T # (n_windows, 62) + _log(f"audio_scan: classifying {len(normed)} windows...") if cancel_flag and getattr(cancel_flag, '_cancel', False): return [] - # Score all windows - if mode == "nearest": - # Compare each window to every clip vector, take max - clip_vecs = np.stack(profile["clip_vectors"]) # (n_clips, 62) - results = [] - # Process in batches to check cancel_flag periodically - batch = 500 - for i in range(0, len(window_vectors), batch): - if cancel_flag and getattr(cancel_flag, '_cancel', False): - _log("audio_scan: cancelled") - return results - chunk = window_vectors[i:i + batch] - # cdist: (batch, n_clips) distances - dists = np.linalg.norm(chunk[:, None, :] - clip_vecs[None, :, :], axis=2) - scores = 1.0 / (1.0 + dists.min(axis=1)) # min dist = max similarity - for j, score in enumerate(scores): - if score >= threshold: - idx = i + j - start_t = starts[idx] / fps - results.append((start_t, start_t + window, float(score))) - else: - # Average mode: compare to mean vector - ref = profile["mean_vector"] - dists = np.linalg.norm(window_vectors - ref, axis=1) - scores = 1.0 / (1.0 + dists) - mask = scores >= threshold - results = [ - (starts[i] / fps, starts[i] / fps + window, float(scores[i])) - for i in np.nonzero(mask)[0] - ] - + probs = clf.predict_proba(normed)[:, 1] + mask = probs >= threshold + results = [ + (timestamps[i], timestamps[i] + window, float(probs[i])) + for i in np.nonzero(mask)[0] + ] _log(f"audio_scan: {len(results)} regions above threshold {threshold}") return results diff --git a/core/beats_backbone.py b/core/beats_backbone.py new file mode 100644 index 0000000..c0c6c86 --- /dev/null +++ b/core/beats_backbone.py @@ -0,0 +1,783 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .beats_modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) diff --git a/core/beats_model.py b/core/beats_model.py new file mode 100644 index 0000000..002f7c2 --- /dev/null +++ b/core/beats_model.py @@ -0,0 +1,179 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +import torchaudio.compliance.kaldi as ta_kaldi + +from .beats_backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + if self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask diff --git a/core/beats_modules.py b/core/beats_modules.py new file mode 100644 index 0000000..7772b2d --- /dev/null +++ b/core/beats_modules.py @@ -0,0 +1,219 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + diff --git a/core/db.py b/core/db.py index 4818a02..3a970d0 100644 --- a/core/db.py +++ b/core/db.py @@ -1,3 +1,4 @@ +import os import sqlite3 import threading from datetime import datetime, timezone @@ -7,7 +8,7 @@ from .paths import _log class ProcessedDB: - _SCHEMA_VERSION = 3 # bump when schema changes + _SCHEMA_VERSION = 4 # bump when schema changes def __init__(self, db_path: str | None = None): if db_path is None: @@ -47,6 +48,7 @@ class ProcessedDB: " clip_count INTEGER NOT NULL DEFAULT 3," " spread REAL NOT NULL DEFAULT 3.0," " profile TEXT NOT NULL DEFAULT 'default'," + " source_path TEXT NOT NULL DEFAULT ''," " processed_at TEXT NOT NULL" ")" ) @@ -62,6 +64,7 @@ class ProcessedDB: "clip_count": "INTEGER NOT NULL DEFAULT 3", "spread": "REAL NOT NULL DEFAULT 3.0", "profile": "TEXT NOT NULL DEFAULT 'default'", + "source_path": "TEXT NOT NULL DEFAULT ''", } for col, typedef in new_cols.items(): if col not in cols: @@ -85,7 +88,7 @@ class ProcessedDB: short_side: int | None = None, portrait_ratio: str = "", crop_center: float = 0.5, fmt: str = "MP4", clip_count: int = 3, spread: float = 3.0, - profile: str = "default") -> None: + profile: str = "default", source_path: str = "") -> None: if not self._enabled: return with self._lock: @@ -93,11 +96,11 @@ class ProcessedDB: "INSERT INTO processed" " (filename, start_time, output_path, label, category," " short_side, portrait_ratio, crop_center, format," - " clip_count, spread, profile, processed_at)" - " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + " clip_count, spread, profile, source_path, processed_at)" + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (filename, start_time, output_path, label, category, short_side, portrait_ratio, crop_center, fmt, - clip_count, spread, profile, + clip_count, spread, profile, source_path, datetime.now(timezone.utc).isoformat()), ) self._con.commit() @@ -223,6 +226,104 @@ class ProcessedDB: ).fetchall() return [r[0] for r in rows] + def get_export_folders(self, profile: str = "default") -> list[str]: + """Return distinct export folder names found in output_paths for a profile. + + Export paths follow the structure: + .../export_folder/group_dir/clip.mp4 + The export folder is 2 levels up from the clip file. + Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]). + """ + if not self._enabled: + return [] + rows = self._con.execute( + "SELECT DISTINCT output_path FROM processed WHERE profile = ?", + (profile,), + ).fetchall() + folder_names: set[str] = set() + for (op,) in rows: + grandparent = os.path.basename(os.path.dirname(os.path.dirname(op))) + if grandparent: + folder_names.add(grandparent) + return sorted(folder_names) + + def get_training_data(self, profile: str, positive_folder: str, + fallback_video_dir: str = "", + ) -> list[tuple[str, list[float], list[float]]]: + """Build training video_infos from DB data. + + Args: + profile: profile name + positive_folder: export folder name for positive class (e.g. "mp4_Intense") + fallback_video_dir: if source_path is empty, try filename in this dir + + Returns: + list of (source_video_path, positive_times, soft_times) per video. + Soft times = clips from any other export folder. + """ + if not self._enabled: + return [] + rows = self._con.execute( + "SELECT filename, start_time, output_path, source_path" + " FROM processed WHERE profile = ?", + (profile,), + ).fetchall() + + # Collect times by video, split by positive vs other folders + pos_by_video: dict[str, set[float]] = {} + soft_by_video: dict[str, set[float]] = {} + source_by_filename: dict[str, str] = {} + + for fn, st, op, sp in rows: + if sp: + source_by_filename[fn] = sp + grandparent = os.path.basename(os.path.dirname(os.path.dirname(op))) + if grandparent == positive_folder: + pos_by_video.setdefault(fn, set()).add(st) + else: + soft_by_video.setdefault(fn, set()).add(st) + + result = [] + for fn in pos_by_video: + sp = source_by_filename.get(fn, "") + if not sp or not os.path.exists(sp): + # Fallback: try video_dir / filename + if fallback_video_dir: + sp = os.path.join(fallback_video_dir, fn) + if not sp or not os.path.exists(sp): + continue + gt_pos = sorted(pos_by_video[fn]) + gt_soft = sorted(soft_by_video.get(fn, set())) + result.append((sp, gt_pos, gt_soft)) + return result + + def get_training_stats(self, profile: str) -> dict[str, dict]: + """Return per-subprofile stats for training readiness display. + + Returns dict mapping subprofile_name → { + 'videos': number of distinct source videos, + 'clips': total clip count, + } + """ + if not self._enabled: + return {} + rows = self._con.execute( + "SELECT filename, output_path FROM processed WHERE profile = ?", + (profile,), + ).fetchall() + folders = self.get_export_folders(profile) + stats: dict[str, dict] = {} + for folder_name in folders: + videos: set[str] = set() + clips = 0 + for fn, op in rows: + grandparent = os.path.basename(os.path.dirname(os.path.dirname(op))) + if grandparent == folder_name: + videos.add(fn) + clips += 1 + stats[folder_name] = {"videos": len(videos), "clips": clips} + return stats + def hide_file(self, filename: str, profile: str = "default") -> None: if not self._enabled: return diff --git a/main.py b/main.py index 1db8868..c4e79e4 100755 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ from PyQt6.QtWidgets import ( QLabel, QPushButton, QLineEdit, QFileDialog, QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, - QMessageBox, QInputDialog, + QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout, ) from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut @@ -191,12 +191,11 @@ class ScanWorker(QThread): error = pyqtSignal(str) progress = pyqtSignal(str) # status message - def __init__(self, video_path: str, clip_paths: list[str], - mode: str = "average", threshold: float = 0.7): + def __init__(self, video_path: str, model: dict, + threshold: float = 0.30): super().__init__() self._video_path = video_path - self._clip_paths = clip_paths - self._mode = mode + self._model = model self._threshold = threshold self._cancel = False @@ -204,20 +203,12 @@ class ScanWorker(QThread): self._cancel = True def run(self): - from core.audio_scan import build_profile, scan_video + from core.audio_scan import scan_video try: - self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...") - profile = build_profile(self._clip_paths) - if self._cancel: - return - if profile is None: - self.error.emit("No valid reference clips found") - return self.progress.emit("Scanning audio...") regions = scan_video( - self._video_path, profile, - mode=self._mode, threshold=self._threshold, - cancel_flag=self, + self._video_path, model=self._model, + threshold=self._threshold, cancel_flag=self, ) if not self._cancel: self.scan_done.emit(regions) @@ -226,6 +217,151 @@ class ScanWorker(QThread): self.error.emit(str(e)) +class TrainDialog(QDialog): + """Dialog for configuring and launching classifier training.""" + + def __init__(self, db: ProcessedDB, profile: str, video_dir: str = "", + parent=None): + super().__init__(parent) + self.setWindowTitle("Train Classifier") + self.setMinimumWidth(400) + + from core.audio_scan import _EMBED_MODELS + self._db = db + self._profile = profile + self._video_dir = video_dir + + layout = QVBoxLayout(self) + form = QFormLayout() + + # Positive class selector — lists export folders + self._cmb_positive = QComboBox() + stats = db.get_training_stats(profile) + if not stats: + form.addRow("", QLabel("No exported clips found for this profile.")) + for folder_name, info in stats.items(): + label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)" + self._cmb_positive.addItem(label, userData=folder_name) + form.addRow("Positive class:", self._cmb_positive) + + # Model selector + self._cmb_model = QComboBox() + for name in _EMBED_MODELS: + self._cmb_model.addItem(name) + self._cmb_model.setCurrentText("WAV2VEC2_BASE") + form.addRow("Model:", self._cmb_model) + + # Video source directory (fallback for old DB rows without source_path) + self._txt_video_dir = QLineEdit(video_dir) + self._txt_video_dir.setPlaceholderText("Directory containing source videos") + self._debounce = QTimer(self) + self._debounce.setSingleShot(True) + self._debounce.setInterval(400) + self._debounce.timeout.connect(self._update_stats) + self._txt_video_dir.textChanged.connect(lambda: self._debounce.start()) + vid_row = QHBoxLayout() + vid_row.addWidget(self._txt_video_dir) + btn_browse = QPushButton("...") + btn_browse.setFixedWidth(30) + btn_browse.clicked.connect(self._browse_video_dir) + vid_row.addWidget(btn_browse) + form.addRow("Video dir:", vid_row) + + layout.addLayout(form) + + # Stats summary + self._lbl_stats = QLabel() + self._update_stats() + self._cmb_positive.currentIndexChanged.connect(self._update_stats) + layout.addWidget(self._lbl_stats) + + # Buttons + btns = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + btns.button(QDialogButtonBox.StandardButton.Ok).setText("Train") + btns.button(QDialogButtonBox.StandardButton.Ok).setEnabled( + self._cmb_positive.count() > 0 + ) + btns.accepted.connect(self.accept) + btns.rejected.connect(self.reject) + layout.addWidget(btns) + + def _browse_video_dir(self): + d = QFileDialog.getExistingDirectory(self, "Select video source directory") + if d: + self._txt_video_dir.setText(d) + + def _update_stats(self): + folder = self._cmb_positive.currentData() + if not folder: + self._lbl_stats.setText("No export folder data available.") + return + video_infos = self._db.get_training_data( + self._profile, folder, + fallback_video_dir=self._txt_video_dir.text(), + ) + n_videos = len(video_infos) + n_pos = sum(len(gt) for _, gt, _ in video_infos) + n_soft = sum(len(s) for _, _, s in video_infos) + lines = [f"{n_videos} videos with positive clips"] + lines.append(f"{n_pos} positive markers, {n_soft} soft/buffer markers") + if n_videos == 0: + lines.append("No source videos found. Set Video dir above.") + elif n_videos < 3: + lines.append("Recommend at least 3 videos for decent results.") + self._lbl_stats.setText("
".join(lines)) + + @property + def positive_folder(self) -> str: + return self._cmb_positive.currentData() or "" + + @property + def embed_model(self) -> str: + return self._cmb_model.currentText() + + @property + def video_dir(self) -> str: + return self._txt_video_dir.text() + + +class TrainWorker(QThread): + """Trains an audio classifier off the main thread.""" + train_done = pyqtSignal(str) # emits model path on success + error = pyqtSignal(str) + progress = pyqtSignal(str) # per-video status + + def __init__(self, video_infos: list, model_path: str, + embed_model: str | None = None): + super().__init__() + self._video_infos = video_infos + self._model_path = model_path + self._embed_model = embed_model + self._cancel = False + + def cancel(self) -> None: + self._cancel = True + + def run(self): + from core.audio_scan import train_classifier + try: + self.progress.emit(f"Training on {len(self._video_infos)} videos...") + result = train_classifier( + self._video_infos, + model_path=self._model_path, + embed_model=self._embed_model, + ) + if self._cancel: + return + if result is None: + self.error.emit("Training failed: not enough data or missing class balance") + else: + self.train_done.emit(self._model_path) + except Exception as e: + if not self._cancel: + self.error.emit(str(e)) + + class TimelineWidget(QWidget): cursor_changed = pyqtSignal(float) # emits position in seconds seek_changed = pyqtSignal(float) # emits seek position (lock mode) @@ -1564,23 +1700,35 @@ class MainWindow(QMainWindow): self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips") self._btn_scan.clicked.connect(self._start_scan) + self._btn_auto_export = QPushButton("Auto") + self._btn_auto_export.setToolTip("Scan + auto-export best 8s clips") + self._btn_auto_export.clicked.connect(self._auto_export) + + self._btn_train = QPushButton("Train") + self._btn_train.setToolTip("Train audio classifier from exported clips") + self._btn_train.clicked.connect(self._open_train_dialog) + self._train_worker: TrainWorker | None = None + + self._spn_auto_fuse = QDoubleSpinBox() + self._spn_auto_fuse.setDecimals(1) + self._spn_auto_fuse.setRange(0.0, 60.0) + self._spn_auto_fuse.setSingleStep(1.0) + self._spn_auto_fuse.setValue(float(self._settings.value("auto_fuse", "4.0"))) + self._spn_auto_fuse.setPrefix("Fuse: ") + self._spn_auto_fuse.setSuffix("s") + self._spn_auto_fuse.setToolTip("Max gap between scan regions to merge into one cluster") + self._spn_auto_fuse.valueChanged.connect( + lambda v: self._settings.setValue("auto_fuse", str(v)) + ) + self._sld_threshold = QDoubleSpinBox() self._sld_threshold.setDecimals(2) self._sld_threshold.setRange(0.0, 1.0) self._sld_threshold.setSingleStep(0.01) - self._sld_threshold.setValue(0.05) + self._sld_threshold.setValue(0.30) self._sld_threshold.setPrefix("Thr: ") self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)") - self._cmb_scan_mode = QComboBox() - self._cmb_scan_mode.addItems(["Average", "Nearest"]) - self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip") - - self._cmb_scan_ref = QComboBox() - self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"]) - self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed) - self._scan_folder: str = "" - self._scan_worker: ScanWorker | None = None cpu_count = os.cpu_count() or 2 @@ -1716,9 +1864,10 @@ class MainWindow(QMainWindow): settings_row.addWidget(self._chk_rand_square) settings_row.addWidget(self._chk_track) settings_row.addWidget(self._btn_scan) + settings_row.addWidget(self._btn_auto_export) + settings_row.addWidget(self._spn_auto_fuse) settings_row.addWidget(self._sld_threshold) - settings_row.addWidget(self._cmb_scan_mode) - settings_row.addWidget(self._cmb_scan_ref) + settings_row.addWidget(self._btn_train) settings_row.addStretch() self._lbl_status = QLabel() self._lbl_status.setStyleSheet("color: #888; font-size: 11px;") @@ -2503,16 +2652,6 @@ class MainWindow(QMainWindow): return self._step_cursor(markers[0][0] - self._cursor) # wrap to first - def _on_scan_ref_changed(self, index: int) -> None: - if index == 1: # Custom Folder - folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder") - if folder: - self._scan_folder = folder - else: - self._cmb_scan_ref.blockSignals(True) - self._cmb_scan_ref.setCurrentIndex(0) - self._cmb_scan_ref.blockSignals(False) - def _cleanup_scan_worker(self) -> None: """Disconnect signals and schedule deletion of old scan worker.""" if self._scan_worker is not None: @@ -2540,35 +2679,22 @@ class MainWindow(QMainWindow): # Clean up previous worker self._cleanup_scan_worker() - # Collect reference clip paths - if self._cmb_scan_ref.currentIndex() == 0: - # Current profile — all exports across all files in this profile - clip_paths = [p for p in self._db.get_all_export_paths(self._profile) - if os.path.exists(p)] - else: - # Custom folder - if not self._scan_folder: - self._show_status("No reference folder selected") - return - exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac") - clip_paths = [ - os.path.join(self._scan_folder, f) - for f in sorted(os.listdir(self._scan_folder)) - if f.lower().endswith(exts) - ] - - if not clip_paths: - self._show_status("No reference clips found") - return - - mode = self._cmb_scan_mode.currentText().lower() threshold = self._sld_threshold.value() - self._btn_scan.setEnabled(False) - self._scan_file_path = self._file_path # remember which file we're scanning - self._show_status(f"Scanning with {len(clip_paths)} reference clips...") + from core.audio_scan import load_classifier, default_model_path + model_path = default_model_path(self._profile) + model = load_classifier(model_path) - self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold) + if model is None: + self._show_status("No trained model — click Train first") + return + + self._btn_scan.setEnabled(False) + self._scan_file_path = self._file_path + self._show_status("Scanning...") + self._scan_worker = ScanWorker( + self._file_path, model=model, threshold=threshold, + ) self._scan_worker.scan_done.connect(self._on_scan_done) self._scan_worker.error.connect(self._on_scan_error) self._scan_worker.progress.connect(self._show_status) @@ -2576,6 +2702,7 @@ class MainWindow(QMainWindow): def _on_scan_done(self, regions: list) -> None: self._btn_scan.setEnabled(True) + self._btn_auto_export.setEnabled(True) # Ignore stale results if the user switched files during scan if self._file_path != getattr(self, '_scan_file_path', None): return @@ -2584,8 +2711,294 @@ class MainWindow(QMainWindow): def _on_scan_error(self, msg: str) -> None: self._btn_scan.setEnabled(True) + self._btn_auto_export.setEnabled(True) self._show_status(f"Scan error: {msg}") + # ── Training ──────────────────────────────────────────────── + + def _cleanup_train_worker(self) -> None: + """Disconnect signals and schedule deletion of old train worker.""" + if self._train_worker is not None: + try: + self._train_worker.train_done.disconnect() + self._train_worker.error.disconnect() + self._train_worker.progress.disconnect() + except TypeError: + pass + if self._train_worker.isRunning(): + self._train_worker.cancel() + self._train_worker.finished.connect(self._train_worker.deleteLater) + else: + self._train_worker.deleteLater() + self._train_worker = None + + def _open_train_dialog(self): + """Show the training config dialog and start training if accepted.""" + if self._train_worker and self._train_worker.isRunning(): + self._show_status("Training already in progress…") + return + + # Default video dir: parent of currently loaded file, or saved setting + default_dir = "" + if self._file_path: + default_dir = os.path.dirname(self._file_path) + saved_dir = self._settings.value("train_video_dir", default_dir) + + dlg = TrainDialog(self._db, self._profile, + video_dir=saved_dir or default_dir, parent=self) + if dlg.exec() != QDialog.DialogCode.Accepted: + return + + pos_folder = dlg.positive_folder + embed_model = dlg.embed_model + video_dir = dlg.video_dir + if not pos_folder: + self._show_status("No positive class selected") + return + + # Persist video dir for next time + if video_dir: + self._settings.setValue("train_video_dir", video_dir) + + video_infos = self._db.get_training_data( + self._profile, pos_folder, fallback_video_dir=video_dir, + ) + if not video_infos: + self._show_status("No training data found for this subprofile") + return + + from core.audio_scan import default_model_path + model_path = default_model_path(self._profile) + + self._cleanup_train_worker() + self._btn_train.setEnabled(False) + self._show_status(f"Training {embed_model} on {len(video_infos)} videos...") + + self._train_worker = TrainWorker(video_infos, model_path, embed_model) + self._train_worker.train_done.connect(self._on_train_done) + self._train_worker.error.connect(self._on_train_error) + self._train_worker.progress.connect(self._show_status) + self._train_worker.start() + + def _on_train_done(self, model_path: str): + self._btn_train.setEnabled(True) + self._show_status(f"Model trained and saved") + _log(f"Training complete: {model_path}") + + def _on_train_error(self, msg: str): + self._btn_train.setEnabled(True) + self._show_status(f"Training error: {msg}") + + # ── Auto-export ───────────────────────────────────────────── + + def _auto_export(self) -> None: + """Scan → NMS → export one 8s clip per selected position.""" + if not self._file_path: + self._show_status("No video loaded") + return + if self._export_worker and self._export_worker.isRunning(): + self._show_status("Export already running…") + return + if self._scan_worker and self._scan_worker.isRunning(): + self._show_status("Scan already running") + return + + self._cleanup_scan_worker() + self._btn_auto_export.setEnabled(False) + self._btn_scan.setEnabled(False) + + threshold = self._sld_threshold.value() + + from core.audio_scan import load_classifier, default_model_path + model_path = default_model_path(self._profile) + model = load_classifier(model_path) + + if model is not None: + self._scan_file_path = self._file_path + self._show_status("Auto: scanning with classifier...") + self._scan_worker = ScanWorker( + self._file_path, model=model, threshold=threshold, + ) + else: + self._show_status("Auto: no trained model — click Train first") + self._btn_auto_export.setEnabled(True) + self._btn_scan.setEnabled(True) + return + + self._scan_worker.scan_done.connect(self._on_auto_scan_done) + self._scan_worker.error.connect(self._on_scan_error) + self._scan_worker.progress.connect(self._show_status) + self._scan_worker.start() + + @staticmethod + def _select_export_positions(regions: list[tuple[float, float, float]], + min_gap: float = 2.0, + cluster_fuse: float = 30.0, + ) -> list[float]: + """Cluster scan regions, then fill each cluster with clips spaced min_gap apart. + + 1. Merge overlapping regions into clusters, fusing clusters = min_gap for p in cluster_picks): + cluster_picks.append(start) + picked.extend(cluster_picks) + + return sorted(picked) + + def _on_auto_scan_done(self, regions: list) -> None: + self._btn_scan.setEnabled(True) + if self._file_path != getattr(self, '_scan_file_path', None): + self._btn_auto_export.setEnabled(True) + return + + self._timeline.set_scan_regions(regions) + + if not regions: + self._show_status("Auto: no regions found") + self._btn_auto_export.setEnabled(True) + return + + positions = self._select_export_positions( + regions, min_gap=2.0, cluster_fuse=self._spn_auto_fuse.value(), + ) + if not positions: + self._show_status("Auto: no positions after NMS") + self._btn_auto_export.setEnabled(True) + return + + # Build export jobs — one 8s clip per position + folder = self._txt_folder.text() + name = self._txt_name.text() or "clip" + fmt = self._cmb_format.currentText() + image_sequence = fmt == "WebP sequence" + os.makedirs(folder, exist_ok=True) + + # Find starting counter + counter = 1 + while True: + if image_sequence: + p = build_sequence_dir(folder, name, counter, sub=0) + else: + p = build_export_path(folder, name, counter, sub=0) + if not os.path.exists(p): + break + counter += 1 + + jobs = [] + self._auto_export_positions = [] # stash for DB writes + for start_t in positions: + group_dir = os.path.join(folder, f"{name}_{counter:03d}") + os.makedirs(group_dir, exist_ok=True) + if image_sequence: + out = build_sequence_dir(folder, name, counter, sub=0) + else: + out = build_export_path(folder, name, counter, sub=0) + jobs.append((start_t, out, None, 0.5)) + self._auto_export_positions.append((start_t, counter)) + counter += 1 + + self._show_status(f"Auto: exporting {len(jobs)} clips...") + + short_side = self._spn_resize.value() or None + self._export_short_side = short_side + self._export_portrait = "Off" + self._export_format = fmt + self._export_clip_count = 1 + self._export_spread = 0 + self._export_folder = folder + self._export_folder_suffix = "" + + hw_on = self._chk_hw.isChecked() and self._hw_encoders + encoder = self._hw_encoders[0] if hw_on else "libx264" + max_workers = min(self._spn_workers.value(), 3) if hw_on else self._spn_workers.value() + + self._export_worker = ExportWorker( + self._file_path, jobs, + short_side=short_side, + image_sequence=image_sequence, + max_workers=max_workers, + encoder=encoder, + ) + self._export_worker.finished.connect(self._on_auto_clip_done) + self._export_worker.all_done.connect(self._on_auto_batch_done) + self._export_worker.error.connect(self._on_export_error) + self._export_worker.cancelled.connect(self._on_export_cancelled) + self._btn_cancel.setEnabled(True) + self._btn_export.setEnabled(False) + self._set_subprofile_btns_enabled(False) + self._export_worker.start() + + def _on_auto_clip_done(self, path: str): + """Record each auto-exported clip to DB.""" + # Find the start_time for this clip from stashed positions + counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042" + name = self._txt_name.text() or "clip" + start_t = None + for t, c in self._auto_export_positions: + if counter_str == f"{name}_{c:03d}": + start_t = t + break + + label = self._txt_label.currentText().strip() + category = self._cmb_category.currentText() + self._db.add( + os.path.basename(self._file_path), + start_t or 0.0, + path, + label=label, + category=category, + short_side=self._export_short_side, + portrait_ratio="", + crop_center=0.5, + fmt=self._export_format, + clip_count=1, + spread=0, + profile=self._profile, + source_path=self._file_path, + ) + upsert_clip_annotation(self._export_folder, path, label) + self._show_status(f"Auto: {os.path.basename(path)}") + _log(f" auto clip done: {os.path.basename(path)}") + + def _on_auto_batch_done(self): + n = len(self._auto_export_positions) + self._btn_auto_export.setEnabled(True) + self._btn_cancel.setEnabled(False) + self._btn_export.setEnabled(True) + self._set_subprofile_btns_enabled(True) + self._refresh_markers() + markers = self._db.get_markers(os.path.basename(self._file_path), self._profile) + self._playlist.mark_done(self._file_path, len(markers)) + self._update_next_label() + self._show_status(f"Auto export complete: {n} clips") + _log(f"Auto export complete: {n} clips") + def _jump_to_next_scan_region(self) -> None: regions = sorted(self._timeline._scan_regions, key=lambda r: r[0]) if not regions: @@ -2812,6 +3225,7 @@ class MainWindow(QMainWindow): clip_count=self._export_clip_count, spread=self._export_spread, profile=self._profile, + source_path=self._file_path, ) upsert_clip_annotation(self._export_folder, path, label) self._last_export_path = path @@ -2851,6 +3265,7 @@ class MainWindow(QMainWindow): _log(f"Export error: {msg}") self._btn_cancel.setEnabled(False) self._btn_export.setEnabled(True) + self._btn_auto_export.setEnabled(True) self._set_subprofile_btns_enabled(True) self._btn_export.setText("Export") self._btn_export.setStyleSheet("") @@ -2866,6 +3281,7 @@ class MainWindow(QMainWindow): def _on_export_cancelled(self): _log("Export cancelled") self._btn_export.setEnabled(True) + self._btn_auto_export.setEnabled(True) self._set_subprofile_btns_enabled(True) self._btn_export.setText("Export") self._btn_export.setStyleSheet("") @@ -2886,6 +3302,9 @@ class MainWindow(QMainWindow): _log("Shutting down…") # Save session playlist for resume. self._settings.setValue("session_files", self._playlist._paths) + # Cancel background workers to prevent callbacks into dead objects. + self._cleanup_scan_worker() + self._cleanup_train_worker() # Stop timers first to prevent callbacks into dead objects. self._preview_timer.stop() self._mpv._render_timer.stop() diff --git a/requirements.txt b/requirements.txt index 180af07..f520695 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,25 @@ +# Core GUI PyQt6>=6.4 python-mpv>=1.0 -pytest>=7.0 + +# Audio & ML +librosa>=0.10 +numpy>=1.24 +scikit-learn>=1.3 +joblib>=1.3 +soundfile>=0.12 + +# Deep learning (torch installed separately for CUDA support) +# torch and torchaudio are installed via --index-url in setup_env.sh +torchaudio>=2.0 + +# Object detection ultralytics>=8.0 + +# Server API +fastapi>=0.100 +pydantic>=2.0 +uvicorn>=0.23 + +# Dev +pytest>=7.0 diff --git a/setup_env.sh b/setup_env.sh new file mode 100755 index 0000000..f888830 --- /dev/null +++ b/setup_env.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ────────────────────────────────────────────────────────────────────── +# 8-cut environment setup — supports conda (miniforge) or python venv +# +# Usage: +# ./setup_env.sh # auto-detect (prefers conda if available) +# ./setup_env.sh --conda # force conda +# ./setup_env.sh --venv # force python venv +# ─��────────────────────────────��─────────────────────────────────────── + +ENV_NAME="8cut" +PYTHON_VERSION="3.12" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" + +# CUDA version for PyTorch index URL +TORCH_INDEX="https://download.pytorch.org/whl/cu128" + +# ── Parse args ──────────────────────────────────────────────────────── + +MODE="" +for arg in "$@"; do + case "$arg" in + --conda) MODE="conda" ;; + --venv) MODE="venv" ;; + *) echo "Unknown arg: $arg"; exit 1 ;; + esac +done + +if [ -z "$MODE" ]; then + if command -v conda &>/dev/null; then + MODE="conda" + else + MODE="venv" + fi + echo "Auto-detected mode: $MODE" +fi + +# ── Conda setup ─────────────��───────────────────────────────────────── + +setup_conda() { + echo "==> Setting up conda environment: $ENV_NAME" + + # Source conda shell hooks if not already active + if ! command -v conda &>/dev/null; then + echo "conda not found in PATH" + exit 1 + fi + eval "$(conda shell.bash hook)" + + if conda env list | grep -qw "$ENV_NAME"; then + echo " Environment '$ENV_NAME' already exists, updating..." + conda activate "$ENV_NAME" + else + echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..." + conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION" + conda activate "$ENV_NAME" + fi + + echo " Installing PyTorch + torchaudio (CUDA 12.8)..." + pip install torch torchaudio --index-url "$TORCH_INDEX" + + echo " Installing project dependencies..." + pip install -r "$SCRIPT_DIR/requirements.txt" + + echo "" + echo "Done! Activate with:" + echo " conda activate $ENV_NAME" +} + +# ── Venv setup ───────��──────────────────────────────────────────────── + +setup_venv() { + echo "==> Setting up Python venv at: $VENV_DIR" + + if [ ! -d "$VENV_DIR" ]; then + python3 -m venv "$VENV_DIR" + echo " Created venv" + else + echo " Venv already exists, updating..." + fi + + source "$VENV_DIR/bin/activate" + + echo " Installing PyTorch + torchaudio (CUDA 12.8)..." + pip install torch torchaudio --index-url "$TORCH_INDEX" + + echo " Installing project dependencies..." + pip install -r "$SCRIPT_DIR/requirements.txt" + + echo "" + echo "Done! Activate with:" + echo " source $VENV_DIR/bin/activate" +} + +# ── Run ─────────────────────────────────────────────────────────────── + +case "$MODE" in + conda) setup_conda ;; + venv) setup_venv ;; +esac + +echo "" +echo "Verify with:" +echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\"" +echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\"" diff --git a/tests/test_audio_scan.py b/tests/test_audio_scan.py index bdd1c6e..1335527 100644 --- a/tests/test_audio_scan.py +++ b/tests/test_audio_scan.py @@ -1,154 +1,28 @@ import tempfile, os import numpy as np -from core.audio_scan import build_profile, _extract_features, scan_video, _similarity +from core.audio_scan import scan_video, load_classifier, default_model_path -def _make_wav(path: str, duration: float = 8.0, sr: int = 16000, freq: float = 440.0): - """Create a short sine-wave WAV file for testing.""" - import soundfile as sf - t = np.linspace(0, duration, int(sr * duration), endpoint=False) - audio = 0.5 * np.sin(2 * np.pi * freq * t) - sf.write(path, audio, sr) - - -def test_extract_features_returns_62d_vector(): - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: - _make_wav(f.name) - try: - vec = _extract_features(f.name) - assert vec.shape == (62,) - assert not np.isnan(vec).any() - finally: - os.unlink(f.name) - - -def test_build_profile_single_clip(): - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: - _make_wav(f.name) - try: - profile = build_profile([f.name]) - assert "mean_vector" in profile - assert "clip_vectors" in profile - assert profile["mean_vector"].shape == (62,) - assert len(profile["clip_vectors"]) == 1 - finally: - os.unlink(f.name) - - -def test_build_profile_multiple_clips(): - paths = [] - try: - for i in range(3): - f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) - _make_wav(f.name, freq=440 + i * 200) - paths.append(f.name) - f.close() - - profile = build_profile(paths) - assert len(profile["clip_vectors"]) == 3 - assert profile["mean_vector"].shape == (62,) - finally: - for p in paths: - os.unlink(p) - - -def test_build_profile_skips_missing_files(): - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: - _make_wav(f.name) - try: - profile = build_profile([f.name, "/no/such/file.wav"]) - assert len(profile["clip_vectors"]) == 1 - finally: - os.unlink(f.name) - - -def test_build_profile_empty_returns_none(): - result = build_profile([]) - assert result is None - - -def test_similarity_identical_is_one(): - a = np.array([1.0, 2.0, 3.0]) - assert abs(_similarity(a, a) - 1.0) < 1e-9 - - -def test_similarity_distant_is_low(): - a = np.zeros(62) - b = np.ones(62) * 100 - assert _similarity(a, b) < 0.01 - - -def test_scan_video_finds_matching_region(): - """A video made of the same sine wave as the reference should match.""" - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref: - _make_wav(ref.name, duration=8.0) +def test_scan_video_no_model_returns_empty(): + """scan_video with no model should return empty list.""" with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid: - _make_wav(vid.name, duration=20.0) - try: - profile = build_profile([ref.name]) - regions = scan_video(vid.name, profile, mode="average", threshold=0.01, hop=1.0) - assert len(regions) > 0 - for start, end, score in regions: - assert abs((end - start) - 8.0) < 0.1 - assert score >= 0.01 - finally: - os.unlink(ref.name) - os.unlink(vid.name) - - -def test_scan_video_nearest_mode(): - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref: - _make_wav(ref.name, duration=8.0) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid: - _make_wav(vid.name, duration=20.0) - try: - profile = build_profile([ref.name]) - regions = scan_video(vid.name, profile, mode="nearest", threshold=0.01, hop=1.0) - assert len(regions) > 0 - finally: - os.unlink(ref.name) - os.unlink(vid.name) - - -def test_scan_video_high_threshold_no_match(): - """Different frequencies with very high threshold should not match.""" - import soundfile as sf - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref: - _make_wav(ref.name, duration=8.0, freq=440) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid: - # White noise — very different from sine wave + import soundfile as sf sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000) try: - profile = build_profile([ref.name]) - regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0) - assert len(regions) == 0 + regions = scan_video(vid.name, model=None) + assert regions == [] finally: - os.unlink(ref.name) os.unlink(vid.name) -def test_scan_video_same_vs_different_discrimination(): - """Same-frequency match should score higher than cross-frequency.""" - import soundfile as sf - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref: - _make_wav(ref.name, duration=8.0, freq=440) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as same: - _make_wav(same.name, duration=10.0, freq=440) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as diff: - # White noise - sf.write(diff.name, np.random.randn(16000 * 10).astype(np.float32) * 0.1, 16000) - try: - profile = build_profile([ref.name]) - same_regions = scan_video(same.name, profile, mode="average", threshold=0.0, hop=1.0) - diff_regions = scan_video(diff.name, profile, mode="average", threshold=0.0, hop=1.0) - # Same-audio scores should be higher than noise scores - best_same = max(r[2] for r in same_regions) - best_diff = max(r[2] for r in diff_regions) - assert best_same > best_diff - finally: - os.unlink(ref.name) - os.unlink(same.name) - os.unlink(diff.name) +def test_load_classifier_missing_returns_none(): + assert load_classifier("/no/such/model.joblib") is None + + +def test_default_model_path_contains_profile(): + path = default_model_path("test_profile") + assert "test_profile" in path + assert path.endswith(".joblib") def test_db_get_all_export_paths():