feat: integrate training UI, BEATs model, and clean up legacy code

- Remove legacy distance-mode scanning (build_profile, _similarity, etc.)
  and hand-crafted intensity features — pipeline is now embedding-only
- Integrate Microsoft BEATs as embedding option alongside wav2vec2/HuBERT
- Add TrainDialog with positive class selector, model picker, video dir
  fallback, and live training stats
- Add TrainWorker QThread with cancel support and proper lifecycle cleanup
- Add source_path column to DB for robust source video tracking
- Add get_export_folders/get_training_data/get_training_stats to DB
- Wire source_path in all export DB writes (_on_clip_done, _on_auto_clip_done)
- Cancel scan/train workers in closeEvent to prevent use-after-free crashes
- Add setup_env.sh supporting both conda and python venv (CUDA 12.8)
- Update requirements.txt with all actual dependencies
- Update 8cut_train.py with --positive flag for new DB-driven training

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-18 11:52:27 +02:00
parent f2c38aee79
commit 12ed183f1b
11 changed files with 2608 additions and 338 deletions
+255
View File
@@ -0,0 +1,255 @@
#!/usr/bin/env python3
"""Calibration — per-video normalized features + classifier."""
import sys, os, time, warnings
sys.path.insert(0, os.path.dirname(__file__))
warnings.filterwarnings("ignore")
import numpy as np
import librosa
from sklearn.ensemble import GradientBoostingClassifier
from core.audio_scan import _SR, _WINDOW
_HOP_LENGTH = 1024
_N_FFT = 2048
from core.db import ProcessedDB
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
PROFILE_NAME = "JAV_missionary"
TOLERANCE = 12.0
NEG_MARGIN = 120.0
def extract_rich_features(y, sr=_SR):
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
hop = _HOP_LENGTH
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
rms = librosa.feature.rms(S=S, hop_length=hop)
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
flatness = librosa.feature.spectral_flatness(S=S)
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
band_feats = []
for flo, fhi in bands:
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
if mask.sum() > 0:
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
else:
band_feats.append(np.zeros((1, mel_S.shape[1])))
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
band_feats[0].shape[1])
return np.vstack([
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
] + [b[:, :min_t] for b in band_feats]
+ [sc[:, :min_t]])
def compute_window_stats(feat, hop=1.0):
"""Sliding window mean/std → (timestamps, feature_vectors)."""
n_feats, T = feat.shape
fps = _SR / _HOP_LENGTH
win_frames = int(_WINDOW * fps)
hop_frames = int(hop * fps)
if win_frames > T:
return np.array([]), np.array([])
cumsum = np.zeros((n_feats, T + 1))
cumsum[:, 1:] = np.cumsum(feat, axis=1)
cumsq = np.zeros((n_feats, T + 1))
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
starts = np.arange(0, T - win_frames + 1, hop_frames)
ends = starts + win_frames
sums = cumsum[:, ends] - cumsum[:, starts]
sq_sums = cumsq[:, ends] - cumsq[:, starts]
means = sums / win_frames
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
return starts / fps, np.vstack([means, stds]).T
def label_windows(timestamps, gt_intense, gt_soft):
all_gt = list(gt_intense) + list(gt_soft)
labels = np.zeros(len(timestamps), dtype=int)
for i, t in enumerate(timestamps):
di = min((abs(t - g) for g in gt_intense), default=9999)
da = min((abs(t - g) for g in all_gt), default=9999)
if di < TOLERANCE:
labels[i] = 1
elif da > NEG_MARGIN:
labels[i] = -1
return labels
def main():
db = ProcessedDB()
rows = db._con.execute(
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
(PROFILE_NAME,),
).fetchall()
intense_by_video, soft_by_video = {}, {}
for fn, st, op in rows:
if '/mp4_Intense/' in op:
intense_by_video.setdefault(fn, set()).add(st)
elif '/mp4_Soft/' in op:
soft_by_video.setdefault(fn, set()).add(st)
videos = [fn for fn in intense_by_video
if os.path.exists(os.path.join(PLEX_DIR, fn))]
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
videos = videos[:n_vids]
print(f"Processing {len(videos)} videos...")
all_data_raw = [] # raw features
all_data_norm = [] # per-video z-scored features
for vi, vname in enumerate(videos):
vpath = os.path.join(PLEX_DIR, vname)
gt_intense = sorted(intense_by_video.get(vname, set()))
gt_soft = sorted(soft_by_video.get(vname, set()))
t0 = time.time()
y, _ = librosa.load(vpath, sr=_SR, mono=True)
feat = extract_rich_features(y)
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
dt = time.time() - t0
if len(timestamps) == 0:
continue
labels = label_windows(timestamps, gt_intense, gt_soft)
# Per-video z-score normalization
vid_mean = window_vectors.mean(axis=0)
vid_std = window_vectors.std(axis=0)
vid_std = np.maximum(vid_std, 1e-6)
normed = (window_vectors - vid_mean) / vid_std
n_pos = (labels == 1).sum()
n_neg = (labels == -1).sum()
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
all_data_norm.append((vi, vname, timestamps, normed, labels))
# Run CV for both raw and normalized
for label, data in [("RAW features", all_data_raw),
("PER-VIDEO NORMALIZED features", all_data_norm)]:
print(f"\n{'='*70}")
print(f" {label}")
print(f"{'='*70}")
all_y_true, all_y_prob = [], []
for test_idx in range(len(data)):
_, vname, _, test_X, test_labels = data[test_idx]
test_mask = test_labels != 0
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
continue
X_test = test_X[test_mask]
y_test = (test_labels[test_mask] == 1).astype(int)
X_parts, y_parts = [], []
for i, (_, _, _, feats, labs) in enumerate(data):
if i == test_idx:
continue
m = labs != 0
if m.sum() == 0:
continue
X_parts.append(feats[m])
y_parts.append((labs[m] == 1).astype(int))
if not X_parts:
continue
X_train = np.vstack(X_parts)
y_train = np.concatenate(y_parts)
pos_idx = np.where(y_train == 1)[0]
neg_idx = np.where(y_train == 0)[0]
if len(pos_idx) == 0 or len(neg_idx) == 0:
continue
rng = np.random.RandomState(42)
n_neg = min(len(neg_idx), len(pos_idx) * 3)
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
train_idx = np.concatenate([pos_idx, neg_sample])
clf = GradientBoostingClassifier(
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
)
clf.fit(X_train[train_idx], y_train[train_idx])
probs = clf.predict_proba(X_test)[:, 1]
tp = ((probs >= 0.5) & (y_test == 1)).sum()
fp = ((probs >= 0.5) & (y_test == 0)).sum()
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
all_y_true.extend(y_test)
all_y_prob.extend(probs)
if not all_y_true:
print(" No test results.")
continue
y_true = np.array(all_y_true)
y_prob = np.array(all_y_prob)
pos_probs = y_prob[y_true == 1]
neg_probs = y_prob[y_true == 0]
if len(pos_probs) > 0 and len(neg_probs) > 0:
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
best_f1, best_thr = 0, 0
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
for thr in np.arange(0.10, 0.91, 0.05):
tp = ((y_prob >= thr) & (y_true == 1)).sum()
fp = ((y_prob >= thr) & (y_true == 0)).sum()
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
if f1 > best_f1:
best_f1, best_thr = f1, thr
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
# Feature importance
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
pos_idx = np.where(y_all == 1)[0]
neg_idx = np.where(y_all == 0)[0]
rng = np.random.RandomState(42)
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
feat_names = (
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
+ [f"mel{i}" for i in range(8)]
+ [f"sc{i}" for i in range(7)]
)
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
imp = clf.feature_importances_
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
if __name__ == "__main__":
main()
+92
View File
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Train an audio scan classifier from DB ground truth.
Usage:
python 8cut_train.py # default model, auto-detect positive
python 8cut_train.py --model BEATS # specific embedding model
python 8cut_train.py --positive mp4_Intense # explicit positive folder
python 8cut_train.py --positive mp4_Intense --model BEATS # both
"""
import sys, os, warnings
sys.path.insert(0, os.path.dirname(__file__))
warnings.filterwarnings("ignore")
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
from core.db import ProcessedDB
PROFILE_NAME = "JAV_missionary"
# Fallback for old DB rows without source_path
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
def main():
embed_model = None
if "--model" in sys.argv:
idx = sys.argv.index("--model")
if idx + 1 < len(sys.argv):
embed_model = sys.argv[idx + 1]
if embed_model not in _EMBED_MODELS:
print(f"Unknown model: {embed_model}")
print(f"Available: {', '.join(_EMBED_MODELS)}")
sys.exit(1)
positive_suffix = None
if "--positive" in sys.argv:
idx = sys.argv.index("--positive")
if idx + 1 < len(sys.argv):
positive_suffix = sys.argv[idx + 1]
db = ProcessedDB()
# If --positive given, use the new DB helper
if positive_suffix:
video_infos = db.get_training_data(
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
)
if not video_infos:
print(f"No training data found for positive='{positive_suffix}'")
sys.exit(1)
else:
# Legacy fallback: classify by folder path pattern
rows = db._con.execute(
"SELECT filename, start_time, output_path, source_path"
" FROM processed WHERE profile = ?",
(PROFILE_NAME,),
).fetchall()
intense_by_video, soft_by_video = {}, {}
source_by_fn = {}
for fn, st, op, sp in rows:
if sp:
source_by_fn[fn] = sp
if "/mp4_Intense/" in op or "_Intense/" in op:
intense_by_video.setdefault(fn, set()).add(st)
elif "/mp4_Soft/" in op or "_Soft/" in op:
soft_by_video.setdefault(fn, set()).add(st)
video_infos = []
for fn in intense_by_video:
# Try source_path from DB first, fall back to PLEX_DIR
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
if not os.path.exists(vpath):
print(f" skip (not found): {fn}")
continue
gt_intense = sorted(intense_by_video[fn])
gt_soft = sorted(soft_by_video.get(fn, set()))
video_infos.append((vpath, gt_intense, gt_soft))
label = embed_model or "WAV2VEC2_BASE"
print(f"Training {label} model on {len(video_infos)} videos...")
model_path = default_model_path(PROFILE_NAME)
result = train_classifier(
video_infos, model_path=model_path, embed_model=embed_model,
)
if result is None:
print("Training failed: no valid samples or missing class balance")
sys.exit(1)
print(f"Model saved to {model_path}")
if __name__ == "__main__":
main()
+347 -128
View File
@@ -1,105 +1,359 @@
"""Audio similarity scanning — MFCC + spectral contrast profile matching.""" """Audio scanning — embedding-based classifier for audio event detection."""
import hashlib
import os
import numpy as np import numpy as np
import librosa import librosa
from .paths import _log from .paths import _log
_N_MFCC = 13 # coefficients 0-12; we drop C0 → 12 usable _SR = 16000 # lower sr = faster
_SR = 16000 # lower sr = faster, no quality loss for style matching
_HOP_LENGTH = 1024 # STFT hop (~64ms frames at 16kHz)
_N_FFT = 2048 # STFT window
_WINDOW = 8.0 # seconds _WINDOW = 8.0 # seconds
_N_FEATURES = 62 # (12 mfcc + 12 delta + 7 sc) * 2 (mean + std) _MODEL_DIR = os.path.join(os.path.expanduser("~"), ".8cut_models")
_W2V_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".8cut_cache", "w2v")
# ---------------------------------------------------------------------------
# Embedding extraction (lazy-loaded)
# ---------------------------------------------------------------------------
def _extract_features_from_signal(y: np.ndarray, sr: int = _SR) -> np.ndarray: _w2v_model = None
"""Compute feature matrix (31 x T) from a raw audio signal. _w2v_device = None
_w2v_model_name = None
Features per frame: 12 MFCCs (skip C0) + 12 delta MFCCs + 7 spectral contrast. # Supported embedding models — name → embed_dim
""" _EMBED_MODELS = {
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=_HOP_LENGTH)) ** 2 "WAV2VEC2_BASE": 768,
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=_HOP_LENGTH) "WAV2VEC2_LARGE": 1024,
mfcc = librosa.feature.mfcc(S=librosa.power_to_db(mel_S), sr=sr, n_mfcc=_N_MFCC) "WAV2VEC2_LARGE_LV60K":1024,
mfcc = mfcc[1:] # drop C0 (energy) — dominates cosine sim, kills discrimination "HUBERT_BASE": 768,
delta = librosa.feature.delta(mfcc) "HUBERT_LARGE": 1024,
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=_HOP_LENGTH) "HUBERT_XLARGE": 1280,
return np.vstack([mfcc, delta, sc]) # (31, T) "BEATS": 768,
def _aggregate(feature_matrix: np.ndarray) -> np.ndarray:
"""Collapse a (31, T) feature matrix into a (62,) vector via mean + std."""
return np.concatenate([
feature_matrix.mean(axis=1),
feature_matrix.std(axis=1),
])
def _extract_features(path: str, sr: int = _SR) -> np.ndarray:
"""Load audio from a file and return a 62-dim feature vector."""
y, _ = librosa.load(path, sr=sr, mono=True)
feat = _extract_features_from_signal(y, sr)
return _aggregate(feat)
def build_profile(clip_paths: list[str]) -> dict | None:
"""Extract features from reference clips.
Returns dict with:
- mean_vector: averaged feature vector across all clips (62,)
- clip_vectors: list of individual feature vectors
Returns None if no clips could be loaded.
"""
vectors = []
for p in clip_paths:
try:
vec = _extract_features(p)
vectors.append(vec)
except Exception as e:
_log(f"audio_scan: skip {p}: {e}")
if not vectors:
return None
arr = np.stack(vectors)
return {
"mean_vector": arr.mean(axis=0),
"clip_vectors": vectors,
} }
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
_BEATS_CHECKPOINT = os.path.join(
os.path.expanduser("~"), ".cache", "huggingface", "hub",
"models--lpepino--beats_ckpts", "snapshots",
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
)
def _similarity(a: np.ndarray, b: np.ndarray) -> float: def _get_w2v_model(model_name: str | None = None):
"""Euclidean-distance-based similarity in (0, 1]. """Lazy-load an embedding model. Reloads if model_name differs from cached."""
global _w2v_model, _w2v_device, _w2v_model_name
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
if _w2v_model is None or _w2v_model_name != model_name:
import torch
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
1/(1+dist): identical → 1.0, very different → near 0. if model_name == "BEATS":
from .beats_model import BEATs, BEATsConfig
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
weights_only=False)
cfg = BEATsConfig(checkpoint['cfg'])
_w2v_model = BEATs(cfg)
_w2v_model.load_state_dict(checkpoint['model'])
_w2v_model.to(_w2v_device)
else:
import torchaudio
bundle = getattr(torchaudio.pipelines, model_name)
_w2v_model = bundle.get_model().to(_w2v_device)
_w2v_model.eval()
_w2v_model_name = model_name
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
return _w2v_model, _w2v_device
def _embed_dim(model_name: str | None = None) -> int:
"""Return embedding dimension for a model name."""
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
return _EMBED_MODELS.get(model_name, 768)
def _w2v_cache_path(video_path: str, hop: float, window: float,
model_name: str | None = None) -> str:
"""Return cache file path for a video's embeddings (includes model name)."""
if model_name is None:
model_name = _DEFAULT_EMBED_MODEL
abspath = os.path.abspath(video_path)
mtime = os.path.getmtime(abspath)
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
h = hashlib.sha256(key.encode()).hexdigest()[:16]
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
hop: float = 1.0, window: float = _WINDOW,
video_path: str | None = None,
cancel_flag: object = None,
model_name: str | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract embeddings for all sliding windows using a torchaudio model.
If video_path is given, results are cached to disk for fast re-scans.
Returns (timestamps, embeddings) where embeddings is (N, D).
""" """
return float(1.0 / (1.0 + np.linalg.norm(a - b))) edim = _embed_dim(model_name)
# Try loading from cache
cache_file = None
if video_path:
try:
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
if os.path.exists(cache_file):
data = np.load(cache_file)
_log(f"audio_scan: cache hit ({cache_file})")
return data["timestamps"], data["embeddings"]
except Exception as e:
_log(f"audio_scan: cache read failed: {e}")
win_samples = int(window * sr)
hop_samples = int(hop * sr)
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
if n_windows == 0:
return np.array([]), np.empty((0, edim))
import torch
model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
batch_size = 16
timestamps = np.arange(n_windows) * hop
embeddings = []
for batch_start in range(0, n_windows, batch_size):
if cancel_flag and getattr(cancel_flag, '_cancel', False):
return np.array([]), np.empty((0, edim))
batch_end = min(batch_start + batch_size, n_windows)
chunks = []
for i in range(batch_start, batch_end):
start = i * hop_samples
chunks.append(y[start:start + win_samples])
with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings.append(batch_emb)
result_ts = timestamps
result_emb = np.vstack(embeddings)
# Save to cache
if cache_file:
try:
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
_log(f"audio_scan: w2v cache saved ({cache_file})")
except Exception as e:
_log(f"audio_scan: cache write failed: {e}")
return result_ts, result_emb
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
gt_soft: list[float], tolerance: float = 12.0,
neg_margin: float = 120.0,
model_name: str | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract embeddings only near positives and distant negatives.
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
"""
edim = _embed_dim(model_name)
duration = len(y) / sr
win_samples = int(_WINDOW * sr)
all_gt = list(gt_intense) + list(gt_soft)
# Positive windows: every second near intense markers
pos_times = set()
for gt in gt_intense:
for offset in range(-int(tolerance), int(tolerance) + 1):
t = gt + offset
if 0 <= t <= duration - _WINDOW:
pos_times.add(int(t))
# Negative windows: every 4s, far from any marker
neg_times = set()
for t in range(0, int(duration - _WINDOW), 4):
if min((abs(t - g) for g in all_gt), default=9999) > neg_margin:
neg_times.add(t)
all_times = sorted(pos_times | neg_times)
# Filter out windows that go past the end
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
if not valid_times:
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
import torch
model, device = _get_w2v_model(model_name)
batch_size = 16
timestamps_list: list[float] = []
embeddings_list: list[np.ndarray] = []
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
for batch_start in range(0, len(valid_times), batch_size):
batch_end = min(batch_start + batch_size, len(valid_times))
chunks = []
for t in valid_times[batch_start:batch_end]:
start = int(t * sr)
chunks.append(y[start:start + win_samples])
timestamps_list.append(float(t))
with torch.no_grad():
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
if is_beats:
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
else:
features, _ = model(waveforms)
batch_emb = features.mean(dim=1).cpu().numpy()
embeddings_list.append(batch_emb)
timestamps = np.array(timestamps_list)
embeddings = np.vstack(embeddings_list)
labels = np.zeros(len(timestamps), dtype=int)
for i, t in enumerate(timestamps):
di = min((abs(t - g) for g in gt_intense), default=9999)
da = min((abs(t - g) for g in all_gt), default=9999)
if di < tolerance:
labels[i] = 1
elif da > neg_margin:
labels[i] = -1
return timestamps, embeddings, labels
# ---------------------------------------------------------------------------
# Classifier mode — train / save / load / scan
# ---------------------------------------------------------------------------
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
model_path: str | None = None,
tolerance: float = 12.0,
neg_margin: float = 120.0,
embed_model: str | None = None) -> dict:
"""Train a classifier from labeled videos.
Args:
video_infos: list of (video_path, intense_times, soft_times)
model_path: if given, save model to this path
tolerance/neg_margin: labeling parameters
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
Returns:
dict with 'classifier', 'embed_model', and metadata, or None on failure.
"""
from sklearn.ensemble import GradientBoostingClassifier
all_X, all_y = [], []
for vi, (vpath, gt_intense, gt_soft) in enumerate(video_infos):
_log(f"audio_scan: training [{vi+1}/{len(video_infos)}] {os.path.basename(vpath)}")
y, _ = librosa.load(vpath, sr=_SR, mono=True)
timestamps, embeddings, labels = _extract_w2v_targeted(
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
model_name=embed_model,
)
if len(timestamps) == 0:
continue
# Per-video z-score normalize
vid_mean = embeddings.mean(axis=0)
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
normed = (embeddings - vid_mean) / vid_std
for i in range(len(labels)):
if labels[i] == 1:
all_X.append(normed[i])
all_y.append(1)
elif labels[i] == -1:
all_X.append(normed[i])
all_y.append(0)
if not all_X:
_log("audio_scan: no training samples collected")
return None
X = np.stack(all_X)
y_arr = np.array(all_y)
n_pos = (y_arr == 1).sum()
n_neg = (y_arr == 0).sum()
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
if n_pos == 0 or n_neg == 0:
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
return None
# Subsample negatives for balance
rng = np.random.RandomState(42)
pos_idx = np.where(y_arr == 1)[0]
neg_idx = np.where(y_arr == 0)[0]
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
train_idx = np.concatenate([pos_idx, neg_sample])
rng.shuffle(train_idx)
clf = GradientBoostingClassifier(
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42,
)
clf.fit(X[train_idx], y_arr[train_idx])
_log("audio_scan: classifier trained")
model = {"classifier": clf, "n_features": X.shape[1],
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
if model_path:
import joblib
parent = os.path.dirname(model_path)
if parent:
os.makedirs(parent, exist_ok=True)
joblib.dump(model, model_path)
_log(f"audio_scan: model saved to {model_path}")
return model
def load_classifier(model_path: str) -> dict | None:
"""Load a saved classifier model."""
if not os.path.exists(model_path):
return None
import joblib
return joblib.load(model_path)
def default_model_path(profile_name: str = "default") -> str:
"""Return the default path for a profile's classifier model."""
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
# ---------------------------------------------------------------------------
# Scanning
# ---------------------------------------------------------------------------
def scan_video( def scan_video(
video_path: str, video_path: str,
profile: dict, model: dict = None,
mode: str = "average", threshold: float = 0.30,
threshold: float = 0.05,
hop: float = 1.0, hop: float = 1.0,
window: float = _WINDOW, window: float = _WINDOW,
cancel_flag: object = None, cancel_flag: object = None,
) -> list[tuple[float, float, float]]: ) -> list[tuple[float, float, float]]:
"""Slide a window across the video audio and score against the profile. """Scan a video for matching audio regions using a trained classifier.
Pre-computes STFT once for the whole file, then uses vectorized Returns list of (start_time, end_time, score) above threshold.
cumulative-sum sliding window for speed.
Args:
video_path: path to video/audio file
profile: dict from build_profile()
mode: "average" (compare to mean) or "nearest" (max over all clips)
threshold: minimum similarity to include (0-1, default 0.05)
hop: step size in seconds
window: window size in seconds (default 8s)
cancel_flag: object with _cancel bool attribute; checked periodically
Returns:
list of (start_time, end_time, score) for regions above threshold
""" """
if model is None:
_log("audio_scan: no model provided")
return []
_log(f"audio_scan: loading {video_path}") _log(f"audio_scan: loading {video_path}")
y, sr = librosa.load(video_path, sr=_SR, mono=True) y, sr = librosa.load(video_path, sr=_SR, mono=True)
duration = len(y) / sr duration = len(y) / sr
@@ -108,68 +362,33 @@ def scan_video(
if cancel_flag and getattr(cancel_flag, '_cancel', False): if cancel_flag and getattr(cancel_flag, '_cancel', False):
return [] return []
# Compute features for the entire file at once (one STFT) clf = model["classifier"]
feat = _extract_features_from_signal(y, sr) # (31, T) embed_model = model.get("embed_model")
n_feats, T = feat.shape
fps = sr / _HOP_LENGTH # frames per second
win_frames = int(window * fps)
hop_frames = int(hop * fps)
if win_frames > T: _log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
timestamps, window_vectors = _extract_w2v_windows(
y, sr, hop=hop, window=window, video_path=video_path,
cancel_flag=cancel_flag, model_name=embed_model,
)
if len(timestamps) == 0:
_log("audio_scan: video shorter than window") _log("audio_scan: video shorter than window")
return [] return []
_log(f"audio_scan: scanning {T} frames, win={win_frames}, hop={hop_frames}") # Per-video z-score normalize
vid_mean = window_vectors.mean(axis=0)
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
normed = (window_vectors - vid_mean) / vid_std
# Vectorized sliding window via cumulative sums _log(f"audio_scan: classifying {len(normed)} windows...")
cumsum = np.zeros((n_feats, T + 1))
cumsum[:, 1:] = np.cumsum(feat, axis=1)
cumsq = np.zeros((n_feats, T + 1))
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
starts = np.arange(0, T - win_frames + 1, hop_frames)
ends = starts + win_frames
sums = cumsum[:, ends] - cumsum[:, starts] # (31, n_windows)
sq_sums = cumsq[:, ends] - cumsq[:, starts]
means = sums / win_frames
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
window_vectors = np.vstack([means, stds]).T # (n_windows, 62)
if cancel_flag and getattr(cancel_flag, '_cancel', False): if cancel_flag and getattr(cancel_flag, '_cancel', False):
return [] return []
# Score all windows probs = clf.predict_proba(normed)[:, 1]
if mode == "nearest": mask = probs >= threshold
# Compare each window to every clip vector, take max
clip_vecs = np.stack(profile["clip_vectors"]) # (n_clips, 62)
results = []
# Process in batches to check cancel_flag periodically
batch = 500
for i in range(0, len(window_vectors), batch):
if cancel_flag and getattr(cancel_flag, '_cancel', False):
_log("audio_scan: cancelled")
return results
chunk = window_vectors[i:i + batch]
# cdist: (batch, n_clips) distances
dists = np.linalg.norm(chunk[:, None, :] - clip_vecs[None, :, :], axis=2)
scores = 1.0 / (1.0 + dists.min(axis=1)) # min dist = max similarity
for j, score in enumerate(scores):
if score >= threshold:
idx = i + j
start_t = starts[idx] / fps
results.append((start_t, start_t + window, float(score)))
else:
# Average mode: compare to mean vector
ref = profile["mean_vector"]
dists = np.linalg.norm(window_vectors - ref, axis=1)
scores = 1.0 / (1.0 + dists)
mask = scores >= threshold
results = [ results = [
(starts[i] / fps, starts[i] / fps + window, float(scores[i])) (timestamps[i], timestamps[i] + window, float(probs[i]))
for i in np.nonzero(mask)[0] for i in np.nonzero(mask)[0]
] ]
_log(f"audio_scan: {len(results)} regions above threshold {threshold}") _log(f"audio_scan: {len(results)} regions above threshold {threshold}")
return results return results
+783
View File
@@ -0,0 +1,783 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import numpy as np
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn import LayerNorm, Parameter
from .beats_modules import (
GradMultiply,
SamePad,
get_activation_fn,
GLU_Linear,
quant_noise,
)
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
deep_norm=args.deep_norm,
has_relative_attention_bias=self.relative_position_embedding,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
encoder_layers=args.encoder_layers,
)
for i in range(args.encoder_layers)
]
)
if self.relative_position_embedding:
for i in range(1, args.encoder_layers):
del self.layers[i].self_attn.relative_attention_bias
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
if args.deep_norm:
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
for i in range(args.encoder_layers):
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
if self.layer_wise_gradient_decay_ratio != 1.0:
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
deep_norm: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
encoder_layers: int = 0,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = LayerNorm(self.embedding_dim)
self.deep_norm = deep_norm
if self.deep_norm:
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
else:
self.deep_norm_alpha = 1
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None
):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual * self.deep_norm_alpha + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual * self.deep_norm_alpha + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position,
bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
alpha = 32
q *= 1 / alpha
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
attn_weights = attn_weights + attn_mask_rel_pos
attn_weights_float = F.softmax(
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
+179
View File
@@ -0,0 +1,179 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import torch
import torch.nn as nn
from torch.nn import LayerNorm
import torchaudio.compliance.kaldi as ta_kaldi
from .beats_backbone import (
TransformerEncoder,
)
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class BEATsConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = -1 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
# label predictor
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
self.predictor_class: int = 527 # target class number for the predictor
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class BEATs(nn.Module):
def __init__(
self,
cfg: BEATsConfig,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
if cfg.finetuned_model:
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
else:
self.predictor = None
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
if self.predictor is not None:
x = self.predictor_dropout(x)
logits = self.predictor(x)
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
else:
logits = logits.mean(dim=1)
lprobs = torch.sigmoid(logits)
return lprobs, padding_mask
else:
return x, padding_mask
+219
View File
@@ -0,0 +1,219 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import warnings
import torch
from torch import Tensor, nn
import torch.nn.functional as F
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
+106 -5
View File
@@ -1,3 +1,4 @@
import os
import sqlite3 import sqlite3
import threading import threading
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -7,7 +8,7 @@ from .paths import _log
class ProcessedDB: class ProcessedDB:
_SCHEMA_VERSION = 3 # bump when schema changes _SCHEMA_VERSION = 4 # bump when schema changes
def __init__(self, db_path: str | None = None): def __init__(self, db_path: str | None = None):
if db_path is None: if db_path is None:
@@ -47,6 +48,7 @@ class ProcessedDB:
" clip_count INTEGER NOT NULL DEFAULT 3," " clip_count INTEGER NOT NULL DEFAULT 3,"
" spread REAL NOT NULL DEFAULT 3.0," " spread REAL NOT NULL DEFAULT 3.0,"
" profile TEXT NOT NULL DEFAULT 'default'," " profile TEXT NOT NULL DEFAULT 'default',"
" source_path TEXT NOT NULL DEFAULT '',"
" processed_at TEXT NOT NULL" " processed_at TEXT NOT NULL"
")" ")"
) )
@@ -62,6 +64,7 @@ class ProcessedDB:
"clip_count": "INTEGER NOT NULL DEFAULT 3", "clip_count": "INTEGER NOT NULL DEFAULT 3",
"spread": "REAL NOT NULL DEFAULT 3.0", "spread": "REAL NOT NULL DEFAULT 3.0",
"profile": "TEXT NOT NULL DEFAULT 'default'", "profile": "TEXT NOT NULL DEFAULT 'default'",
"source_path": "TEXT NOT NULL DEFAULT ''",
} }
for col, typedef in new_cols.items(): for col, typedef in new_cols.items():
if col not in cols: if col not in cols:
@@ -85,7 +88,7 @@ class ProcessedDB:
short_side: int | None = None, portrait_ratio: str = "", short_side: int | None = None, portrait_ratio: str = "",
crop_center: float = 0.5, fmt: str = "MP4", crop_center: float = 0.5, fmt: str = "MP4",
clip_count: int = 3, spread: float = 3.0, clip_count: int = 3, spread: float = 3.0,
profile: str = "default") -> None: profile: str = "default", source_path: str = "") -> None:
if not self._enabled: if not self._enabled:
return return
with self._lock: with self._lock:
@@ -93,11 +96,11 @@ class ProcessedDB:
"INSERT INTO processed" "INSERT INTO processed"
" (filename, start_time, output_path, label, category," " (filename, start_time, output_path, label, category,"
" short_side, portrait_ratio, crop_center, format," " short_side, portrait_ratio, crop_center, format,"
" clip_count, spread, profile, processed_at)" " clip_count, spread, profile, source_path, processed_at)"
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(filename, start_time, output_path, label, category, (filename, start_time, output_path, label, category,
short_side, portrait_ratio, crop_center, fmt, short_side, portrait_ratio, crop_center, fmt,
clip_count, spread, profile, clip_count, spread, profile, source_path,
datetime.now(timezone.utc).isoformat()), datetime.now(timezone.utc).isoformat()),
) )
self._con.commit() self._con.commit()
@@ -223,6 +226,104 @@ class ProcessedDB:
).fetchall() ).fetchall()
return [r[0] for r in rows] return [r[0] for r in rows]
def get_export_folders(self, profile: str = "default") -> list[str]:
"""Return distinct export folder names found in output_paths for a profile.
Export paths follow the structure:
.../export_folder/group_dir/clip.mp4
The export folder is 2 levels up from the clip file.
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
"""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
folder_names: set[str] = set()
for (op,) in rows:
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent:
folder_names.add(grandparent)
return sorted(folder_names)
def get_training_data(self, profile: str, positive_folder: str,
fallback_video_dir: str = "",
) -> list[tuple[str, list[float], list[float]]]:
"""Build training video_infos from DB data.
Args:
profile: profile name
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
fallback_video_dir: if source_path is empty, try filename in this dir
Returns:
list of (source_video_path, positive_times, soft_times) per video.
Soft times = clips from any other export folder.
"""
if not self._enabled:
return []
rows = self._con.execute(
"SELECT filename, start_time, output_path, source_path"
" FROM processed WHERE profile = ?",
(profile,),
).fetchall()
# Collect times by video, split by positive vs other folders
pos_by_video: dict[str, set[float]] = {}
soft_by_video: dict[str, set[float]] = {}
source_by_filename: dict[str, str] = {}
for fn, st, op, sp in rows:
if sp:
source_by_filename[fn] = sp
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent == positive_folder:
pos_by_video.setdefault(fn, set()).add(st)
else:
soft_by_video.setdefault(fn, set()).add(st)
result = []
for fn in pos_by_video:
sp = source_by_filename.get(fn, "")
if not sp or not os.path.exists(sp):
# Fallback: try video_dir / filename
if fallback_video_dir:
sp = os.path.join(fallback_video_dir, fn)
if not sp or not os.path.exists(sp):
continue
gt_pos = sorted(pos_by_video[fn])
gt_soft = sorted(soft_by_video.get(fn, set()))
result.append((sp, gt_pos, gt_soft))
return result
def get_training_stats(self, profile: str) -> dict[str, dict]:
"""Return per-subprofile stats for training readiness display.
Returns dict mapping subprofile_name → {
'videos': number of distinct source videos,
'clips': total clip count,
}
"""
if not self._enabled:
return {}
rows = self._con.execute(
"SELECT filename, output_path FROM processed WHERE profile = ?",
(profile,),
).fetchall()
folders = self.get_export_folders(profile)
stats: dict[str, dict] = {}
for folder_name in folders:
videos: set[str] = set()
clips = 0
for fn, op in rows:
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
if grandparent == folder_name:
videos.add(fn)
clips += 1
stats[folder_name] = {"videos": len(videos), "clips": clips}
return stats
def hide_file(self, filename: str, profile: str = "default") -> None: def hide_file(self, filename: str, profile: str = "default") -> None:
if not self._enabled: if not self._enabled:
return return
+483 -64
View File
@@ -15,7 +15,7 @@ from PyQt6.QtWidgets import (
QLabel, QPushButton, QLineEdit, QFileDialog, QLabel, QPushButton, QLineEdit, QFileDialog,
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip, QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
QMessageBox, QInputDialog, QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
) )
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
@@ -191,12 +191,11 @@ class ScanWorker(QThread):
error = pyqtSignal(str) error = pyqtSignal(str)
progress = pyqtSignal(str) # status message progress = pyqtSignal(str) # status message
def __init__(self, video_path: str, clip_paths: list[str], def __init__(self, video_path: str, model: dict,
mode: str = "average", threshold: float = 0.7): threshold: float = 0.30):
super().__init__() super().__init__()
self._video_path = video_path self._video_path = video_path
self._clip_paths = clip_paths self._model = model
self._mode = mode
self._threshold = threshold self._threshold = threshold
self._cancel = False self._cancel = False
@@ -204,20 +203,12 @@ class ScanWorker(QThread):
self._cancel = True self._cancel = True
def run(self): def run(self):
from core.audio_scan import build_profile, scan_video from core.audio_scan import scan_video
try: try:
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
profile = build_profile(self._clip_paths)
if self._cancel:
return
if profile is None:
self.error.emit("No valid reference clips found")
return
self.progress.emit("Scanning audio...") self.progress.emit("Scanning audio...")
regions = scan_video( regions = scan_video(
self._video_path, profile, self._video_path, model=self._model,
mode=self._mode, threshold=self._threshold, threshold=self._threshold, cancel_flag=self,
cancel_flag=self,
) )
if not self._cancel: if not self._cancel:
self.scan_done.emit(regions) self.scan_done.emit(regions)
@@ -226,6 +217,151 @@ class ScanWorker(QThread):
self.error.emit(str(e)) self.error.emit(str(e))
class TrainDialog(QDialog):
"""Dialog for configuring and launching classifier training."""
def __init__(self, db: ProcessedDB, profile: str, video_dir: str = "",
parent=None):
super().__init__(parent)
self.setWindowTitle("Train Classifier")
self.setMinimumWidth(400)
from core.audio_scan import _EMBED_MODELS
self._db = db
self._profile = profile
self._video_dir = video_dir
layout = QVBoxLayout(self)
form = QFormLayout()
# Positive class selector — lists export folders
self._cmb_positive = QComboBox()
stats = db.get_training_stats(profile)
if not stats:
form.addRow("", QLabel("No exported clips found for this profile."))
for folder_name, info in stats.items():
label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)"
self._cmb_positive.addItem(label, userData=folder_name)
form.addRow("Positive class:", self._cmb_positive)
# Model selector
self._cmb_model = QComboBox()
for name in _EMBED_MODELS:
self._cmb_model.addItem(name)
self._cmb_model.setCurrentText("WAV2VEC2_BASE")
form.addRow("Model:", self._cmb_model)
# Video source directory (fallback for old DB rows without source_path)
self._txt_video_dir = QLineEdit(video_dir)
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
self._debounce = QTimer(self)
self._debounce.setSingleShot(True)
self._debounce.setInterval(400)
self._debounce.timeout.connect(self._update_stats)
self._txt_video_dir.textChanged.connect(lambda: self._debounce.start())
vid_row = QHBoxLayout()
vid_row.addWidget(self._txt_video_dir)
btn_browse = QPushButton("...")
btn_browse.setFixedWidth(30)
btn_browse.clicked.connect(self._browse_video_dir)
vid_row.addWidget(btn_browse)
form.addRow("Video dir:", vid_row)
layout.addLayout(form)
# Stats summary
self._lbl_stats = QLabel()
self._update_stats()
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
layout.addWidget(self._lbl_stats)
# Buttons
btns = QDialogButtonBox(
QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel
)
btns.button(QDialogButtonBox.StandardButton.Ok).setText("Train")
btns.button(QDialogButtonBox.StandardButton.Ok).setEnabled(
self._cmb_positive.count() > 0
)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
layout.addWidget(btns)
def _browse_video_dir(self):
d = QFileDialog.getExistingDirectory(self, "Select video source directory")
if d:
self._txt_video_dir.setText(d)
def _update_stats(self):
folder = self._cmb_positive.currentData()
if not folder:
self._lbl_stats.setText("No export folder data available.")
return
video_infos = self._db.get_training_data(
self._profile, folder,
fallback_video_dir=self._txt_video_dir.text(),
)
n_videos = len(video_infos)
n_pos = sum(len(gt) for _, gt, _ in video_infos)
n_soft = sum(len(s) for _, _, s in video_infos)
lines = [f"<b>{n_videos}</b> videos with positive clips"]
lines.append(f"<b>{n_pos}</b> positive markers, <b>{n_soft}</b> soft/buffer markers")
if n_videos == 0:
lines.append("<i>No source videos found. Set Video dir above.</i>")
elif n_videos < 3:
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
self._lbl_stats.setText("<br>".join(lines))
@property
def positive_folder(self) -> str:
return self._cmb_positive.currentData() or ""
@property
def embed_model(self) -> str:
return self._cmb_model.currentText()
@property
def video_dir(self) -> str:
return self._txt_video_dir.text()
class TrainWorker(QThread):
"""Trains an audio classifier off the main thread."""
train_done = pyqtSignal(str) # emits model path on success
error = pyqtSignal(str)
progress = pyqtSignal(str) # per-video status
def __init__(self, video_infos: list, model_path: str,
embed_model: str | None = None):
super().__init__()
self._video_infos = video_infos
self._model_path = model_path
self._embed_model = embed_model
self._cancel = False
def cancel(self) -> None:
self._cancel = True
def run(self):
from core.audio_scan import train_classifier
try:
self.progress.emit(f"Training on {len(self._video_infos)} videos...")
result = train_classifier(
self._video_infos,
model_path=self._model_path,
embed_model=self._embed_model,
)
if self._cancel:
return
if result is None:
self.error.emit("Training failed: not enough data or missing class balance")
else:
self.train_done.emit(self._model_path)
except Exception as e:
if not self._cancel:
self.error.emit(str(e))
class TimelineWidget(QWidget): class TimelineWidget(QWidget):
cursor_changed = pyqtSignal(float) # emits position in seconds cursor_changed = pyqtSignal(float) # emits position in seconds
seek_changed = pyqtSignal(float) # emits seek position (lock mode) seek_changed = pyqtSignal(float) # emits seek position (lock mode)
@@ -1564,23 +1700,35 @@ class MainWindow(QMainWindow):
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips") self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
self._btn_scan.clicked.connect(self._start_scan) self._btn_scan.clicked.connect(self._start_scan)
self._btn_auto_export = QPushButton("Auto")
self._btn_auto_export.setToolTip("Scan + auto-export best 8s clips")
self._btn_auto_export.clicked.connect(self._auto_export)
self._btn_train = QPushButton("Train")
self._btn_train.setToolTip("Train audio classifier from exported clips")
self._btn_train.clicked.connect(self._open_train_dialog)
self._train_worker: TrainWorker | None = None
self._spn_auto_fuse = QDoubleSpinBox()
self._spn_auto_fuse.setDecimals(1)
self._spn_auto_fuse.setRange(0.0, 60.0)
self._spn_auto_fuse.setSingleStep(1.0)
self._spn_auto_fuse.setValue(float(self._settings.value("auto_fuse", "4.0")))
self._spn_auto_fuse.setPrefix("Fuse: ")
self._spn_auto_fuse.setSuffix("s")
self._spn_auto_fuse.setToolTip("Max gap between scan regions to merge into one cluster")
self._spn_auto_fuse.valueChanged.connect(
lambda v: self._settings.setValue("auto_fuse", str(v))
)
self._sld_threshold = QDoubleSpinBox() self._sld_threshold = QDoubleSpinBox()
self._sld_threshold.setDecimals(2) self._sld_threshold.setDecimals(2)
self._sld_threshold.setRange(0.0, 1.0) self._sld_threshold.setRange(0.0, 1.0)
self._sld_threshold.setSingleStep(0.01) self._sld_threshold.setSingleStep(0.01)
self._sld_threshold.setValue(0.05) self._sld_threshold.setValue(0.30)
self._sld_threshold.setPrefix("Thr: ") self._sld_threshold.setPrefix("Thr: ")
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)") self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
self._cmb_scan_mode = QComboBox()
self._cmb_scan_mode.addItems(["Average", "Nearest"])
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
self._cmb_scan_ref = QComboBox()
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
self._scan_folder: str = ""
self._scan_worker: ScanWorker | None = None self._scan_worker: ScanWorker | None = None
cpu_count = os.cpu_count() or 2 cpu_count = os.cpu_count() or 2
@@ -1716,9 +1864,10 @@ class MainWindow(QMainWindow):
settings_row.addWidget(self._chk_rand_square) settings_row.addWidget(self._chk_rand_square)
settings_row.addWidget(self._chk_track) settings_row.addWidget(self._chk_track)
settings_row.addWidget(self._btn_scan) settings_row.addWidget(self._btn_scan)
settings_row.addWidget(self._btn_auto_export)
settings_row.addWidget(self._spn_auto_fuse)
settings_row.addWidget(self._sld_threshold) settings_row.addWidget(self._sld_threshold)
settings_row.addWidget(self._cmb_scan_mode) settings_row.addWidget(self._btn_train)
settings_row.addWidget(self._cmb_scan_ref)
settings_row.addStretch() settings_row.addStretch()
self._lbl_status = QLabel() self._lbl_status = QLabel()
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;") self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
@@ -2503,16 +2652,6 @@ class MainWindow(QMainWindow):
return return
self._step_cursor(markers[0][0] - self._cursor) # wrap to first self._step_cursor(markers[0][0] - self._cursor) # wrap to first
def _on_scan_ref_changed(self, index: int) -> None:
if index == 1: # Custom Folder
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
if folder:
self._scan_folder = folder
else:
self._cmb_scan_ref.blockSignals(True)
self._cmb_scan_ref.setCurrentIndex(0)
self._cmb_scan_ref.blockSignals(False)
def _cleanup_scan_worker(self) -> None: def _cleanup_scan_worker(self) -> None:
"""Disconnect signals and schedule deletion of old scan worker.""" """Disconnect signals and schedule deletion of old scan worker."""
if self._scan_worker is not None: if self._scan_worker is not None:
@@ -2540,35 +2679,22 @@ class MainWindow(QMainWindow):
# Clean up previous worker # Clean up previous worker
self._cleanup_scan_worker() self._cleanup_scan_worker()
# Collect reference clip paths
if self._cmb_scan_ref.currentIndex() == 0:
# Current profile — all exports across all files in this profile
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
if os.path.exists(p)]
else:
# Custom folder
if not self._scan_folder:
self._show_status("No reference folder selected")
return
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
clip_paths = [
os.path.join(self._scan_folder, f)
for f in sorted(os.listdir(self._scan_folder))
if f.lower().endswith(exts)
]
if not clip_paths:
self._show_status("No reference clips found")
return
mode = self._cmb_scan_mode.currentText().lower()
threshold = self._sld_threshold.value() threshold = self._sld_threshold.value()
self._btn_scan.setEnabled(False) from core.audio_scan import load_classifier, default_model_path
self._scan_file_path = self._file_path # remember which file we're scanning model_path = default_model_path(self._profile)
self._show_status(f"Scanning with {len(clip_paths)} reference clips...") model = load_classifier(model_path)
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold) if model is None:
self._show_status("No trained model — click Train first")
return
self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path
self._show_status("Scanning...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
self._scan_worker.scan_done.connect(self._on_scan_done) self._scan_worker.scan_done.connect(self._on_scan_done)
self._scan_worker.error.connect(self._on_scan_error) self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status) self._scan_worker.progress.connect(self._show_status)
@@ -2576,6 +2702,7 @@ class MainWindow(QMainWindow):
def _on_scan_done(self, regions: list) -> None: def _on_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True) self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
# Ignore stale results if the user switched files during scan # Ignore stale results if the user switched files during scan
if self._file_path != getattr(self, '_scan_file_path', None): if self._file_path != getattr(self, '_scan_file_path', None):
return return
@@ -2584,8 +2711,294 @@ class MainWindow(QMainWindow):
def _on_scan_error(self, msg: str) -> None: def _on_scan_error(self, msg: str) -> None:
self._btn_scan.setEnabled(True) self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._show_status(f"Scan error: {msg}") self._show_status(f"Scan error: {msg}")
# ── Training ────────────────────────────────────────────────
def _cleanup_train_worker(self) -> None:
"""Disconnect signals and schedule deletion of old train worker."""
if self._train_worker is not None:
try:
self._train_worker.train_done.disconnect()
self._train_worker.error.disconnect()
self._train_worker.progress.disconnect()
except TypeError:
pass
if self._train_worker.isRunning():
self._train_worker.cancel()
self._train_worker.finished.connect(self._train_worker.deleteLater)
else:
self._train_worker.deleteLater()
self._train_worker = None
def _open_train_dialog(self):
"""Show the training config dialog and start training if accepted."""
if self._train_worker and self._train_worker.isRunning():
self._show_status("Training already in progress…")
return
# Default video dir: parent of currently loaded file, or saved setting
default_dir = ""
if self._file_path:
default_dir = os.path.dirname(self._file_path)
saved_dir = self._settings.value("train_video_dir", default_dir)
dlg = TrainDialog(self._db, self._profile,
video_dir=saved_dir or default_dir, parent=self)
if dlg.exec() != QDialog.DialogCode.Accepted:
return
pos_folder = dlg.positive_folder
embed_model = dlg.embed_model
video_dir = dlg.video_dir
if not pos_folder:
self._show_status("No positive class selected")
return
# Persist video dir for next time
if video_dir:
self._settings.setValue("train_video_dir", video_dir)
video_infos = self._db.get_training_data(
self._profile, pos_folder, fallback_video_dir=video_dir,
)
if not video_infos:
self._show_status("No training data found for this subprofile")
return
from core.audio_scan import default_model_path
model_path = default_model_path(self._profile)
self._cleanup_train_worker()
self._btn_train.setEnabled(False)
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
self._train_worker = TrainWorker(video_infos, model_path, embed_model)
self._train_worker.train_done.connect(self._on_train_done)
self._train_worker.error.connect(self._on_train_error)
self._train_worker.progress.connect(self._show_status)
self._train_worker.start()
def _on_train_done(self, model_path: str):
self._btn_train.setEnabled(True)
self._show_status(f"Model trained and saved")
_log(f"Training complete: {model_path}")
def _on_train_error(self, msg: str):
self._btn_train.setEnabled(True)
self._show_status(f"Training error: {msg}")
# ── Auto-export ─────────────────────────────────────────────
def _auto_export(self) -> None:
"""Scan → NMS → export one 8s clip per selected position."""
if not self._file_path:
self._show_status("No video loaded")
return
if self._export_worker and self._export_worker.isRunning():
self._show_status("Export already running…")
return
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
self._cleanup_scan_worker()
self._btn_auto_export.setEnabled(False)
self._btn_scan.setEnabled(False)
threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
if model is not None:
self._scan_file_path = self._file_path
self._show_status("Auto: scanning with classifier...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
else:
self._show_status("Auto: no trained model — click Train first")
self._btn_auto_export.setEnabled(True)
self._btn_scan.setEnabled(True)
return
self._scan_worker.scan_done.connect(self._on_auto_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
self._scan_worker.start()
@staticmethod
def _select_export_positions(regions: list[tuple[float, float, float]],
min_gap: float = 2.0,
cluster_fuse: float = 30.0,
) -> list[float]:
"""Cluster scan regions, then fill each cluster with clips spaced min_gap apart.
1. Merge overlapping regions into clusters, fusing clusters <cluster_fuse apart.
2. Within each cluster, greedily pick positions by score, min_gap apart.
"""
if not regions:
return []
# Build clusters — merge overlapping + fuse if gap < cluster_fuse
sorted_r = sorted(regions, key=lambda r: r[0])
clusters: list[list[tuple[float, float, float]]] = []
cur_start, cur_end = sorted_r[0][0], sorted_r[0][1]
cur_regions = [sorted_r[0]]
for start, end, score in sorted_r[1:]:
if start - cur_end <= cluster_fuse:
cur_end = max(cur_end, end)
cur_regions.append((start, end, score))
else:
clusters.append(cur_regions)
cur_start, cur_end = start, end
cur_regions = [(start, end, score)]
clusters.append(cur_regions)
# Within each cluster, NMS by score with min_gap
picked: list[float] = []
for cluster in clusters:
by_score = sorted(cluster, key=lambda r: -r[2])
cluster_picks: list[float] = []
for start, _end, _score in by_score:
if all(abs(start - p) >= min_gap for p in cluster_picks):
cluster_picks.append(start)
picked.extend(cluster_picks)
return sorted(picked)
def _on_auto_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
if self._file_path != getattr(self, '_scan_file_path', None):
self._btn_auto_export.setEnabled(True)
return
self._timeline.set_scan_regions(regions)
if not regions:
self._show_status("Auto: no regions found")
self._btn_auto_export.setEnabled(True)
return
positions = self._select_export_positions(
regions, min_gap=2.0, cluster_fuse=self._spn_auto_fuse.value(),
)
if not positions:
self._show_status("Auto: no positions after NMS")
self._btn_auto_export.setEnabled(True)
return
# Build export jobs — one 8s clip per position
folder = self._txt_folder.text()
name = self._txt_name.text() or "clip"
fmt = self._cmb_format.currentText()
image_sequence = fmt == "WebP sequence"
os.makedirs(folder, exist_ok=True)
# Find starting counter
counter = 1
while True:
if image_sequence:
p = build_sequence_dir(folder, name, counter, sub=0)
else:
p = build_export_path(folder, name, counter, sub=0)
if not os.path.exists(p):
break
counter += 1
jobs = []
self._auto_export_positions = [] # stash for DB writes
for start_t in positions:
group_dir = os.path.join(folder, f"{name}_{counter:03d}")
os.makedirs(group_dir, exist_ok=True)
if image_sequence:
out = build_sequence_dir(folder, name, counter, sub=0)
else:
out = build_export_path(folder, name, counter, sub=0)
jobs.append((start_t, out, None, 0.5))
self._auto_export_positions.append((start_t, counter))
counter += 1
self._show_status(f"Auto: exporting {len(jobs)} clips...")
short_side = self._spn_resize.value() or None
self._export_short_side = short_side
self._export_portrait = "Off"
self._export_format = fmt
self._export_clip_count = 1
self._export_spread = 0
self._export_folder = folder
self._export_folder_suffix = ""
hw_on = self._chk_hw.isChecked() and self._hw_encoders
encoder = self._hw_encoders[0] if hw_on else "libx264"
max_workers = min(self._spn_workers.value(), 3) if hw_on else self._spn_workers.value()
self._export_worker = ExportWorker(
self._file_path, jobs,
short_side=short_side,
image_sequence=image_sequence,
max_workers=max_workers,
encoder=encoder,
)
self._export_worker.finished.connect(self._on_auto_clip_done)
self._export_worker.all_done.connect(self._on_auto_batch_done)
self._export_worker.error.connect(self._on_export_error)
self._export_worker.cancelled.connect(self._on_export_cancelled)
self._btn_cancel.setEnabled(True)
self._btn_export.setEnabled(False)
self._set_subprofile_btns_enabled(False)
self._export_worker.start()
def _on_auto_clip_done(self, path: str):
"""Record each auto-exported clip to DB."""
# Find the start_time for this clip from stashed positions
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
name = self._txt_name.text() or "clip"
start_t = None
for t, c in self._auto_export_positions:
if counter_str == f"{name}_{c:03d}":
start_t = t
break
label = self._txt_label.currentText().strip()
category = self._cmb_category.currentText()
self._db.add(
os.path.basename(self._file_path),
start_t or 0.0,
path,
label=label,
category=category,
short_side=self._export_short_side,
portrait_ratio="",
crop_center=0.5,
fmt=self._export_format,
clip_count=1,
spread=0,
profile=self._profile,
source_path=self._file_path,
)
upsert_clip_annotation(self._export_folder, path, label)
self._show_status(f"Auto: {os.path.basename(path)}")
_log(f" auto clip done: {os.path.basename(path)}")
def _on_auto_batch_done(self):
n = len(self._auto_export_positions)
self._btn_auto_export.setEnabled(True)
self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True)
self._set_subprofile_btns_enabled(True)
self._refresh_markers()
markers = self._db.get_markers(os.path.basename(self._file_path), self._profile)
self._playlist.mark_done(self._file_path, len(markers))
self._update_next_label()
self._show_status(f"Auto export complete: {n} clips")
_log(f"Auto export complete: {n} clips")
def _jump_to_next_scan_region(self) -> None: def _jump_to_next_scan_region(self) -> None:
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0]) regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
if not regions: if not regions:
@@ -2812,6 +3225,7 @@ class MainWindow(QMainWindow):
clip_count=self._export_clip_count, clip_count=self._export_clip_count,
spread=self._export_spread, spread=self._export_spread,
profile=self._profile, profile=self._profile,
source_path=self._file_path,
) )
upsert_clip_annotation(self._export_folder, path, label) upsert_clip_annotation(self._export_folder, path, label)
self._last_export_path = path self._last_export_path = path
@@ -2851,6 +3265,7 @@ class MainWindow(QMainWindow):
_log(f"Export error: {msg}") _log(f"Export error: {msg}")
self._btn_cancel.setEnabled(False) self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True) self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True) self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export") self._btn_export.setText("Export")
self._btn_export.setStyleSheet("") self._btn_export.setStyleSheet("")
@@ -2866,6 +3281,7 @@ class MainWindow(QMainWindow):
def _on_export_cancelled(self): def _on_export_cancelled(self):
_log("Export cancelled") _log("Export cancelled")
self._btn_export.setEnabled(True) self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True) self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export") self._btn_export.setText("Export")
self._btn_export.setStyleSheet("") self._btn_export.setStyleSheet("")
@@ -2886,6 +3302,9 @@ class MainWindow(QMainWindow):
_log("Shutting down…") _log("Shutting down…")
# Save session playlist for resume. # Save session playlist for resume.
self._settings.setValue("session_files", self._playlist._paths) self._settings.setValue("session_files", self._playlist._paths)
# Cancel background workers to prevent callbacks into dead objects.
self._cleanup_scan_worker()
self._cleanup_train_worker()
# Stop timers first to prevent callbacks into dead objects. # Stop timers first to prevent callbacks into dead objects.
self._preview_timer.stop() self._preview_timer.stop()
self._mpv._render_timer.stop() self._mpv._render_timer.stop()
+22 -1
View File
@@ -1,4 +1,25 @@
# Core GUI
PyQt6>=6.4 PyQt6>=6.4
python-mpv>=1.0 python-mpv>=1.0
pytest>=7.0
# Audio & ML
librosa>=0.10
numpy>=1.24
scikit-learn>=1.3
joblib>=1.3
soundfile>=0.12
# Deep learning (torch installed separately for CUDA support)
# torch and torchaudio are installed via --index-url in setup_env.sh
torchaudio>=2.0
# Object detection
ultralytics>=8.0 ultralytics>=8.0
# Server API
fastapi>=0.100
pydantic>=2.0
uvicorn>=0.23
# Dev
pytest>=7.0
Executable
+108
View File
@@ -0,0 +1,108 @@
#!/usr/bin/env bash
set -euo pipefail
# ──────────────────────────────────────────────────────────────────────
# 8-cut environment setup — supports conda (miniforge) or python venv
#
# Usage:
# ./setup_env.sh # auto-detect (prefers conda if available)
# ./setup_env.sh --conda # force conda
# ./setup_env.sh --venv # force python venv
# ────────────────────────────────────────────────────────────────────
ENV_NAME="8cut"
PYTHON_VERSION="3.12"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
VENV_DIR="$SCRIPT_DIR/.venv"
# CUDA version for PyTorch index URL
TORCH_INDEX="https://download.pytorch.org/whl/cu128"
# ── Parse args ────────────────────────────────────────────────────────
MODE=""
for arg in "$@"; do
case "$arg" in
--conda) MODE="conda" ;;
--venv) MODE="venv" ;;
*) echo "Unknown arg: $arg"; exit 1 ;;
esac
done
if [ -z "$MODE" ]; then
if command -v conda &>/dev/null; then
MODE="conda"
else
MODE="venv"
fi
echo "Auto-detected mode: $MODE"
fi
# ── Conda setup ──────────────────────────────────────────────────────
setup_conda() {
echo "==> Setting up conda environment: $ENV_NAME"
# Source conda shell hooks if not already active
if ! command -v conda &>/dev/null; then
echo "conda not found in PATH"
exit 1
fi
eval "$(conda shell.bash hook)"
if conda env list | grep -qw "$ENV_NAME"; then
echo " Environment '$ENV_NAME' already exists, updating..."
conda activate "$ENV_NAME"
else
echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..."
conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION"
conda activate "$ENV_NAME"
fi
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
pip install torch torchaudio --index-url "$TORCH_INDEX"
echo " Installing project dependencies..."
pip install -r "$SCRIPT_DIR/requirements.txt"
echo ""
echo "Done! Activate with:"
echo " conda activate $ENV_NAME"
}
# ── Venv setup ───────────────────────────────────────────────────────
setup_venv() {
echo "==> Setting up Python venv at: $VENV_DIR"
if [ ! -d "$VENV_DIR" ]; then
python3 -m venv "$VENV_DIR"
echo " Created venv"
else
echo " Venv already exists, updating..."
fi
source "$VENV_DIR/bin/activate"
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
pip install torch torchaudio --index-url "$TORCH_INDEX"
echo " Installing project dependencies..."
pip install -r "$SCRIPT_DIR/requirements.txt"
echo ""
echo "Done! Activate with:"
echo " source $VENV_DIR/bin/activate"
}
# ── Run ───────────────────────────────────────────────────────────────
case "$MODE" in
conda) setup_conda ;;
venv) setup_venv ;;
esac
echo ""
echo "Verify with:"
echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\""
echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\""
+14 -140
View File
@@ -1,154 +1,28 @@
import tempfile, os import tempfile, os
import numpy as np import numpy as np
from core.audio_scan import build_profile, _extract_features, scan_video, _similarity from core.audio_scan import scan_video, load_classifier, default_model_path
def _make_wav(path: str, duration: float = 8.0, sr: int = 16000, freq: float = 440.0): def test_scan_video_no_model_returns_empty():
"""Create a short sine-wave WAV file for testing.""" """scan_video with no model should return empty list."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
import soundfile as sf import soundfile as sf
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
audio = 0.5 * np.sin(2 * np.pi * freq * t)
sf.write(path, audio, sr)
def test_extract_features_returns_62d_vector():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
vec = _extract_features(f.name)
assert vec.shape == (62,)
assert not np.isnan(vec).any()
finally:
os.unlink(f.name)
def test_build_profile_single_clip():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
profile = build_profile([f.name])
assert "mean_vector" in profile
assert "clip_vectors" in profile
assert profile["mean_vector"].shape == (62,)
assert len(profile["clip_vectors"]) == 1
finally:
os.unlink(f.name)
def test_build_profile_multiple_clips():
paths = []
try:
for i in range(3):
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
_make_wav(f.name, freq=440 + i * 200)
paths.append(f.name)
f.close()
profile = build_profile(paths)
assert len(profile["clip_vectors"]) == 3
assert profile["mean_vector"].shape == (62,)
finally:
for p in paths:
os.unlink(p)
def test_build_profile_skips_missing_files():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
_make_wav(f.name)
try:
profile = build_profile([f.name, "/no/such/file.wav"])
assert len(profile["clip_vectors"]) == 1
finally:
os.unlink(f.name)
def test_build_profile_empty_returns_none():
result = build_profile([])
assert result is None
def test_similarity_identical_is_one():
a = np.array([1.0, 2.0, 3.0])
assert abs(_similarity(a, a) - 1.0) < 1e-9
def test_similarity_distant_is_low():
a = np.zeros(62)
b = np.ones(62) * 100
assert _similarity(a, b) < 0.01
def test_scan_video_finds_matching_region():
"""A video made of the same sine wave as the reference should match."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
_make_wav(vid.name, duration=20.0)
try:
profile = build_profile([ref.name])
regions = scan_video(vid.name, profile, mode="average", threshold=0.01, hop=1.0)
assert len(regions) > 0
for start, end, score in regions:
assert abs((end - start) - 8.0) < 0.1
assert score >= 0.01
finally:
os.unlink(ref.name)
os.unlink(vid.name)
def test_scan_video_nearest_mode():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
_make_wav(vid.name, duration=20.0)
try:
profile = build_profile([ref.name])
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.01, hop=1.0)
assert len(regions) > 0
finally:
os.unlink(ref.name)
os.unlink(vid.name)
def test_scan_video_high_threshold_no_match():
"""Different frequencies with very high threshold should not match."""
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0, freq=440)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
# White noise — very different from sine wave
sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000) sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000)
try: try:
profile = build_profile([ref.name]) regions = scan_video(vid.name, model=None)
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0) assert regions == []
assert len(regions) == 0
finally: finally:
os.unlink(ref.name)
os.unlink(vid.name) os.unlink(vid.name)
def test_scan_video_same_vs_different_discrimination(): def test_load_classifier_missing_returns_none():
"""Same-frequency match should score higher than cross-frequency.""" assert load_classifier("/no/such/model.joblib") is None
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
_make_wav(ref.name, duration=8.0, freq=440) def test_default_model_path_contains_profile():
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as same: path = default_model_path("test_profile")
_make_wav(same.name, duration=10.0, freq=440) assert "test_profile" in path
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as diff: assert path.endswith(".joblib")
# White noise
sf.write(diff.name, np.random.randn(16000 * 10).astype(np.float32) * 0.1, 16000)
try:
profile = build_profile([ref.name])
same_regions = scan_video(same.name, profile, mode="average", threshold=0.0, hop=1.0)
diff_regions = scan_video(diff.name, profile, mode="average", threshold=0.0, hop=1.0)
# Same-audio scores should be higher than noise scores
best_same = max(r[2] for r in same_regions)
best_diff = max(r[2] for r in diff_regions)
assert best_same > best_diff
finally:
os.unlink(ref.name)
os.unlink(same.name)
os.unlink(diff.name)
def test_db_get_all_export_paths(): def test_db_get_all_export_paths():