Compare commits
83 Commits
v0.9.0
..
7855ea62c2
| Author | SHA1 | Date | |
|---|---|---|---|
| 7855ea62c2 | |||
| 70be5974cf | |||
| a0286d5cf9 | |||
| 2b7dfb330d | |||
| 518554f788 | |||
| 282156e8ed | |||
| 3417a0f603 | |||
| cd0552197f | |||
| 7dffcb08eb | |||
| 93bcb23fa7 | |||
| eda7826a40 | |||
| e7e20b0fe6 | |||
| 814ef946eb | |||
| 2e738df9ae | |||
| 6ddfcde8ee | |||
| b161412d94 | |||
| 5a9e068903 | |||
| 6870e5aaf3 | |||
| f597ff29e8 | |||
| e1789d4e71 | |||
| 7834b1d05c | |||
| 12ed183f1b | |||
| f2c38aee79 | |||
| 8ab5bdba77 | |||
| c6c5934fe8 | |||
| 73d5367424 | |||
| 1e2cebd424 | |||
| c439aca9b9 | |||
| afda9b2d9f | |||
| fd42791c9f | |||
| 4cf54f2642 | |||
| e7f4de9ec1 | |||
| 9cf9e3233f | |||
| e17d8f67aa | |||
| b1980de6d1 | |||
| 85e0641440 | |||
| 834b89b682 | |||
| a67e189aa0 | |||
| 2b6c56cd15 | |||
| 0f6082061f | |||
| 9662b815db | |||
| 9776b83ac5 | |||
| 39f873bec2 | |||
| 409eb82e5c | |||
| 297aafa51c | |||
| b4cf972d59 | |||
| 5cc1e52e75 | |||
| 6bf0b0ae99 | |||
| b6fbda01dd | |||
| 51d41f0a56 | |||
| 16bd1a9ae0 | |||
| 2036c49b52 | |||
| b12758c53c | |||
| 3d484952c2 | |||
| 12dae93671 | |||
| 1e65fd6b0f | |||
| f7756320e5 | |||
| cd0331d4ce | |||
| 38c6174f83 | |||
| 5b22bceed2 | |||
| 80f21915e3 | |||
| b09ba3fa9e | |||
| 5b7a55a05d | |||
| 2200da491f | |||
| 3d6469c60c | |||
| 6a4ac8b8ed | |||
| 1f6906c946 | |||
| dfba88a601 | |||
| e94c088df0 | |||
| 9569103edd | |||
| 079afeee7c | |||
| fbbfa6fdce | |||
| 56920a5247 | |||
| 08c1dd8b33 | |||
| 2b63ad1857 | |||
| 72f6a4e8f5 | |||
| 799a2ab353 | |||
| 066f4431ba | |||
| 97f9ef7073 | |||
| 592e40c1a6 | |||
| 73dd7a1569 | |||
| 7abf0b4d4c | |||
| 9e5bd4a8ec |
@@ -0,0 +1,36 @@
|
|||||||
|
name: Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # manual only — build locally and push to ghcr.io
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- uses: docker/metadata-action@v5
|
||||||
|
id: meta
|
||||||
|
with:
|
||||||
|
images: ghcr.io/${{ github.repository }}-server
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=sha,prefix=
|
||||||
|
|
||||||
|
- uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
@@ -3,3 +3,8 @@ __pycache__/
|
|||||||
*.pyo
|
*.pyo
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.worktrees/
|
.worktrees/
|
||||||
|
.venv/
|
||||||
|
models/
|
||||||
|
cache/
|
||||||
|
*.joblib
|
||||||
|
*.pt
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
@echo off
|
@echo off
|
||||||
cd /d "%~dp0"
|
cd /d "%~dp0"
|
||||||
python main.py %*
|
if exist ".venv\Scripts\python.exe" (
|
||||||
|
.venv\Scripts\python.exe main.py %*
|
||||||
|
) else (
|
||||||
|
python main.py %*
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Launch 8-cut with auto-detected venv/conda environment
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
ENV_NAME="8cut"
|
||||||
|
CONDA_PREFIX_BASE="/media/p5/miniforge3"
|
||||||
|
|
||||||
|
# 1. Try .venv in project dir
|
||||||
|
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
|
||||||
|
source "$SCRIPT_DIR/.venv/bin/activate"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. Try conda env (works without shell init)
|
||||||
|
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
|
||||||
|
if [ -x "$CONDA_PYTHON" ]; then
|
||||||
|
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. Try conda via shell hook (interactive shells)
|
||||||
|
if command -v conda &>/dev/null; then
|
||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Fallback to system Python
|
||||||
|
exec python3 "$SCRIPT_DIR/main.py" "$@"
|
||||||
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Calibration — per-video normalized features + classifier."""
|
||||||
|
import sys, os, time, warnings
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from sklearn.ensemble import GradientBoostingClassifier
|
||||||
|
|
||||||
|
from core.audio_scan import _SR, _WINDOW
|
||||||
|
|
||||||
|
_HOP_LENGTH = 1024
|
||||||
|
_N_FFT = 2048
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||||
|
PROFILE_NAME = "JAV_missionary"
|
||||||
|
TOLERANCE = 12.0
|
||||||
|
NEG_MARGIN = 120.0
|
||||||
|
|
||||||
|
|
||||||
|
def extract_rich_features(y, sr=_SR):
|
||||||
|
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
|
||||||
|
hop = _HOP_LENGTH
|
||||||
|
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
|
||||||
|
rms = librosa.feature.rms(S=S, hop_length=hop)
|
||||||
|
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
|
||||||
|
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
|
||||||
|
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
|
||||||
|
flatness = librosa.feature.spectral_flatness(S=S)
|
||||||
|
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
|
||||||
|
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
|
||||||
|
|
||||||
|
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
|
||||||
|
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
|
||||||
|
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
|
||||||
|
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
|
||||||
|
band_feats = []
|
||||||
|
for flo, fhi in bands:
|
||||||
|
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
|
||||||
|
if mask.sum() > 0:
|
||||||
|
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
|
||||||
|
else:
|
||||||
|
band_feats.append(np.zeros((1, mel_S.shape[1])))
|
||||||
|
|
||||||
|
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
|
||||||
|
|
||||||
|
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
|
||||||
|
band_feats[0].shape[1])
|
||||||
|
return np.vstack([
|
||||||
|
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
|
||||||
|
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
|
||||||
|
] + [b[:, :min_t] for b in band_feats]
|
||||||
|
+ [sc[:, :min_t]])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_window_stats(feat, hop=1.0):
|
||||||
|
"""Sliding window mean/std → (timestamps, feature_vectors)."""
|
||||||
|
n_feats, T = feat.shape
|
||||||
|
fps = _SR / _HOP_LENGTH
|
||||||
|
win_frames = int(_WINDOW * fps)
|
||||||
|
hop_frames = int(hop * fps)
|
||||||
|
if win_frames > T:
|
||||||
|
return np.array([]), np.array([])
|
||||||
|
|
||||||
|
cumsum = np.zeros((n_feats, T + 1))
|
||||||
|
cumsum[:, 1:] = np.cumsum(feat, axis=1)
|
||||||
|
cumsq = np.zeros((n_feats, T + 1))
|
||||||
|
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
|
||||||
|
|
||||||
|
starts = np.arange(0, T - win_frames + 1, hop_frames)
|
||||||
|
ends = starts + win_frames
|
||||||
|
sums = cumsum[:, ends] - cumsum[:, starts]
|
||||||
|
sq_sums = cumsq[:, ends] - cumsq[:, starts]
|
||||||
|
means = sums / win_frames
|
||||||
|
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
|
||||||
|
|
||||||
|
return starts / fps, np.vstack([means, stds]).T
|
||||||
|
|
||||||
|
|
||||||
|
def label_windows(timestamps, gt_intense, gt_soft):
|
||||||
|
all_gt = list(gt_intense) + list(gt_soft)
|
||||||
|
labels = np.zeros(len(timestamps), dtype=int)
|
||||||
|
for i, t in enumerate(timestamps):
|
||||||
|
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||||
|
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||||
|
if di < TOLERANCE:
|
||||||
|
labels[i] = 1
|
||||||
|
elif da > NEG_MARGIN:
|
||||||
|
labels[i] = -1
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
db = ProcessedDB()
|
||||||
|
rows = db._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
|
||||||
|
(PROFILE_NAME,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
intense_by_video, soft_by_video = {}, {}
|
||||||
|
for fn, st, op in rows:
|
||||||
|
if '/mp4_Intense/' in op:
|
||||||
|
intense_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif '/mp4_Soft/' in op:
|
||||||
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
|
videos = [fn for fn in intense_by_video
|
||||||
|
if os.path.exists(os.path.join(PLEX_DIR, fn))]
|
||||||
|
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
|
||||||
|
videos = videos[:n_vids]
|
||||||
|
print(f"Processing {len(videos)} videos...")
|
||||||
|
|
||||||
|
all_data_raw = [] # raw features
|
||||||
|
all_data_norm = [] # per-video z-scored features
|
||||||
|
|
||||||
|
for vi, vname in enumerate(videos):
|
||||||
|
vpath = os.path.join(PLEX_DIR, vname)
|
||||||
|
gt_intense = sorted(intense_by_video.get(vname, set()))
|
||||||
|
gt_soft = sorted(soft_by_video.get(vname, set()))
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
y, _ = librosa.load(vpath, sr=_SR, mono=True)
|
||||||
|
feat = extract_rich_features(y)
|
||||||
|
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
|
||||||
|
dt = time.time() - t0
|
||||||
|
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
labels = label_windows(timestamps, gt_intense, gt_soft)
|
||||||
|
|
||||||
|
# Per-video z-score normalization
|
||||||
|
vid_mean = window_vectors.mean(axis=0)
|
||||||
|
vid_std = window_vectors.std(axis=0)
|
||||||
|
vid_std = np.maximum(vid_std, 1e-6)
|
||||||
|
normed = (window_vectors - vid_mean) / vid_std
|
||||||
|
|
||||||
|
n_pos = (labels == 1).sum()
|
||||||
|
n_neg = (labels == -1).sum()
|
||||||
|
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
|
||||||
|
|
||||||
|
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
|
||||||
|
all_data_norm.append((vi, vname, timestamps, normed, labels))
|
||||||
|
|
||||||
|
# Run CV for both raw and normalized
|
||||||
|
for label, data in [("RAW features", all_data_raw),
|
||||||
|
("PER-VIDEO NORMALIZED features", all_data_norm)]:
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f" {label}")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
|
||||||
|
all_y_true, all_y_prob = [], []
|
||||||
|
|
||||||
|
for test_idx in range(len(data)):
|
||||||
|
_, vname, _, test_X, test_labels = data[test_idx]
|
||||||
|
test_mask = test_labels != 0
|
||||||
|
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
|
||||||
|
continue
|
||||||
|
X_test = test_X[test_mask]
|
||||||
|
y_test = (test_labels[test_mask] == 1).astype(int)
|
||||||
|
|
||||||
|
X_parts, y_parts = [], []
|
||||||
|
for i, (_, _, _, feats, labs) in enumerate(data):
|
||||||
|
if i == test_idx:
|
||||||
|
continue
|
||||||
|
m = labs != 0
|
||||||
|
if m.sum() == 0:
|
||||||
|
continue
|
||||||
|
X_parts.append(feats[m])
|
||||||
|
y_parts.append((labs[m] == 1).astype(int))
|
||||||
|
|
||||||
|
if not X_parts:
|
||||||
|
continue
|
||||||
|
X_train = np.vstack(X_parts)
|
||||||
|
y_train = np.concatenate(y_parts)
|
||||||
|
|
||||||
|
pos_idx = np.where(y_train == 1)[0]
|
||||||
|
neg_idx = np.where(y_train == 0)[0]
|
||||||
|
if len(pos_idx) == 0 or len(neg_idx) == 0:
|
||||||
|
continue
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
n_neg = min(len(neg_idx), len(pos_idx) * 3)
|
||||||
|
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
|
||||||
|
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||||
|
|
||||||
|
clf = GradientBoostingClassifier(
|
||||||
|
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
|
||||||
|
)
|
||||||
|
clf.fit(X_train[train_idx], y_train[train_idx])
|
||||||
|
probs = clf.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
tp = ((probs >= 0.5) & (y_test == 1)).sum()
|
||||||
|
fp = ((probs >= 0.5) & (y_test == 0)).sum()
|
||||||
|
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
|
||||||
|
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
|
||||||
|
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
|
||||||
|
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
|
||||||
|
|
||||||
|
all_y_true.extend(y_test)
|
||||||
|
all_y_prob.extend(probs)
|
||||||
|
|
||||||
|
if not all_y_true:
|
||||||
|
print(" No test results.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true = np.array(all_y_true)
|
||||||
|
y_prob = np.array(all_y_prob)
|
||||||
|
pos_probs = y_prob[y_true == 1]
|
||||||
|
neg_probs = y_prob[y_true == 0]
|
||||||
|
|
||||||
|
if len(pos_probs) > 0 and len(neg_probs) > 0:
|
||||||
|
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
|
||||||
|
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
|
||||||
|
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
|
||||||
|
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
|
||||||
|
|
||||||
|
best_f1, best_thr = 0, 0
|
||||||
|
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
|
||||||
|
for thr in np.arange(0.10, 0.91, 0.05):
|
||||||
|
tp = ((y_prob >= thr) & (y_true == 1)).sum()
|
||||||
|
fp = ((y_prob >= thr) & (y_true == 0)).sum()
|
||||||
|
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
|
||||||
|
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||||
|
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
|
||||||
|
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||||
|
if f1 > best_f1:
|
||||||
|
best_f1, best_thr = f1, thr
|
||||||
|
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
|
||||||
|
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
|
||||||
|
|
||||||
|
# Feature importance
|
||||||
|
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
|
||||||
|
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
|
||||||
|
pos_idx = np.where(y_all == 1)[0]
|
||||||
|
neg_idx = np.where(y_all == 0)[0]
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
|
||||||
|
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
|
||||||
|
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
|
||||||
|
|
||||||
|
feat_names = (
|
||||||
|
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
|
||||||
|
+ [f"mel{i}" for i in range(8)]
|
||||||
|
+ [f"sc{i}" for i in range(7)]
|
||||||
|
)
|
||||||
|
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
|
||||||
|
imp = clf.feature_importances_
|
||||||
|
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
|
||||||
|
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train an audio scan classifier from DB ground truth.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python 8cut_train.py # default model, auto-detect positive
|
||||||
|
python 8cut_train.py --model BEATS # specific embedding model
|
||||||
|
python 8cut_train.py --positive mp4_Intense # explicit positive folder
|
||||||
|
python 8cut_train.py --positive mp4_Intense --model BEATS # both
|
||||||
|
"""
|
||||||
|
import sys, os, warnings
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
PROFILE_NAME = "JAV_missionary"
|
||||||
|
|
||||||
|
# Fallback for old DB rows without source_path
|
||||||
|
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
embed_model = None
|
||||||
|
if "--model" in sys.argv:
|
||||||
|
idx = sys.argv.index("--model")
|
||||||
|
if idx + 1 < len(sys.argv):
|
||||||
|
embed_model = sys.argv[idx + 1]
|
||||||
|
if embed_model not in _EMBED_MODELS:
|
||||||
|
print(f"Unknown model: {embed_model}")
|
||||||
|
print(f"Available: {', '.join(_EMBED_MODELS)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
positive_suffix = None
|
||||||
|
if "--positive" in sys.argv:
|
||||||
|
idx = sys.argv.index("--positive")
|
||||||
|
if idx + 1 < len(sys.argv):
|
||||||
|
positive_suffix = sys.argv[idx + 1]
|
||||||
|
|
||||||
|
db = ProcessedDB()
|
||||||
|
|
||||||
|
# If --positive given, use the new DB helper
|
||||||
|
if positive_suffix:
|
||||||
|
video_infos = db.get_training_data(
|
||||||
|
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
|
||||||
|
)
|
||||||
|
if not video_infos:
|
||||||
|
print(f"No training data found for positive='{positive_suffix}'")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# Legacy fallback: classify by folder path pattern
|
||||||
|
rows = db._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path, source_path"
|
||||||
|
" FROM processed WHERE profile = ?",
|
||||||
|
(PROFILE_NAME,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
intense_by_video, soft_by_video = {}, {}
|
||||||
|
source_by_fn = {}
|
||||||
|
for fn, st, op, sp in rows:
|
||||||
|
if sp:
|
||||||
|
source_by_fn[fn] = sp
|
||||||
|
if "/mp4_Intense/" in op or "_Intense/" in op:
|
||||||
|
intense_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif "/mp4_Soft/" in op or "_Soft/" in op:
|
||||||
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
|
video_infos = []
|
||||||
|
for fn in intense_by_video:
|
||||||
|
# Try source_path from DB first, fall back to PLEX_DIR
|
||||||
|
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
|
||||||
|
if not os.path.exists(vpath):
|
||||||
|
print(f" skip (not found): {fn}")
|
||||||
|
continue
|
||||||
|
gt_intense = sorted(intense_by_video[fn])
|
||||||
|
gt_soft = sorted(soft_by_video.get(fn, set()))
|
||||||
|
video_infos.append((vpath, gt_intense, gt_soft))
|
||||||
|
|
||||||
|
label = embed_model or "WAV2VEC2_BASE"
|
||||||
|
print(f"Training {label} model on {len(video_infos)} videos...")
|
||||||
|
model_path = default_model_path(PROFILE_NAME)
|
||||||
|
result = train_classifier(
|
||||||
|
video_infos, model_path=model_path, embed_model=embed_model,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
print("Training failed: no valid samples or missing class balance")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"Model saved to {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -8,7 +8,7 @@
|
|||||||
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets.
|
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets. Includes audio classification for automated scanning and batch export.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
@@ -22,19 +22,44 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
### Clip export
|
||||||
|
|
||||||
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
||||||
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
||||||
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
||||||
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
||||||
- **Random portrait** — optionally apply a random portrait crop to a subset of each batch
|
- **Random portrait/square** — optionally apply a random crop to a subset of each batch
|
||||||
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
||||||
- **Sound annotation** — label and category fields saved to the clip database; label also written to `dataset.json`
|
- **Hardware encoding** — GPU-accelerated export via NVENC, VAAPI, QSV, AMF, or VideoToolbox
|
||||||
- **Export history** — timeline markers show previously exported clips; double-click to enter overwrite mode; right-click to delete
|
- **Subject tracking** — auto-adjust crop center using YOLOv8 detection (optional)
|
||||||
- **End-frame preview** — floating window shows the last frame of the selection region
|
|
||||||
- **Playlist** — drag-and-drop or use the Open Files button; right-click to remove items
|
### Audio scanning
|
||||||
- **Playback loop** — plays the exact selection region on loop so you can preview what will be exported
|
|
||||||
- **Group operations** — delete or overwrite acts on all sub-clips in a batch, not just one
|
- **Embedding models** — WAV2VEC2 (base/large), HuBERT (base/large/xlarge), BEATs
|
||||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait") for the same video
|
- **Train classifier** — train a gradient boosting classifier on your exported clips to find similar audio
|
||||||
|
- **Scan video** — detect regions matching your trained model with configurable threshold
|
||||||
|
- **Scan All** — batch scan every video in the playlist
|
||||||
|
- **Region fusion** — merge overlapping detections into contiguous regions
|
||||||
|
- **Hard negatives** — mark false positives to refine training
|
||||||
|
- **Model versioning** — timestamped backups with rollback support
|
||||||
|
- **Scan export** — batch export from scan results with spread and minimum duration filtering
|
||||||
|
|
||||||
|
### Scan results panel
|
||||||
|
|
||||||
|
- **Tabbed results** — one tab per model, showing start/end/score per region
|
||||||
|
- **Disable regions** — Delete/Backspace toggles regions off (greyed out, excluded from export) without removing them
|
||||||
|
- **Resize regions** — double-click Time or End cells to edit, or drag region edges directly on the timeline
|
||||||
|
- **Grey ghost** — trimmed portions of resized regions shown as grey overlay on timeline
|
||||||
|
- **Undo** — Ctrl+Z reverts the last disable, resize, drag, or negative toggle
|
||||||
|
|
||||||
|
### Organization
|
||||||
|
|
||||||
|
- **Sound annotation** — label and category fields saved to the clip database and `dataset.json`
|
||||||
|
- **Export history** — timeline markers show previously exported clips; double-click to overwrite; right-click to delete
|
||||||
|
- **Playlist** — drag-and-drop video queue with progress tracking
|
||||||
|
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait")
|
||||||
|
- **Subprofiles** — lightweight export folder variants for multiple output targets
|
||||||
|
- **Review mode** — clean timeline view for navigating scan results without export clutter
|
||||||
|
|
||||||
## Keyboard shortcuts
|
## Keyboard shortcuts
|
||||||
|
|
||||||
@@ -50,37 +75,158 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
|||||||
| `M` | Jump to next marker (wraps) |
|
| `M` | Jump to next marker (wraps) |
|
||||||
| `N` | Next file in playlist |
|
| `N` | Next file in playlist |
|
||||||
| `G` | Toggle cursor lock |
|
| `G` | Toggle cursor lock |
|
||||||
|
| `Delete` / `Backspace` | Toggle disable on selected scan regions |
|
||||||
|
| `Ctrl+Z` | Undo last scan panel action |
|
||||||
| `?` / `F1` | Show keyboard shortcuts |
|
| `?` / `F1` | Show keyboard shortcuts |
|
||||||
|
|
||||||
Shortcuts are suppressed when a text field has focus.
|
Shortcuts are suppressed when a text field has focus.
|
||||||
|
|
||||||
## Requirements
|
## Installation
|
||||||
|
|
||||||
- Python 3.11+
|
### Prerequisites
|
||||||
- `ffmpeg` on `PATH`
|
|
||||||
- PyQt6
|
|
||||||
- python-mpv (requires libmpv)
|
|
||||||
|
|
||||||
|
- **Python 3.11+** — [python.org/downloads](https://www.python.org/downloads/)
|
||||||
|
- **ffmpeg** — video encoding
|
||||||
|
- **libmpv** — video playback
|
||||||
|
|
||||||
|
### Quick start (all platforms)
|
||||||
|
|
||||||
|
The setup script creates a virtual environment and installs everything including PyTorch with CUDA support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux / macOS
|
||||||
|
./setup_env.sh
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux / macOS
|
||||||
|
./8cut.sh
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
8cut.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
The launch scripts auto-detect your venv or conda environment.
|
||||||
|
|
||||||
|
### Manual installation
|
||||||
|
|
||||||
|
#### 1. Install system dependencies
|
||||||
|
|
||||||
|
**Linux (Arch):**
|
||||||
|
```bash
|
||||||
|
pacman -S python mpv ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Debian/Ubuntu):**
|
||||||
|
```bash
|
||||||
|
apt install python3 python3-venv libmpv-dev ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Windows:**
|
||||||
|
```powershell
|
||||||
|
# ffmpeg
|
||||||
|
winget install ffmpeg
|
||||||
|
|
||||||
|
# libmpv — download mpv-2.dll and place next to main.py
|
||||||
|
# https://sourceforge.net/projects/mpv-player-windows/files/libmpv/
|
||||||
|
```
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install python mpv ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Create a virtual environment
|
||||||
|
|
||||||
|
A virtual environment keeps 8-cut's dependencies isolated from your system Python. This is strongly recommended — PyTorch alone is several GB and can conflict with other projects.
|
||||||
|
|
||||||
|
**Using venv (recommended):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create the venv in the project directory
|
||||||
|
python3 -m venv .venv
|
||||||
|
|
||||||
|
# Activate it
|
||||||
|
# Linux / macOS:
|
||||||
|
source .venv/bin/activate
|
||||||
|
# Windows (cmd):
|
||||||
|
.venv\Scripts\activate.bat
|
||||||
|
# Windows (PowerShell):
|
||||||
|
.venv\Scripts\Activate.ps1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using conda / miniforge:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n 8cut python=3.12
|
||||||
|
conda activate 8cut
|
||||||
|
```
|
||||||
|
|
||||||
|
You must activate the environment every time you open a new terminal before running 8-cut. The `8cut.sh` launcher does this automatically.
|
||||||
|
|
||||||
|
#### 3. Install PyTorch
|
||||||
|
|
||||||
|
PyTorch must be installed separately with the correct CUDA version for GPU acceleration. Without CUDA, audio scanning will fall back to CPU (much slower).
|
||||||
|
|
||||||
|
**With NVIDIA GPU (CUDA 12.8):**
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
```
|
||||||
|
|
||||||
|
**CPU only (no GPU):**
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Check available CUDA versions at [pytorch.org/get-started](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
|
#### 4. Install project dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Platform notes
|
#### 5. Verify
|
||||||
|
|
||||||
| Platform | libmpv |
|
```bash
|
||||||
|----------|--------|
|
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||||
| **Linux** | `apt install libmpv-dev` or `pacman -S mpv` |
|
python -c "import librosa, torchaudio, sklearn; print('All imports OK')"
|
||||||
| **macOS** | `brew install mpv` |
|
```
|
||||||
| **Windows** | Download `mpv-2.dll` from [mpv Windows builds](https://sourceforge.net/projects/mpv-player-windows/files/libmpv/) and place it in `PATH` or next to `main.py` |
|
|
||||||
|
|
||||||
Windows also needs `ffmpeg.exe` on `PATH` (e.g. `winget install ffmpeg`).
|
### Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# With venv activated:
|
||||||
|
python main.py
|
||||||
|
|
||||||
|
# Or use the launcher (auto-activates venv/conda):
|
||||||
|
./8cut.sh # Linux / macOS
|
||||||
|
8cut.bat # Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU encoding
|
||||||
|
|
||||||
|
Hardware encoders are auto-detected from ffmpeg. Available encoders by platform:
|
||||||
|
|
||||||
|
| Platform | Encoders |
|
||||||
|
|----------|----------|
|
||||||
|
| **Linux** | `h264_nvenc` (NVIDIA), `h264_vaapi` (AMD/Intel), `h264_qsv` (Intel) |
|
||||||
|
| **Windows** | `h264_nvenc` (NVIDIA), `h264_qsv` (Intel), `h264_amf` (AMD) |
|
||||||
|
| **macOS** | `h264_videotoolbox` |
|
||||||
|
|
||||||
|
Enable the **HW** checkbox in the export controls to use GPU encoding.
|
||||||
|
|
||||||
|
### Optional: audio scanning
|
||||||
|
|
||||||
|
Audio scanning requires PyTorch (installed above). Embedding models are downloaded on first use and cached in `cache/downloads/`. A CUDA-capable GPU is strongly recommended for training and scanning speed.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```
|
|
||||||
python main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
||||||
|
|
||||||
### Export layout
|
### Export layout
|
||||||
@@ -109,6 +255,20 @@ output/
|
|||||||
clip_001_0.wav
|
clip_001_0.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Scan export layout
|
||||||
|
|
||||||
|
Scan exports create one group folder per detected area:
|
||||||
|
|
||||||
|
```
|
||||||
|
output/
|
||||||
|
clip_037/
|
||||||
|
clip_037_a1_0.mp4 # area 1, clip 0
|
||||||
|
clip_037_a1_1.mp4 # area 1, clip 1
|
||||||
|
clip_038/
|
||||||
|
clip_038_a2_0.mp4 # area 2, clip 0
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
### Sound annotation
|
### Sound annotation
|
||||||
|
|
||||||
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
||||||
@@ -124,9 +284,73 @@ Labels persist between exports so you can cut many clips of the same class witho
|
|||||||
- **Right-click** a marker to delete it from the database
|
- **Right-click** a marker to delete it from the database
|
||||||
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
||||||
|
|
||||||
|
## Audio scan workflow
|
||||||
|
|
||||||
|
### 1. Build a dataset
|
||||||
|
|
||||||
|
Export clips manually from several videos. Clips from the same export folder (e.g. `mp4_Intense`) become your positive training class.
|
||||||
|
|
||||||
|
**Minimum dataset:** ~20 clips from 2–3 different videos. This is enough for the classifier to learn a basic boundary, but expect noisy results — you'll need to mark hard negatives and retrain.
|
||||||
|
|
||||||
|
**Ideal dataset:** 50–100+ clips from 5+ videos covering the full range of variation in your target sound (different recording conditions, distances, intensities). More variety in your positives makes the model generalize better to unseen footage. Negatives are sampled automatically from regions far from your markers, but adding explicit negatives of confusable sounds (e.g. thunder when training for explosions) significantly reduces false positives.
|
||||||
|
|
||||||
|
The classifier improves iteratively: export a small initial set → train → scan → mark false positives as hard negatives → retrain. Each cycle sharpens the decision boundary without needing a large upfront dataset.
|
||||||
|
|
||||||
|
### 2. Train a classifier
|
||||||
|
|
||||||
|
Click **Train** to open the training dialog:
|
||||||
|
|
||||||
|
- **Positive class** — select the export folder containing your target sounds
|
||||||
|
- **Negative class** — optional explicit negatives, or leave as "(auto only)" for automatic sampling
|
||||||
|
- **Model** — embedding model to use (HuBERT XLARGE recommended)
|
||||||
|
- **Auto-neg margin** — distance from markers to sample automatic negatives (30s default)
|
||||||
|
- **Include scan-exported clips** — whether to include previously scan-exported clips in training
|
||||||
|
|
||||||
|
The classifier trains a `HistGradientBoostingClassifier` on audio embeddings and saves to `models/`.
|
||||||
|
|
||||||
|
### 3. Scan videos
|
||||||
|
|
||||||
|
Select a trained model from the dropdown and click **Scan**. Adjust the threshold slider to control sensitivity. Detected regions appear as colored bands on the timeline and as rows in the results panel.
|
||||||
|
|
||||||
|
Audio embeddings are computed once per video and cached to disk (`cache/w2v/`). Subsequent scans with the same embedding model skip the GPU entirely and only re-run the classifier, which takes milliseconds. This makes the retrain → rescan loop nearly free after the first pass.
|
||||||
|
|
||||||
|
### 4. Review and refine
|
||||||
|
|
||||||
|
- Toggle **Review** mode for a clean timeline focused on scan results
|
||||||
|
- **Disable** false positive regions (Delete key) — they stay in the list but are excluded from export
|
||||||
|
- **Resize** regions by dragging edges on the timeline or editing times in the table
|
||||||
|
- **Mark as negative** — add false positives to the hard negative set for retraining
|
||||||
|
- **Ctrl+Z** to undo any of the above
|
||||||
|
|
||||||
|
### 5. Export results
|
||||||
|
|
||||||
|
Click **Export Scan Results** to batch export all enabled regions. The button shows the estimated clip count based on spread and minimum duration settings.
|
||||||
|
|
||||||
|
### 6. Retrain with feedback
|
||||||
|
|
||||||
|
Train again — hard negatives are automatically included. Each training run saves with a timestamp. Click the **⏲** button next to the model dropdown to restore a previous version if results degrade — restoring automatically rescans with the selected version.
|
||||||
|
|
||||||
## Database
|
## Database
|
||||||
|
|
||||||
Export history is stored in `~/.8cut.db` (SQLite). The database records filename, start time, output path, label, category, and all encoding settings for every clip. When you open a file, 8-cut matches the filename and pre-populates the timeline with existing markers.
|
Export history is stored in `~/.8cut.db` (SQLite). Tables:
|
||||||
|
|
||||||
|
| Table | Purpose |
|
||||||
|
|-------|---------|
|
||||||
|
| `processed` | Every exported clip with full encoding settings |
|
||||||
|
| `scan_results` | Audio scan detections per video/model |
|
||||||
|
| `hard_negatives` | Timestamps marked as false positives for training |
|
||||||
|
| `hidden_files` | Playlist files hidden by the user |
|
||||||
|
|
||||||
|
The database auto-migrates when new columns are added.
|
||||||
|
|
||||||
|
## File locations
|
||||||
|
|
||||||
|
| Path | Contents |
|
||||||
|
|------|----------|
|
||||||
|
| `~/.8cut.db` | SQLite database |
|
||||||
|
| `models/` | Trained classifier models (`.joblib`) |
|
||||||
|
| `cache/w2v/` | Embedding cache (`.npz`, keyed by video hash) |
|
||||||
|
| `cache/downloads/` | Downloaded pretrained models |
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def build_annotation_json_path(folder: str) -> str:
|
||||||
|
return os.path.join(folder, "dataset.json")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||||
|
"""Remove the entry for *clip_path* from <folder>/dataset.json if present."""
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
if not os.path.exists(json_path):
|
||||||
|
return
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return
|
||||||
|
entries = [e for e in entries if e.get("path") != abs_path]
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||||
|
"""Insert or update one entry in <folder>/dataset.json.
|
||||||
|
|
||||||
|
Each entry stores a path relative to *folder* and the sound label.
|
||||||
|
Matches on ``path``; if an entry for the same clip already exists it is
|
||||||
|
replaced (overwrite-export case). Nothing is written when *label* is
|
||||||
|
empty.
|
||||||
|
"""
|
||||||
|
if not label.strip():
|
||||||
|
return
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
entries: list[dict] = []
|
||||||
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
entries = []
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
entry: dict = {"path": abs_path, "label": label}
|
||||||
|
for i, e in enumerate(entries):
|
||||||
|
if e.get("path") == abs_path:
|
||||||
|
entries[i] = entry
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
entries.append(entry)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
@@ -0,0 +1,655 @@
|
|||||||
|
"""Audio scanning — embedding-based classifier for audio event detection."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_SR = 16000 # lower sr = faster
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
|
||||||
|
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
|
||||||
|
cmd = [
|
||||||
|
_bin("ffmpeg"), "-i", path,
|
||||||
|
"-vn", # skip video
|
||||||
|
"-ac", "1", # mono
|
||||||
|
"-ar", str(sr), # resample
|
||||||
|
"-f", "f32le", # raw 32-bit float little-endian
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"ffmpeg timed out (300s) on {os.path.basename(path)}")
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
|
||||||
|
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||||
|
_WINDOW = 8.0 # seconds
|
||||||
|
_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
_MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
|
||||||
|
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
|
||||||
|
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
|
||||||
|
|
||||||
|
# Redirect torch hub and huggingface downloads into the project
|
||||||
|
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
|
||||||
|
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Embedding extraction (lazy-loaded)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_w2v_model = None
|
||||||
|
_w2v_device = None
|
||||||
|
_w2v_model_name = None
|
||||||
|
|
||||||
|
# Supported embedding models — name → embed_dim
|
||||||
|
_EMBED_MODELS = {
|
||||||
|
"WAV2VEC2_BASE": 768,
|
||||||
|
"WAV2VEC2_LARGE": 1024,
|
||||||
|
"WAV2VEC2_LARGE_LV60K":1024,
|
||||||
|
"HUBERT_BASE": 768,
|
||||||
|
"HUBERT_LARGE": 1024,
|
||||||
|
"HUBERT_XLARGE": 1280,
|
||||||
|
"BEATS": 768,
|
||||||
|
}
|
||||||
|
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||||
|
|
||||||
|
_BEATS_CHECKPOINT = os.path.join(
|
||||||
|
_DL_CACHE_DIR, "huggingface", "hub",
|
||||||
|
"models--lpepino--beats_ckpts", "snapshots",
|
||||||
|
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_w2v_model(model_name: str | None = None):
|
||||||
|
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||||
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
if _w2v_model is None or _w2v_model_name != model_name:
|
||||||
|
import torch
|
||||||
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
if model_name == "BEATS":
|
||||||
|
from .beats_model import BEATs, BEATsConfig
|
||||||
|
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||||
|
weights_only=False)
|
||||||
|
cfg = BEATsConfig(checkpoint['cfg'])
|
||||||
|
_w2v_model = BEATs(cfg)
|
||||||
|
_w2v_model.load_state_dict(checkpoint['model'])
|
||||||
|
_w2v_model.to(_w2v_device)
|
||||||
|
else:
|
||||||
|
import torchaudio
|
||||||
|
bundle = getattr(torchaudio.pipelines, model_name)
|
||||||
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
|
||||||
|
_w2v_model.eval()
|
||||||
|
_w2v_model_name = model_name
|
||||||
|
_log(f"audio_scan: {model_name} loaded on {_w2v_device}")
|
||||||
|
return _w2v_model, _w2v_device
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_dim(model_name: str | None = None) -> int:
|
||||||
|
"""Return embedding dimension for a model name."""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
return _EMBED_MODELS.get(model_name, 768)
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> str:
|
||||||
|
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
abspath = os.path.abspath(video_path)
|
||||||
|
mtime = os.path.getmtime(abspath)
|
||||||
|
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
|
||||||
|
h = hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||||
|
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_exists(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> bool:
|
||||||
|
"""Check if embedding cache exists for a video."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
return os.path.exists(path)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_load(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
|
||||||
|
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
if os.path.exists(path):
|
||||||
|
data = np.load(path)
|
||||||
|
_log(f"audio_scan: cache hit ({path})")
|
||||||
|
return data["timestamps"], data["embeddings"]
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache read failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||||
|
hop: float = 1.0, window: float = _WINDOW,
|
||||||
|
video_path: str | None = None,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Extract embeddings for all sliding windows using a torchaudio model.
|
||||||
|
|
||||||
|
If video_path is given, results are cached to disk for fast re-scans.
|
||||||
|
Returns (timestamps, embeddings) where embeddings is (N, D).
|
||||||
|
"""
|
||||||
|
edim = _embed_dim(model_name)
|
||||||
|
|
||||||
|
# Try loading from cache
|
||||||
|
cache_file = None
|
||||||
|
if video_path:
|
||||||
|
try:
|
||||||
|
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
data = np.load(cache_file)
|
||||||
|
_log(f"audio_scan: cache hit ({cache_file})")
|
||||||
|
return data["timestamps"], data["embeddings"]
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache read failed: {e}")
|
||||||
|
|
||||||
|
win_samples = int(window * sr)
|
||||||
|
hop_samples = int(hop * sr)
|
||||||
|
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
|
||||||
|
|
||||||
|
if n_windows == 0:
|
||||||
|
return np.array([]), np.empty((0, edim))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
model, device = _get_w2v_model(model_name)
|
||||||
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
# Auto-size batches based on available GPU memory
|
||||||
|
batch_size = 16
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
||||||
|
if vram_gb >= 16:
|
||||||
|
batch_size = 64
|
||||||
|
elif vram_gb >= 8:
|
||||||
|
batch_size = 32
|
||||||
|
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
timestamps = np.arange(n_windows) * hop
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for batch_start in range(0, n_windows, batch_size):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return np.array([]), np.empty((0, edim))
|
||||||
|
batch_end = min(batch_start + batch_size, n_windows)
|
||||||
|
chunks = []
|
||||||
|
for i in range(batch_start, batch_end):
|
||||||
|
start = i * hop_samples
|
||||||
|
chunks.append(y[start:start + win_samples])
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings.append(batch_emb)
|
||||||
|
|
||||||
|
result_ts = timestamps
|
||||||
|
result_emb = np.vstack(embeddings)
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
if cache_file:
|
||||||
|
try:
|
||||||
|
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
|
||||||
|
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
|
||||||
|
_log(f"audio_scan: w2v cache saved ({cache_file})")
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache write failed: {e}")
|
||||||
|
|
||||||
|
return result_ts, result_emb
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||||
|
gt_soft: list[float], tolerance: float = 12.0,
|
||||||
|
neg_margin: float = 120.0,
|
||||||
|
model_name: str | None = None,
|
||||||
|
gt_negative: list[float] | None = None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Extract embeddings only near positives and distant negatives.
|
||||||
|
|
||||||
|
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
|
||||||
|
"""
|
||||||
|
edim = _embed_dim(model_name)
|
||||||
|
duration = len(y) / sr
|
||||||
|
win_samples = int(_WINDOW * sr)
|
||||||
|
all_gt = list(gt_intense) + list(gt_soft)
|
||||||
|
|
||||||
|
# Positive windows: every second near intense markers
|
||||||
|
pos_times = set()
|
||||||
|
for gt in gt_intense:
|
||||||
|
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||||
|
t = gt + offset
|
||||||
|
if 0 <= t <= duration - _WINDOW:
|
||||||
|
pos_times.add(int(t))
|
||||||
|
|
||||||
|
# Manual negative windows: near explicit negative markers
|
||||||
|
manual_neg_times = set()
|
||||||
|
if gt_negative:
|
||||||
|
for gt in gt_negative:
|
||||||
|
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||||
|
t = gt + offset
|
||||||
|
if 0 <= t <= duration - _WINDOW:
|
||||||
|
manual_neg_times.add(int(t))
|
||||||
|
# Don't let manual negatives overlap with positives
|
||||||
|
manual_neg_times -= pos_times
|
||||||
|
|
||||||
|
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0 or no markers)
|
||||||
|
neg_times = set()
|
||||||
|
if all_gt and neg_margin > 0:
|
||||||
|
for t in range(0, int(duration - _WINDOW), 4):
|
||||||
|
if min(abs(t - g) for g in all_gt) > neg_margin:
|
||||||
|
neg_times.add(t)
|
||||||
|
|
||||||
|
all_times = sorted(pos_times | neg_times | manual_neg_times)
|
||||||
|
# Filter out windows that go past the end
|
||||||
|
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
||||||
|
|
||||||
|
if not valid_times:
|
||||||
|
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
model, device = _get_w2v_model(model_name)
|
||||||
|
batch_size = 16
|
||||||
|
timestamps_list: list[float] = []
|
||||||
|
embeddings_list: list[np.ndarray] = []
|
||||||
|
|
||||||
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
|
||||||
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||||
|
chunks = []
|
||||||
|
for t in valid_times[batch_start:batch_end]:
|
||||||
|
start = int(t * sr)
|
||||||
|
chunks.append(y[start:start + win_samples])
|
||||||
|
timestamps_list.append(float(t))
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings_list.append(batch_emb)
|
||||||
|
|
||||||
|
timestamps = np.array(timestamps_list)
|
||||||
|
embeddings = np.vstack(embeddings_list)
|
||||||
|
|
||||||
|
labels = np.zeros(len(timestamps), dtype=int)
|
||||||
|
for i, t in enumerate(timestamps):
|
||||||
|
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||||
|
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||||
|
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
|
||||||
|
if di < tolerance:
|
||||||
|
labels[i] = 1
|
||||||
|
elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
|
||||||
|
labels[i] = -1
|
||||||
|
return timestamps, embeddings, labels
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Classifier mode — train / save / load / scan
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
||||||
|
model_path: str | None = None,
|
||||||
|
tolerance: float = 12.0,
|
||||||
|
neg_margin: float = 120.0,
|
||||||
|
embed_model: str | None = None,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
n_workers: int = 4,
|
||||||
|
progress_cb: object = None) -> dict:
|
||||||
|
"""Train a classifier from labeled videos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_infos: list of (video_path, intense_times, soft_times)
|
||||||
|
model_path: if given, save model to this path
|
||||||
|
tolerance/neg_margin: labeling parameters
|
||||||
|
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
||||||
|
cancel_flag: object with _cancel attribute; if set, training aborts early
|
||||||
|
n_workers: number of threads for parallel audio loading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
||||||
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||||
|
|
||||||
|
def _progress(msg: str) -> None:
|
||||||
|
_log(msg)
|
||||||
|
if progress_cb:
|
||||||
|
progress_cb(msg)
|
||||||
|
|
||||||
|
def _load_audio(path: str) -> np.ndarray:
|
||||||
|
return _load_audio_ffmpeg(path, sr=_SR)
|
||||||
|
|
||||||
|
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
|
||||||
|
n = len(video_infos)
|
||||||
|
load_workers = min(n_workers, 4)
|
||||||
|
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
|
||||||
|
audio_data: dict[int, np.ndarray] = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=load_workers) as pool:
|
||||||
|
future_to_idx = {
|
||||||
|
pool.submit(_load_audio, vi[0]): i
|
||||||
|
for i, vi in enumerate(video_infos)
|
||||||
|
}
|
||||||
|
failed = set()
|
||||||
|
for future in as_completed(future_to_idx):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: training cancelled")
|
||||||
|
return None
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
try:
|
||||||
|
audio_data[idx] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
|
||||||
|
failed.add(idx)
|
||||||
|
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
|
||||||
|
|
||||||
|
# Phase 2: extract embeddings sequentially on GPU
|
||||||
|
_progress(f"Extracting embeddings: 0/{n}")
|
||||||
|
all_X, all_y = [], []
|
||||||
|
for vi, vinfo in enumerate(video_infos):
|
||||||
|
if vi in failed:
|
||||||
|
continue
|
||||||
|
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
|
||||||
|
gt_negative = vinfo[3] if len(vinfo) > 3 else []
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: training cancelled")
|
||||||
|
return None
|
||||||
|
_progress(f"Extracting embeddings: {vi+1}/{n}")
|
||||||
|
y = audio_data.pop(vi)
|
||||||
|
|
||||||
|
timestamps, embeddings, labels = _extract_w2v_targeted(
|
||||||
|
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
||||||
|
model_name=embed_model, gt_negative=gt_negative,
|
||||||
|
)
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
continue
|
||||||
|
# Per-video z-score normalize
|
||||||
|
vid_mean = embeddings.mean(axis=0)
|
||||||
|
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
|
||||||
|
normed = (embeddings - vid_mean) / vid_std
|
||||||
|
for i in range(len(labels)):
|
||||||
|
if labels[i] == 1:
|
||||||
|
all_X.append(normed[i])
|
||||||
|
all_y.append(1)
|
||||||
|
elif labels[i] == -1:
|
||||||
|
all_X.append(normed[i])
|
||||||
|
all_y.append(0)
|
||||||
|
|
||||||
|
if not all_X:
|
||||||
|
_log("audio_scan: no training samples collected")
|
||||||
|
return None
|
||||||
|
|
||||||
|
X = np.stack(all_X)
|
||||||
|
y_arr = np.array(all_y)
|
||||||
|
n_pos = (y_arr == 1).sum()
|
||||||
|
n_neg = (y_arr == 0).sum()
|
||||||
|
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
|
||||||
|
|
||||||
|
if n_pos == 0 or n_neg == 0:
|
||||||
|
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Subsample negatives for balance
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
pos_idx = np.where(y_arr == 1)[0]
|
||||||
|
neg_idx = np.where(y_arr == 0)[0]
|
||||||
|
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
|
||||||
|
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
|
||||||
|
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||||
|
rng.shuffle(train_idx)
|
||||||
|
|
||||||
|
_progress(f"Fitting classifier on {len(train_idx)} samples...")
|
||||||
|
clf = HistGradientBoostingClassifier(
|
||||||
|
max_iter=200, max_depth=5, learning_rate=0.1, random_state=42,
|
||||||
|
)
|
||||||
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
model = {"classifier": clf, "n_features": X.shape[1],
|
||||||
|
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
import joblib
|
||||||
|
from datetime import datetime
|
||||||
|
parent = os.path.dirname(model_path)
|
||||||
|
if parent:
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
# Save with timestamp in name; keep a symlink/copy as the "latest"
|
||||||
|
stem, ext = os.path.splitext(model_path)
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
versioned = f"{stem}_{ts}{ext}"
|
||||||
|
joblib.dump(model, versioned)
|
||||||
|
_log(f"audio_scan: model saved to {versioned}")
|
||||||
|
# Update the base path to point to latest version (copy)
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(versioned, model_path)
|
||||||
|
_log(f"audio_scan: latest model updated: {model_path}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_classifier(model_path: str) -> dict | None:
|
||||||
|
"""Load a saved classifier model."""
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
return None
|
||||||
|
import joblib
|
||||||
|
return joblib.load(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def default_model_path(profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> str:
|
||||||
|
"""Return the path for a profile's classifier model.
|
||||||
|
|
||||||
|
When embed_model is given the file is ``{profile}_{model}.joblib``,
|
||||||
|
otherwise ``{profile}.joblib`` (legacy single-model layout).
|
||||||
|
"""
|
||||||
|
if embed_model:
|
||||||
|
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
|
||||||
|
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def list_model_versions(profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> list[tuple[str, str]]:
|
||||||
|
"""Return available backup versions for a model, newest first.
|
||||||
|
|
||||||
|
Returns list of (timestamp_label, file_path).
|
||||||
|
The current (active) model is listed first as "current".
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
current = default_model_path(profile_name, embed_model)
|
||||||
|
stem, ext = os.path.splitext(current)
|
||||||
|
versions: list[tuple[str, str]] = []
|
||||||
|
if os.path.exists(current):
|
||||||
|
versions.append(("current", current))
|
||||||
|
if not os.path.isdir(_MODEL_DIR):
|
||||||
|
return versions
|
||||||
|
pattern = re.compile(re.escape(os.path.basename(stem)) + r"_(\d{8}_\d{6})" + re.escape(ext) + "$")
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
versions.append((m.group(1), os.path.join(_MODEL_DIR, fname)))
|
||||||
|
# Sort backups newest first (after "current")
|
||||||
|
current_entry = versions[:1]
|
||||||
|
backups = sorted(versions[1:], key=lambda v: v[0], reverse=True)
|
||||||
|
return current_entry + backups
|
||||||
|
|
||||||
|
|
||||||
|
def restore_model_version(version_path: str, profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> None:
|
||||||
|
"""Restore a backup version as the active model."""
|
||||||
|
import filecmp, shutil
|
||||||
|
from datetime import datetime
|
||||||
|
current = default_model_path(profile_name, embed_model)
|
||||||
|
if version_path == current:
|
||||||
|
return
|
||||||
|
# Back up current before replacing — but only if no identical backup exists
|
||||||
|
if os.path.exists(current):
|
||||||
|
stem, ext = os.path.splitext(current)
|
||||||
|
already_saved = False
|
||||||
|
if os.path.isdir(_MODEL_DIR):
|
||||||
|
import re
|
||||||
|
pat = re.compile(re.escape(os.path.basename(stem)) + r"_\d{8}_\d{6}" + re.escape(ext) + "$")
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
if pat.match(fname):
|
||||||
|
candidate = os.path.join(_MODEL_DIR, fname)
|
||||||
|
if filecmp.cmp(current, candidate, shallow=False):
|
||||||
|
already_saved = True
|
||||||
|
break
|
||||||
|
if not already_saved:
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
shutil.move(current, f"{stem}_{ts}{ext}")
|
||||||
|
shutil.copy2(version_path, current)
|
||||||
|
_log(f"audio_scan: restored {os.path.basename(version_path)} as active model")
|
||||||
|
|
||||||
|
|
||||||
|
def list_trained_models(profile_name: str = "default") -> list[str]:
|
||||||
|
"""Return embedding model names that have a trained .joblib for *profile_name*.
|
||||||
|
|
||||||
|
Looks for files matching ``{profile}_{MODEL}.joblib`` in the models dir.
|
||||||
|
"""
|
||||||
|
prefix = f"{profile_name}_"
|
||||||
|
suffix = ".joblib"
|
||||||
|
result = []
|
||||||
|
if not os.path.isdir(_MODEL_DIR):
|
||||||
|
return result
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
if fname.startswith(prefix) and fname.endswith(suffix):
|
||||||
|
model_name = fname[len(prefix):-len(suffix)]
|
||||||
|
if model_name in _EMBED_MODELS:
|
||||||
|
result.append(model_name)
|
||||||
|
# Also check legacy {profile}.joblib
|
||||||
|
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
if os.path.exists(legacy) and not result:
|
||||||
|
# Legacy model — we don't know the embed model, but it's usable
|
||||||
|
result.append("")
|
||||||
|
return sorted(result)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scanning
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _fuse_regions(regions: list[tuple[float, float, float]]
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Merge overlapping/adjacent regions, keeping max score."""
|
||||||
|
if not regions:
|
||||||
|
return []
|
||||||
|
by_start = sorted(regions, key=lambda r: r[0])
|
||||||
|
fused: list[tuple[float, float, float]] = []
|
||||||
|
s, e, sc = by_start[0]
|
||||||
|
for s2, e2, sc2 in by_start[1:]:
|
||||||
|
if s2 <= e: # overlapping or touching
|
||||||
|
e = max(e, e2)
|
||||||
|
sc = max(sc, sc2)
|
||||||
|
else:
|
||||||
|
fused.append((s, e, sc))
|
||||||
|
s, e, sc = s2, e2, sc2
|
||||||
|
fused.append((s, e, sc))
|
||||||
|
return fused
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_audio(video_path: str, embed_model: str | None = None,
|
||||||
|
hop: float = 1.0, window: float = _WINDOW) -> np.ndarray | None:
|
||||||
|
"""Pre-load audio for a video if embeddings aren't cached.
|
||||||
|
|
||||||
|
Returns the raw audio array, or None if cache already exists.
|
||||||
|
Call from a background thread while the GPU is busy with another video.
|
||||||
|
"""
|
||||||
|
if _w2v_cache_exists(video_path, hop, window, embed_model):
|
||||||
|
return None
|
||||||
|
_log(f"audio_scan: prefetching {os.path.basename(video_path)}")
|
||||||
|
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||||
|
_log(f"audio_scan: prefetched {len(y)/_SR:.1f}s")
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def scan_video(
|
||||||
|
video_path: str,
|
||||||
|
model: dict = None,
|
||||||
|
threshold: float = 0.30,
|
||||||
|
hop: float = 1.0,
|
||||||
|
window: float = _WINDOW,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
prefetched_audio: np.ndarray | None = None,
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Scan a video for matching audio regions using a trained classifier.
|
||||||
|
|
||||||
|
Returns list of (start_time, end_time, score) above threshold.
|
||||||
|
If prefetched_audio is provided, skips the ffmpeg decode step.
|
||||||
|
"""
|
||||||
|
if model is None:
|
||||||
|
_log("audio_scan: no model provided")
|
||||||
|
return []
|
||||||
|
|
||||||
|
clf = model["classifier"]
|
||||||
|
embed_model = model.get("embed_model")
|
||||||
|
|
||||||
|
# Try cache first — skip expensive audio loading if embeddings exist
|
||||||
|
cached = _w2v_cache_load(video_path, hop, window, embed_model)
|
||||||
|
if cached is not None:
|
||||||
|
timestamps, window_vectors = cached
|
||||||
|
else:
|
||||||
|
if prefetched_audio is not None:
|
||||||
|
_log(f"audio_scan: using prefetched audio")
|
||||||
|
y = prefetched_audio
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: loading {video_path}")
|
||||||
|
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||||
|
sr = _SR
|
||||||
|
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
|
||||||
|
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
||||||
|
timestamps, window_vectors = _extract_w2v_windows(
|
||||||
|
y, sr, hop=hop, window=window, video_path=video_path,
|
||||||
|
cancel_flag=cancel_flag, model_name=embed_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
_log("audio_scan: video shorter than window")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Per-video z-score normalize
|
||||||
|
vid_mean = window_vectors.mean(axis=0)
|
||||||
|
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
|
||||||
|
normed = (window_vectors - vid_mean) / vid_std
|
||||||
|
|
||||||
|
_log(f"audio_scan: classifying {len(normed)} windows...")
|
||||||
|
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
probs = clf.predict_proba(normed)[:, 1]
|
||||||
|
mask = probs >= threshold
|
||||||
|
raw = [
|
||||||
|
(timestamps[i], timestamps[i] + window, float(probs[i]))
|
||||||
|
for i in np.nonzero(mask)[0]
|
||||||
|
]
|
||||||
|
results = _fuse_regions(raw)
|
||||||
|
_log(f"audio_scan: {len(results)} regions above threshold {threshold} (from {len(raw)} raw)")
|
||||||
|
return results
|
||||||
@@ -0,0 +1,783 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import LayerNorm, Parameter
|
||||||
|
from .beats_modules import (
|
||||||
|
GradMultiply,
|
||||||
|
SamePad,
|
||||||
|
get_activation_fn,
|
||||||
|
GLU_Linear,
|
||||||
|
quant_noise,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.embedding_dim = args.encoder_embed_dim
|
||||||
|
|
||||||
|
self.pos_conv = nn.Conv1d(
|
||||||
|
self.embedding_dim,
|
||||||
|
self.embedding_dim,
|
||||||
|
kernel_size=args.conv_pos,
|
||||||
|
padding=args.conv_pos // 2,
|
||||||
|
groups=args.conv_pos_groups,
|
||||||
|
)
|
||||||
|
dropout = 0
|
||||||
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||||
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
self.relative_position_embedding = args.relative_position_embedding
|
||||||
|
self.num_buckets = args.num_buckets
|
||||||
|
self.max_distance = args.max_distance
|
||||||
|
else:
|
||||||
|
self.relative_position_embedding = False
|
||||||
|
self.num_buckets = 0
|
||||||
|
self.max_distance = 0
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
deep_norm=args.deep_norm,
|
||||||
|
has_relative_attention_bias=self.relative_position_embedding,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
gru_rel_pos=args.gru_rel_pos,
|
||||||
|
encoder_layers=args.encoder_layers,
|
||||||
|
)
|
||||||
|
for i in range(args.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.relative_position_embedding:
|
||||||
|
for i in range(1, args.encoder_layers):
|
||||||
|
del self.layers[i].self_attn.relative_attention_bias
|
||||||
|
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||||
|
|
||||||
|
self.layer_norm_first = args.layer_norm_first
|
||||||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
self.layerdrop = args.encoder_layerdrop
|
||||||
|
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
if args.deep_norm:
|
||||||
|
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||||
|
for i in range(args.encoder_layers):
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, layer=None):
|
||||||
|
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x[padding_mask] = 0
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x = x + x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
z = None
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
r = None
|
||||||
|
pos_bias = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||||
|
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||||
|
dropout_probability = np.random.random()
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: float = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
deep_norm: bool = False,
|
||||||
|
has_relative_attention_bias: bool = False,
|
||||||
|
num_buckets: int = 0,
|
||||||
|
max_distance: int = 0,
|
||||||
|
rescale_init: bool = False,
|
||||||
|
gru_rel_pos: bool = False,
|
||||||
|
encoder_layers: int = 0,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
|
||||||
|
self.activation_name = activation_fn
|
||||||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
self.embedding_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
dropout=attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
max_distance=max_distance,
|
||||||
|
rescale_init=rescale_init,
|
||||||
|
gru_rel_pos=gru_rel_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_norm_first = layer_norm_first
|
||||||
|
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||||
|
else:
|
||||||
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||||
|
|
||||||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
self.deep_norm = deep_norm
|
||||||
|
if self.deep_norm:
|
||||||
|
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||||
|
else:
|
||||||
|
self.deep_norm_alpha = 1
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
self_attn_mask: torch.Tensor = None,
|
||||||
|
self_attn_padding_mask: torch.Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
pos_bias=None
|
||||||
|
):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if self.layer_norm_first:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, pos_bias
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention.
|
||||||
|
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
q_noise=0.0,
|
||||||
|
qn_block_size=8,
|
||||||
|
has_relative_attention_bias=False,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
gru_rel_pos=False,
|
||||||
|
rescale_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout_module = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_head_dim = self.head_dim
|
||||||
|
self.k_head_dim = self.head_dim
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
|
)
|
||||||
|
|
||||||
|
k_bias = True
|
||||||
|
if rescale_init:
|
||||||
|
k_bias = False
|
||||||
|
|
||||||
|
k_embed_dim = embed_dim
|
||||||
|
q_embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.k_proj = quant_noise(
|
||||||
|
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.v_proj = quant_noise(
|
||||||
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.q_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self.gru_rel_pos = gru_rel_pos
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||||
|
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.qkv_same_dim:
|
||||||
|
# Empirically observed the convergence to be much better with
|
||||||
|
# the scaled initialization
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
if self.out_proj.bias is not None:
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
if self.bias_k is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_v)
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||||
|
|
||||||
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
max_distance = self.max_distance
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets = num_buckets // 2
|
||||||
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||||
|
relative_positions = torch.abs(relative_positions)
|
||||||
|
else:
|
||||||
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_positions < max_exact
|
||||||
|
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_positions.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = self._relative_positions_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True
|
||||||
|
)
|
||||||
|
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
values = values.permute([2, 0, 1])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key: Optional[Tensor],
|
||||||
|
value: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
static_kv: bool = False,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
before_softmax: bool = False,
|
||||||
|
need_head_weights: bool = False,
|
||||||
|
position_bias: Optional[Tensor] = None
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
|
"""Input shape: Time x Batch x Channel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
|
padding elements are indicated by 1s.
|
||||||
|
need_weights (bool, optional): return the attention weights,
|
||||||
|
averaged over heads (default: False).
|
||||||
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
|
implement causal attention, where the mask prevents the
|
||||||
|
attention from looking forward in time (default: None).
|
||||||
|
before_softmax (bool, optional): return the raw attention
|
||||||
|
weights and values before the attention softmax.
|
||||||
|
need_head_weights (bool, optional): return the attention
|
||||||
|
weights for each head. Implies *need_weights*. Default:
|
||||||
|
return the average attention weights over all heads.
|
||||||
|
"""
|
||||||
|
if need_head_weights:
|
||||||
|
need_weights = True
|
||||||
|
|
||||||
|
is_tpu = query.device.type == "xla"
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
|
if key is not None:
|
||||||
|
src_len, key_bsz, _ = key.size()
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
assert key_bsz == bsz
|
||||||
|
assert value is not None
|
||||||
|
assert src_len, bsz == value.shape[:2]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias and position_bias is None:
|
||||||
|
position_bias = self.compute_bias(tgt_len, src_len)
|
||||||
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
saved_state = self._get_input_buffer(incremental_state)
|
||||||
|
if saved_state is not None and "prev_key" in saved_state:
|
||||||
|
# previous time steps are cached - no need to recompute
|
||||||
|
# key and value if they are static
|
||||||
|
if static_kv:
|
||||||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||||||
|
key = value = None
|
||||||
|
else:
|
||||||
|
saved_state = None
|
||||||
|
|
||||||
|
if self.self_attention:
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(query)
|
||||||
|
v = self.v_proj(query)
|
||||||
|
elif self.encoder_decoder_attention:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self.q_proj(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(key)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert key is not None and value is not None
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
alpha = 32
|
||||||
|
q *= 1 / alpha
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||||
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
q.contiguous()
|
||||||
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if k is not None:
|
||||||
|
k = (
|
||||||
|
k.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if v is not None:
|
||||||
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved_state is not None:
|
||||||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
|
if "prev_key" in saved_state:
|
||||||
|
_prev_key = saved_state["prev_key"]
|
||||||
|
assert _prev_key is not None
|
||||||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
k = prev_key
|
||||||
|
else:
|
||||||
|
assert k is not None
|
||||||
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
|
src_len = k.size(1)
|
||||||
|
if "prev_value" in saved_state:
|
||||||
|
_prev_value = saved_state["prev_value"]
|
||||||
|
assert _prev_value is not None
|
||||||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
v = prev_value
|
||||||
|
else:
|
||||||
|
assert v is not None
|
||||||
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||||||
|
if "prev_key_padding_mask" in saved_state:
|
||||||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||||
|
assert k is not None and v is not None
|
||||||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||||||
|
batch_size=bsz,
|
||||||
|
src_len=k.size(1),
|
||||||
|
static_kv=static_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
|
# In this branch incremental_state is never None
|
||||||
|
assert incremental_state is not None
|
||||||
|
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||||
|
assert k is not None
|
||||||
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
|
# This is part of a workaround to get around fork/join parallelism
|
||||||
|
# not supporting Optional types.
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||||
|
key_padding_mask = None
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz
|
||||||
|
assert key_padding_mask.size(1) == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
assert v is not None
|
||||||
|
src_len += 1
|
||||||
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||||
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||||
|
key_padding_mask
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||||
|
|
||||||
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# don't attend to padding symbols
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
if not is_tpu:
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if before_softmax:
|
||||||
|
return attn_weights, v, position_bias
|
||||||
|
|
||||||
|
if position_bias is not None:
|
||||||
|
attn_mask_rel_pos = position_bias
|
||||||
|
if self.gru_rel_pos == 1:
|
||||||
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||||
|
_B, _H, _L, __ = query_layer.size()
|
||||||
|
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||||
|
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||||
|
|
||||||
|
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attn_mask_rel_pos
|
||||||
|
|
||||||
|
attn_weights_float = F.softmax(
|
||||||
|
attn_weights, dim=-1
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
assert v is not None
|
||||||
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights: Optional[Tensor] = None
|
||||||
|
if need_weights:
|
||||||
|
attn_weights = attn_weights_float.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
).transpose(1, 0)
|
||||||
|
if not need_head_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_weights = attn_weights.mean(dim=0)
|
||||||
|
|
||||||
|
return attn, attn_weights, position_bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_prev_key_padding_mask(
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
prev_key_padding_mask: Optional[Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
src_len: int,
|
||||||
|
static_kv: bool,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
# During incremental decoding, as the padding token enters and
|
||||||
|
# leaves the frame, there will be a time when prev or current
|
||||||
|
# is None
|
||||||
|
elif prev_key_padding_mask is not None:
|
||||||
|
if src_len > prev_key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
device=prev_key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask.float()
|
||||||
|
elif key_padding_mask is not None:
|
||||||
|
if src_len > key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
device=key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[filler.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = key_padding_mask.float()
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
return new_key_padding_mask
|
||||||
|
|
||||||
|
def _get_input_buffer(
|
||||||
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||||
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||||
|
return empty_result
|
||||||
|
|
||||||
|
def _set_input_buffer(
|
||||||
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
|
):
|
||||||
|
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||||
|
|
||||||
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||||
|
return attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
"""
|
||||||
|
Initialize the weights specific to the BERT Model.
|
||||||
|
This overrides the default initializations depending on the specified arguments.
|
||||||
|
1. If normal_init_linear_weights is set then weights of linear
|
||||||
|
layer will be initialized using the normal distribution and
|
||||||
|
bais will be set to the specified value.
|
||||||
|
2. If normal_init_embed_weights is set then weights of embedding
|
||||||
|
layer will be initialized using the normal distribution.
|
||||||
|
3. If normal_init_proj_weights is set then weights of
|
||||||
|
in_project_weight for MultiHeadAttention initialized using
|
||||||
|
the normal distribution (to be validated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||||
|
# so that the RNG is consistent with and without FSDP
|
||||||
|
data.copy_(
|
||||||
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
normal_(module.q_proj.weight.data)
|
||||||
|
normal_(module.k_proj.weight.data)
|
||||||
|
normal_(module.v_proj.weight.data)
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
import torchaudio.compliance.kaldi as ta_kaldi
|
||||||
|
|
||||||
|
from .beats_backbone import (
|
||||||
|
TransformerEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATsConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.input_patch_size: int = -1 # path size of patch embedding
|
||||||
|
self.embed_dim: int = 512 # patch embedding dimension
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||||
|
|
||||||
|
# label predictor
|
||||||
|
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||||
|
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||||
|
self.predictor_class: int = 527 # target class number for the predictor
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATs(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: BEATsConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.embed = cfg.embed_dim
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||||
|
if self.embed != cfg.encoder_embed_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_patch_size = cfg.input_patch_size
|
||||||
|
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||||
|
bias=cfg.conv_bias)
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
|
||||||
|
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
if cfg.finetuned_model:
|
||||||
|
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||||
|
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||||
|
else:
|
||||||
|
self.predictor = None
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(
|
||||||
|
padding_mask.size(0), features.size(1), -1
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask.all(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fbanks = []
|
||||||
|
for waveform in source:
|
||||||
|
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||||
|
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||||
|
fbanks.append(fbank)
|
||||||
|
fbank = torch.stack(fbanks, dim=0)
|
||||||
|
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
):
|
||||||
|
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||||
|
|
||||||
|
fbank = fbank.unsqueeze(1)
|
||||||
|
features = self.patch_embedding(fbank)
|
||||||
|
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
|
||||||
|
x = self.dropout_input(features)
|
||||||
|
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.predictor is not None:
|
||||||
|
x = self.predictor_dropout(x)
|
||||||
|
logits = self.predictor(x)
|
||||||
|
|
||||||
|
if padding_mask is not None and padding_mask.any():
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1)
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||||
|
else:
|
||||||
|
logits = logits.mean(dim=1)
|
||||||
|
|
||||||
|
lprobs = torch.sigmoid(logits)
|
||||||
|
|
||||||
|
return lprobs, padding_mask
|
||||||
|
else:
|
||||||
|
return x, padding_mask
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class GradMultiply(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scale):
|
||||||
|
ctx.scale = scale
|
||||||
|
res = x.new(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
return grad * ctx.scale, None
|
||||||
|
|
||||||
|
|
||||||
|
class SamePad(nn.Module):
|
||||||
|
def __init__(self, kernel_size, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
if causal:
|
||||||
|
self.remove = kernel_size - 1
|
||||||
|
else:
|
||||||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.remove > 0:
|
||||||
|
x = x[:, :, : -self.remove]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
self.act = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU_Linear(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||||
|
super(GLU_Linear, self).__init__()
|
||||||
|
|
||||||
|
self.glu_type = glu_type
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
if glu_type == "sigmoid":
|
||||||
|
self.glu_act = torch.nn.Sigmoid()
|
||||||
|
elif glu_type == "swish":
|
||||||
|
self.glu_act = Swish()
|
||||||
|
elif glu_type == "relu":
|
||||||
|
self.glu_act = torch.nn.ReLU()
|
||||||
|
elif glu_type == "gelu":
|
||||||
|
self.glu_act = torch.nn.GELU()
|
||||||
|
|
||||||
|
if bias_in_glu:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||||
|
else:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.glu_type == "bilinear":
|
||||||
|
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||||
|
else:
|
||||||
|
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_accurate(x):
|
||||||
|
if not hasattr(gelu_accurate, "_a"):
|
||||||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||||
|
return (
|
||||||
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation: str):
|
||||||
|
"""Returns the activation function corresponding to `activation`"""
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
warnings.warn(
|
||||||
|
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||||
|
)
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "gelu_accurate":
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "tanh":
|
||||||
|
return torch.tanh
|
||||||
|
elif activation == "linear":
|
||||||
|
return lambda x: x
|
||||||
|
elif activation == "glu":
|
||||||
|
return lambda x: x
|
||||||
|
else:
|
||||||
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||||
|
|
||||||
|
|
||||||
|
def quant_noise(module, p, block_size):
|
||||||
|
"""
|
||||||
|
Wraps modules and applies quantization noise to the weights for
|
||||||
|
subsequent quantization with Iterative Product Quantization as
|
||||||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- module: nn.Module
|
||||||
|
- p: amount of Quantization Noise
|
||||||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
- Module weights must have the right sizes wrt the block size
|
||||||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||||||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||||
|
- We implement the simplest form of noise here as stated in the paper
|
||||||
|
which consists in randomly dropping blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no quantization noise, don't register hook
|
||||||
|
if p <= 0:
|
||||||
|
return module
|
||||||
|
|
||||||
|
# supported modules
|
||||||
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||||
|
|
||||||
|
# test whether module.weight has the right sizes wrt block_size
|
||||||
|
is_conv = module.weight.ndim == 4
|
||||||
|
|
||||||
|
# 2D matrix
|
||||||
|
if not is_conv:
|
||||||
|
assert (
|
||||||
|
module.weight.size(1) % block_size == 0
|
||||||
|
), "Input features must be a multiple of block sizes"
|
||||||
|
|
||||||
|
# 4D matrix
|
||||||
|
else:
|
||||||
|
# 1x1 convolutions
|
||||||
|
if module.kernel_size == (1, 1):
|
||||||
|
assert (
|
||||||
|
module.in_channels % block_size == 0
|
||||||
|
), "Input channels must be a multiple of block sizes"
|
||||||
|
# regular convolutions
|
||||||
|
else:
|
||||||
|
k = module.kernel_size[0] * module.kernel_size[1]
|
||||||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||||
|
|
||||||
|
def _forward_pre_hook(mod, input):
|
||||||
|
# no noise for evaluation
|
||||||
|
if mod.training:
|
||||||
|
if not is_conv:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_features = weight.size(1)
|
||||||
|
out_features = weight.size(0)
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
mask = torch.zeros(
|
||||||
|
in_features // block_size * out_features, device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_channels = mod.in_channels
|
||||||
|
out_channels = mod.out_channels
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
if mod.kernel_size == (1, 1):
|
||||||
|
mask = torch.zeros(
|
||||||
|
int(in_channels // block_size * out_channels),
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(
|
||||||
|
weight.size(0), weight.size(1), device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = (
|
||||||
|
mask.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale weights and apply mask
|
||||||
|
mask = mask.to(
|
||||||
|
torch.bool
|
||||||
|
) # x.bool() is not currently supported in TorchScript
|
||||||
|
s = 1 / (1 - p)
|
||||||
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||||||
|
return module
|
||||||
|
|
||||||
+625
@@ -0,0 +1,625 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .paths import _log
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessedDB:
|
||||||
|
_SCHEMA_VERSION = 4 # bump when schema changes
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | None = None):
|
||||||
|
if db_path is None:
|
||||||
|
db_path = str(Path.home() / ".8cut.db")
|
||||||
|
self._path = db_path
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
try:
|
||||||
|
self._con = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
|
self._migrate()
|
||||||
|
self._enabled = True
|
||||||
|
_log(f"DB opened: {db_path}")
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"DB unavailable: {e}")
|
||||||
|
self._con = None
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
def _migrate(self) -> None:
|
||||||
|
"""Create table if missing, then add any new columns for old DBs."""
|
||||||
|
cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(processed)").fetchall()
|
||||||
|
}
|
||||||
|
if not cols:
|
||||||
|
# Fresh DB — create from scratch
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS processed ("
|
||||||
|
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||||
|
" filename TEXT NOT NULL,"
|
||||||
|
" start_time REAL NOT NULL,"
|
||||||
|
" output_path TEXT NOT NULL,"
|
||||||
|
" label TEXT NOT NULL DEFAULT '',"
|
||||||
|
" category TEXT NOT NULL DEFAULT '',"
|
||||||
|
" short_side INTEGER DEFAULT 512,"
|
||||||
|
" portrait_ratio TEXT NOT NULL DEFAULT '',"
|
||||||
|
" crop_center REAL NOT NULL DEFAULT 0.5,"
|
||||||
|
" format TEXT NOT NULL DEFAULT 'MP4',"
|
||||||
|
" clip_count INTEGER NOT NULL DEFAULT 3,"
|
||||||
|
" spread REAL NOT NULL DEFAULT 3.0,"
|
||||||
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
|
" source_path TEXT NOT NULL DEFAULT '',"
|
||||||
|
" scan_export INTEGER NOT NULL DEFAULT 0,"
|
||||||
|
" processed_at TEXT NOT NULL"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Add missing columns to legacy tables
|
||||||
|
new_cols = {
|
||||||
|
"label": "TEXT NOT NULL DEFAULT ''",
|
||||||
|
"category": "TEXT NOT NULL DEFAULT ''",
|
||||||
|
"short_side": "INTEGER DEFAULT 512",
|
||||||
|
"portrait_ratio": "TEXT NOT NULL DEFAULT ''",
|
||||||
|
"crop_center": "REAL NOT NULL DEFAULT 0.5",
|
||||||
|
"format": "TEXT NOT NULL DEFAULT 'MP4'",
|
||||||
|
"clip_count": "INTEGER NOT NULL DEFAULT 3",
|
||||||
|
"spread": "REAL NOT NULL DEFAULT 3.0",
|
||||||
|
"profile": "TEXT NOT NULL DEFAULT 'default'",
|
||||||
|
"source_path": "TEXT NOT NULL DEFAULT ''",
|
||||||
|
"scan_export": "INTEGER NOT NULL DEFAULT 0",
|
||||||
|
}
|
||||||
|
for col, typedef in new_cols.items():
|
||||||
|
if col not in cols:
|
||||||
|
self._con.execute(
|
||||||
|
f"ALTER TABLE processed ADD COLUMN {col} {typedef}"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_filename ON processed(filename)"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS hidden_files ("
|
||||||
|
" filename TEXT NOT NULL,"
|
||||||
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
|
" PRIMARY KEY (filename, profile)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS scan_results ("
|
||||||
|
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||||
|
" filename TEXT NOT NULL,"
|
||||||
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
|
" model TEXT NOT NULL,"
|
||||||
|
" start_time REAL NOT NULL,"
|
||||||
|
" end_time REAL NOT NULL,"
|
||||||
|
" score REAL NOT NULL,"
|
||||||
|
" disabled INTEGER NOT NULL DEFAULT 0,"
|
||||||
|
" orig_start_time REAL,"
|
||||||
|
" orig_end_time REAL"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
# Migrate: add new columns to existing scan_results tables
|
||||||
|
sr_cols = {
|
||||||
|
row[1]
|
||||||
|
for row in self._con.execute("PRAGMA table_info(scan_results)").fetchall()
|
||||||
|
}
|
||||||
|
for col, typedef in [
|
||||||
|
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
||||||
|
("orig_start_time", "REAL"),
|
||||||
|
("orig_end_time", "REAL"),
|
||||||
|
]:
|
||||||
|
if col not in sr_cols:
|
||||||
|
self._con.execute(
|
||||||
|
f"ALTER TABLE scan_results ADD COLUMN {col} {typedef}"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model"
|
||||||
|
" ON scan_results(filename, profile, model)"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS hard_negatives ("
|
||||||
|
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||||
|
" filename TEXT NOT NULL,"
|
||||||
|
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||||
|
" start_time REAL NOT NULL,"
|
||||||
|
" source_path TEXT NOT NULL DEFAULT ''"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
self._con.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||||
|
" ON hard_negatives(filename, profile)"
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def add(self, filename: str, start_time: float, output_path: str,
|
||||||
|
label: str = "", category: str = "",
|
||||||
|
short_side: int | None = None, portrait_ratio: str = "",
|
||||||
|
crop_center: float = 0.5, fmt: str = "MP4",
|
||||||
|
clip_count: int = 3, spread: float = 3.0,
|
||||||
|
profile: str = "default", source_path: str = "",
|
||||||
|
scan_export: bool = False) -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"INSERT INTO processed"
|
||||||
|
" (filename, start_time, output_path, label, category,"
|
||||||
|
" short_side, portrait_ratio, crop_center, format,"
|
||||||
|
" clip_count, spread, profile, source_path, scan_export, processed_at)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(filename, start_time, output_path, label, category,
|
||||||
|
short_side, portrait_ratio, crop_center, fmt,
|
||||||
|
clip_count, spread, profile, source_path,
|
||||||
|
1 if scan_export else 0,
|
||||||
|
datetime.now(timezone.utc).isoformat()),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_labels(self) -> list[str]:
|
||||||
|
"""Return distinct non-empty labels ordered by most recently used."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT label FROM processed"
|
||||||
|
" WHERE label != '' ORDER BY processed_at DESC"
|
||||||
|
).fetchall()
|
||||||
|
# Deduplicate while preserving order (DISTINCT on processed_at DESC
|
||||||
|
# may return duplicates if the same label was used multiple times).
|
||||||
|
seen: set[str] = set()
|
||||||
|
result = []
|
||||||
|
for (lbl,) in rows:
|
||||||
|
if lbl not in seen:
|
||||||
|
seen.add(lbl)
|
||||||
|
result.append(lbl)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_by_output_path(self, output_path: str) -> dict | None:
|
||||||
|
"""Return config dict for an output_path, or None."""
|
||||||
|
if not self._enabled:
|
||||||
|
return None
|
||||||
|
cur = self._con.cursor()
|
||||||
|
cur.row_factory = sqlite3.Row
|
||||||
|
row = cur.execute(
|
||||||
|
"SELECT label, category, short_side, portrait_ratio, crop_center, format,"
|
||||||
|
" clip_count, spread"
|
||||||
|
" FROM processed WHERE output_path = ?",
|
||||||
|
(output_path,),
|
||||||
|
).fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
def delete_by_output_path(self, output_path: str) -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute("DELETE FROM processed WHERE output_path = ?", (output_path,))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_group(self, output_path: str, profile: str = "") -> list[str]:
|
||||||
|
"""Return all output_paths sharing the same (filename, start_time, profile) as *output_path*."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
row = self._con.execute(
|
||||||
|
"SELECT filename, start_time, profile FROM processed WHERE output_path = ?",
|
||||||
|
(output_path,),
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return []
|
||||||
|
filename, start_time, row_profile = row
|
||||||
|
p = profile or row_profile
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT output_path FROM processed"
|
||||||
|
" WHERE filename = ? AND start_time = ? AND profile = ? ORDER BY output_path",
|
||||||
|
(filename, start_time, p),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def delete_group(self, output_path: str, profile: str = "") -> list[str]:
|
||||||
|
"""Delete all rows sharing the same (filename, start_time, profile) as *output_path*.
|
||||||
|
Returns list of deleted output_paths."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
with self._lock:
|
||||||
|
row = self._con.execute(
|
||||||
|
"SELECT filename, start_time, profile FROM processed WHERE output_path = ?",
|
||||||
|
(output_path,),
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return []
|
||||||
|
filename, start_time, row_profile = row
|
||||||
|
p = profile or row_profile
|
||||||
|
paths = [r[0] for r in self._con.execute(
|
||||||
|
"SELECT output_path FROM processed"
|
||||||
|
" WHERE filename = ? AND start_time = ? AND profile = ?",
|
||||||
|
(filename, start_time, p),
|
||||||
|
).fetchall()]
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM processed WHERE filename = ? AND start_time = ? AND profile = ?",
|
||||||
|
(filename, start_time, p),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def _get_markers_for(self, match: str, profile: str = "default") -> list[tuple[float, int, str]]:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT start_time, output_path FROM processed"
|
||||||
|
" WHERE filename = ? AND profile = ? AND scan_export = 0"
|
||||||
|
" ORDER BY start_time",
|
||||||
|
(match, profile),
|
||||||
|
).fetchall()
|
||||||
|
# Deduplicate by start_time — batch exports share the same cursor.
|
||||||
|
seen_times: dict[float, tuple[float, int, str]] = {}
|
||||||
|
n = 0
|
||||||
|
for t, p in rows:
|
||||||
|
if t not in seen_times:
|
||||||
|
n += 1
|
||||||
|
seen_times[t] = (t, n, p)
|
||||||
|
return list(seen_times.values())
|
||||||
|
|
||||||
|
def get_markers(self, filename: str, profile: str = "default") -> list[tuple[float, int, str]]:
|
||||||
|
"""Return [(start_time, marker_number, output_path), ...] for exact
|
||||||
|
filename match, sorted by start_time. Empty list if no match.
|
||||||
|
Excludes scan exports (shown via scan panel instead)."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
return self._get_markers_for(filename, profile)
|
||||||
|
|
||||||
|
def get_clip_count(self, filename: str, profile: str = "default") -> int:
|
||||||
|
"""Return total number of exported clips (including scan exports)."""
|
||||||
|
if not self._enabled:
|
||||||
|
return 0
|
||||||
|
row = self._con.execute(
|
||||||
|
"SELECT COUNT(*) FROM processed WHERE filename = ? AND profile = ?",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchone()
|
||||||
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
def get_profiles(self) -> list[str]:
|
||||||
|
"""Return distinct profile names, ordered alphabetically."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT profile FROM processed ORDER BY profile"
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||||
|
"""Return all unique output_path values for a given profile."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def get_export_folders(self, profile: str = "default") -> list[str]:
|
||||||
|
"""Return distinct export folder names found in output_paths for a profile.
|
||||||
|
|
||||||
|
Export paths follow the structure:
|
||||||
|
.../export_folder/group_dir/clip.mp4
|
||||||
|
The export folder is 2 levels up from the clip file.
|
||||||
|
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
folder_names: set[str] = set()
|
||||||
|
for (op,) in rows:
|
||||||
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
|
if grandparent:
|
||||||
|
folder_names.add(grandparent)
|
||||||
|
return sorted(folder_names)
|
||||||
|
|
||||||
|
def get_training_data(self, profile: str, positive_folder: str,
|
||||||
|
negative_folder: str = "",
|
||||||
|
fallback_video_dir: str = "",
|
||||||
|
include_scan_exports: bool = False,
|
||||||
|
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||||
|
"""Build training video_infos from DB data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile: profile name
|
||||||
|
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
|
||||||
|
negative_folder: export folder name for explicit negatives (optional)
|
||||||
|
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||||
|
include_scan_exports: if True, include auto-exported scan clips
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||||
|
per video. Soft times = clips from any other non-negative folder.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
if include_scan_exports:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path, source_path"
|
||||||
|
" FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path, source_path"
|
||||||
|
" FROM processed WHERE profile = ? AND scan_export = 0",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
# Collect times by video, split by folder role
|
||||||
|
pos_by_video: dict[str, set[float]] = {}
|
||||||
|
neg_by_video: dict[str, set[float]] = {}
|
||||||
|
soft_by_video: dict[str, set[float]] = {}
|
||||||
|
source_by_filename: dict[str, str] = {}
|
||||||
|
|
||||||
|
for fn, st, op, sp in rows:
|
||||||
|
if sp:
|
||||||
|
source_by_filename[fn] = sp
|
||||||
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
|
if grandparent == positive_folder:
|
||||||
|
pos_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif negative_folder and grandparent == negative_folder:
|
||||||
|
neg_by_video.setdefault(fn, set()).add(st)
|
||||||
|
else:
|
||||||
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
|
# Include hard negatives from scan feedback
|
||||||
|
hard_rows = self._con.execute(
|
||||||
|
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||||
|
" WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
for fn, st, sp in hard_rows:
|
||||||
|
neg_by_video.setdefault(fn, set()).add(st)
|
||||||
|
if sp:
|
||||||
|
source_by_filename.setdefault(fn, sp)
|
||||||
|
|
||||||
|
# Remove positive times from soft/neg to avoid conflicting labels
|
||||||
|
for fn in pos_by_video:
|
||||||
|
if fn in soft_by_video:
|
||||||
|
soft_by_video[fn] -= pos_by_video[fn]
|
||||||
|
if fn in neg_by_video:
|
||||||
|
neg_by_video[fn] -= pos_by_video[fn]
|
||||||
|
|
||||||
|
# Deduplicate nearby markers (spread clips from same position)
|
||||||
|
def _dedup_times(times: set[float], min_gap: float = 8.0) -> list[float]:
|
||||||
|
if not times:
|
||||||
|
return []
|
||||||
|
ordered = sorted(times)
|
||||||
|
result = [ordered[0]]
|
||||||
|
for t in ordered[1:]:
|
||||||
|
if t - result[-1] >= min_gap:
|
||||||
|
result.append(t)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Include videos that have positives OR explicit negatives
|
||||||
|
all_videos = set(pos_by_video) | set(neg_by_video)
|
||||||
|
result = []
|
||||||
|
for fn in all_videos:
|
||||||
|
sp = source_by_filename.get(fn, "")
|
||||||
|
if not sp or not os.path.exists(sp):
|
||||||
|
if fallback_video_dir:
|
||||||
|
sp = os.path.join(fallback_video_dir, fn)
|
||||||
|
if not sp or not os.path.exists(sp):
|
||||||
|
continue
|
||||||
|
gt_pos = _dedup_times(pos_by_video.get(fn, set()))
|
||||||
|
gt_soft = _dedup_times(soft_by_video.get(fn, set()))
|
||||||
|
gt_neg = _dedup_times(neg_by_video.get(fn, set()))
|
||||||
|
result.append((sp, gt_pos, gt_soft, gt_neg))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_training_stats(self, profile: str,
|
||||||
|
include_scan_exports: bool = False) -> dict[str, dict]:
|
||||||
|
"""Return per-subprofile stats for training readiness display.
|
||||||
|
|
||||||
|
Returns dict mapping subprofile_name → {
|
||||||
|
'videos': number of distinct source videos,
|
||||||
|
'clips': total clip count,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return {}
|
||||||
|
if include_scan_exports:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT filename, output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT filename, output_path FROM processed"
|
||||||
|
" WHERE profile = ? AND scan_export = 0",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
folders = self.get_export_folders(profile)
|
||||||
|
stats: dict[str, dict] = {}
|
||||||
|
for folder_name in folders:
|
||||||
|
videos: set[str] = set()
|
||||||
|
clips = 0
|
||||||
|
for fn, op in rows:
|
||||||
|
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||||
|
if grandparent == folder_name:
|
||||||
|
videos.add(fn)
|
||||||
|
clips += 1
|
||||||
|
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
# ── Scan results ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||||
|
regions: list[tuple[float, float, float]]) -> None:
|
||||||
|
"""Replace scan results for (filename, profile, model) with new regions.
|
||||||
|
|
||||||
|
regions: list of (start_time, end_time, score).
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? AND model = ?",
|
||||||
|
(filename, profile, model),
|
||||||
|
)
|
||||||
|
self._con.executemany(
|
||||||
|
"INSERT INTO scan_results"
|
||||||
|
" (filename, profile, model, start_time, end_time, score,"
|
||||||
|
" orig_start_time, orig_end_time)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
[(filename, profile, model, s, e, sc, s, e) for s, e, sc in regions],
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_scan_results(self, filename: str, profile: str
|
||||||
|
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
||||||
|
"""Return scan results grouped by model.
|
||||||
|
|
||||||
|
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
||||||
|
sorted by start_time.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return {}
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT id, model, start_time, end_time, score, disabled,"
|
||||||
|
" orig_start_time, orig_end_time"
|
||||||
|
" FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ?"
|
||||||
|
" ORDER BY model, start_time",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
||||||
|
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
||||||
|
# Fall back to current bounds for legacy rows without orig
|
||||||
|
result.setdefault(model, []).append(
|
||||||
|
(row_id, s, e, sc, bool(dis), os_ if os_ is not None else s,
|
||||||
|
oe if oe is not None else e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_scan_result(self, row_id: int) -> None:
|
||||||
|
"""Delete a single scan result row."""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute("DELETE FROM scan_results WHERE id = ?", (row_id,))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def toggle_scan_result_disabled(self, row_id: int, disabled: bool) -> None:
|
||||||
|
"""Set disabled flag on a scan result row."""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"UPDATE scan_results SET disabled = ? WHERE id = ?",
|
||||||
|
(1 if disabled else 0, row_id),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def update_scan_result_times(self, row_id: int,
|
||||||
|
start: float, end: float) -> None:
|
||||||
|
"""Update start/end times of a scan result row (resize)."""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"UPDATE scan_results SET start_time = ?, end_time = ? WHERE id = ?",
|
||||||
|
(start, end, row_id),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_scan_models(self, filename: str, profile: str) -> list[str]:
|
||||||
|
"""Return model names that have scan results for this file."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT model FROM scan_results"
|
||||||
|
" WHERE filename = ? AND profile = ? ORDER BY model",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
def get_scanned_filenames(self, profile: str, model: str) -> set[str]:
|
||||||
|
"""Return filenames that already have scan results for this model."""
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT filename FROM scan_results"
|
||||||
|
" WHERE profile = ? AND model = ?",
|
||||||
|
(profile, model),
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def add_hard_negatives(self, filename: str, profile: str,
|
||||||
|
times: list[float], source_path: str = "") -> None:
|
||||||
|
"""Save timestamps as hard-negative training examples."""
|
||||||
|
if not self._enabled or not times:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
for t in times:
|
||||||
|
self._con.execute(
|
||||||
|
"INSERT INTO hard_negatives (filename, profile, start_time, source_path)"
|
||||||
|
" VALUES (?, ?, ?, ?)",
|
||||||
|
(filename, profile, t, source_path),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_hard_negative_times(self, filename: str, profile: str) -> set[float]:
|
||||||
|
"""Return start_times marked as hard negatives for this file."""
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT start_time FROM hard_negatives"
|
||||||
|
" WHERE filename = ? AND profile = ?",
|
||||||
|
(filename, profile),
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
def remove_hard_negatives(self, filename: str, profile: str,
|
||||||
|
times: list[float]) -> None:
|
||||||
|
"""Remove specific hard-negative timestamps."""
|
||||||
|
if not self._enabled or not times:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
for t in times:
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM hard_negatives"
|
||||||
|
" WHERE filename = ? AND profile = ? AND start_time = ?",
|
||||||
|
(filename, profile, t),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_training_filenames(self, profile: str) -> set[str]:
|
||||||
|
"""Return filenames used in training (have exported clips)."""
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT filename FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
# ── Hidden files ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def hide_file(self, filename: str, profile: str = "default") -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"INSERT OR IGNORE INTO hidden_files (filename, profile) VALUES (?, ?)",
|
||||||
|
(filename, profile),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def unhide_file(self, filename: str, profile: str = "default") -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
self._con.execute(
|
||||||
|
"DELETE FROM hidden_files WHERE filename = ? AND profile = ?",
|
||||||
|
(filename, profile),
|
||||||
|
)
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
|
def get_hidden_files(self, profile: str = "default") -> set[str]:
|
||||||
|
if not self._enabled:
|
||||||
|
return set()
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT filename FROM hidden_files WHERE profile = ?", (profile,)
|
||||||
|
).fetchall()
|
||||||
|
return {r[0] for r in rows}
|
||||||
+178
@@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
_RATIOS: dict[str, tuple[int, int]] = {
|
||||||
|
"9:16": (9, 16),
|
||||||
|
"4:5": (4, 5),
|
||||||
|
"1:1": (1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||||
|
"""Return an ffmpeg crop= filter expression for the given portrait ratio.
|
||||||
|
|
||||||
|
Uses ffmpeg expression syntax so source dimensions are resolved at runtime.
|
||||||
|
Commas inside min()/max() are escaped with \\, to prevent ffmpeg's
|
||||||
|
filtergraph parser from treating them as filter-chain separators.
|
||||||
|
"""
|
||||||
|
num, den = _RATIOS[ratio]
|
||||||
|
cw = f"ih*{num}/{den}"
|
||||||
|
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||||
|
return f"crop={cw}:ih:{x}:0"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_keyframe(
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
t: float,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||||
|
"""Return the latest keyframe at or before *t*, or None."""
|
||||||
|
result = None
|
||||||
|
for kf in keyframes:
|
||||||
|
if kf[0] <= t + tolerance:
|
||||||
|
result = kf
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_keyframes_to_jobs(
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
base_center: float,
|
||||||
|
base_ratio: str | None,
|
||||||
|
base_rand_p: bool,
|
||||||
|
base_rand_s: bool,
|
||||||
|
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||||
|
"""Resolve each job's crop state from keyframes, returning widened tuples.
|
||||||
|
|
||||||
|
Returns list of (start, path, ratio, center, rand_portrait, rand_square).
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for s, o, _r, _c in jobs:
|
||||||
|
kf = resolve_keyframe(keyframes, s)
|
||||||
|
if kf is not None:
|
||||||
|
_, center, ratio, rp, rs = kf
|
||||||
|
else:
|
||||||
|
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||||
|
result.append((s, o, ratio, center, rp, rs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _find_vaapi_device() -> str:
|
||||||
|
"""Return the first available VAAPI render device path (Linux)."""
|
||||||
|
import glob
|
||||||
|
devices = sorted(glob.glob("/dev/dri/renderD*"))
|
||||||
|
return devices[0] if devices else "/dev/dri/renderD128"
|
||||||
|
|
||||||
|
|
||||||
|
def build_ffmpeg_command(
|
||||||
|
input_path: str, start: float, output_path: str,
|
||||||
|
short_side: int | None = None,
|
||||||
|
portrait_ratio: str | None = None,
|
||||||
|
crop_center: float = 0.5,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
) -> list[str]:
|
||||||
|
# -ss before -i: fast input-seeking. Safe here because we always re-encode,
|
||||||
|
# so there is no keyframe-alignment issue from pre-input seek.
|
||||||
|
# Image sequences always use libwebp, so skip HW encoder setup.
|
||||||
|
use_hw_vaapi = (encoder == "h264_vaapi" and not image_sequence
|
||||||
|
and sys.platform == "linux")
|
||||||
|
cmd = [_bin("ffmpeg"), "-y"]
|
||||||
|
|
||||||
|
# VAAPI needs a render device for hardware context (Linux only).
|
||||||
|
if use_hw_vaapi:
|
||||||
|
vaapi_dev = _find_vaapi_device()
|
||||||
|
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||||
|
"-vaapi_device", vaapi_dev]
|
||||||
|
|
||||||
|
cmd += [
|
||||||
|
"-threads", "0",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", "8",
|
||||||
|
]
|
||||||
|
|
||||||
|
filters: list[str] = []
|
||||||
|
if portrait_ratio is not None:
|
||||||
|
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||||
|
if short_side is not None:
|
||||||
|
# Scale so the shorter dimension equals short_side.
|
||||||
|
filters.append(
|
||||||
|
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||||
|
)
|
||||||
|
|
||||||
|
# VAAPI: decoded frames are GPU surfaces. CPU filters need hwdownload first.
|
||||||
|
if use_hw_vaapi:
|
||||||
|
if filters:
|
||||||
|
filters.insert(0, "hwdownload")
|
||||||
|
filters.insert(1, "format=nv12")
|
||||||
|
filters.append("format=nv12")
|
||||||
|
filters.append("hwupload")
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
cmd += ["-vf", ",".join(filters)]
|
||||||
|
|
||||||
|
if image_sequence:
|
||||||
|
cmd += [
|
||||||
|
"-an",
|
||||||
|
"-c:v", "libwebp",
|
||||||
|
"-quality", "92",
|
||||||
|
"-compression_level", "1",
|
||||||
|
os.path.join(output_path, "frame_%04d.webp"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||||
|
"""Return an ffmpeg command that extracts audio to <sequence_dir>.wav."""
|
||||||
|
audio_path = sequence_dir + ".wav"
|
||||||
|
return [
|
||||||
|
_bin("ffmpeg"), "-y",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", "8",
|
||||||
|
"-vn",
|
||||||
|
"-c:a", "pcm_s16le",
|
||||||
|
audio_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_hw_encoders() -> list[str]:
|
||||||
|
"""Probe ffmpeg for available H.264 hardware encoders.
|
||||||
|
|
||||||
|
Returns only encoders relevant to the current platform:
|
||||||
|
- Windows: h264_nvenc, h264_qsv, h264_amf
|
||||||
|
- Linux: h264_nvenc, h264_vaapi, h264_qsv
|
||||||
|
- macOS: h264_videotoolbox
|
||||||
|
"""
|
||||||
|
if sys.platform == "win32":
|
||||||
|
candidates = ["h264_nvenc", "h264_qsv", "h264_amf"]
|
||||||
|
elif sys.platform == "darwin":
|
||||||
|
candidates = ["h264_videotoolbox"]
|
||||||
|
else:
|
||||||
|
candidates = ["h264_nvenc", "h264_vaapi", "h264_qsv"]
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return []
|
||||||
|
output = result.stdout
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
available = [enc for enc in candidates if re.search(rf'\b{enc}\b', output)]
|
||||||
|
if available:
|
||||||
|
_log(f"HW encoders detected: {', '.join(available)}")
|
||||||
|
else:
|
||||||
|
_log("No HW encoders detected — GPU export unavailable")
|
||||||
|
return available
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _frozen_path() -> Path:
|
||||||
|
if getattr(sys, "frozen", False):
|
||||||
|
return Path(sys._MEIPASS)
|
||||||
|
return Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _bin(name: str) -> str:
|
||||||
|
"""Resolve a binary name (e.g. 'ffmpeg') to its full path in frozen builds."""
|
||||||
|
p = _frozen_path() / name
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
return name # fall back to PATH
|
||||||
|
|
||||||
|
|
||||||
|
def _log(*args) -> None:
|
||||||
|
"""Print a timestamped log line to stderr."""
|
||||||
|
ts = datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name + ".mp4")
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
m = int(seconds // 60)
|
||||||
|
# Floor-truncate to 1 dp (not round) — prevents "X:60.0" rollover when
|
||||||
|
# seconds is e.g. 59.95.
|
||||||
|
s = int(seconds % 60 * 10) / 10
|
||||||
|
return f"{m}:{s:04.1f}"
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_yolo():
|
||||||
|
"""Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed."""
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model is None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
_yolo_model = YOLO("yolov8n.pt")
|
||||||
|
_log("YOLO model loaded")
|
||||||
|
except ImportError:
|
||||||
|
_log("ultralytics not installed — tracking disabled")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"YOLO load failed: {e}")
|
||||||
|
return None
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frame_cv(video_path: str, time: float):
|
||||||
|
"""Extract a single frame as a numpy array (BGR) via ffmpeg -> temp PNG -> cv2."""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||||
|
os.close(fd)
|
||||||
|
try:
|
||||||
|
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||||
|
"-frames:v", "1", tmp]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return cv2.imread(tmp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_subject_center(
|
||||||
|
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||||
|
) -> tuple[int | None, float, float] | None:
|
||||||
|
"""Detect objects at *time* and return (class_id, norm_x, norm_y) of the
|
||||||
|
best match to (target_cls, last_x, last_y). Returns None on failure."""
|
||||||
|
model = _get_yolo()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
frame = extract_frame_cv(video_path, time)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
results = model(frame, verbose=False)
|
||||||
|
if not results or len(results[0].boxes) == 0:
|
||||||
|
return None
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
dets = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
cx = (x1 + x2) / 2 / w
|
||||||
|
cy = (y1 + y2) / 2 / h
|
||||||
|
dets.append((cls, cx, cy))
|
||||||
|
# Prefer same class, nearest to last known position.
|
||||||
|
def score(d):
|
||||||
|
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||||
|
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||||
|
return cls_penalty + dist
|
||||||
|
best = min(dets, key=score)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def track_centers_for_jobs(
|
||||||
|
video_path: str, cursor: float, crop_center: float,
|
||||||
|
starts: list[float],
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run detection at the cursor (to identify the target) then at each start
|
||||||
|
time. Returns a list of horizontal crop centers (one per start)."""
|
||||||
|
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||||
|
if ref is None:
|
||||||
|
_log("Tracking: no detection at cursor, using fixed center")
|
||||||
|
return [crop_center] * len(starts)
|
||||||
|
target_cls, last_x, last_y = ref
|
||||||
|
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||||
|
centers = []
|
||||||
|
for t in starts:
|
||||||
|
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||||
|
if det is not None:
|
||||||
|
_, cx, cy = det
|
||||||
|
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||||
|
centers.append(cx)
|
||||||
|
last_x, last_y = cx, cy
|
||||||
|
else:
|
||||||
|
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||||
|
centers.append(last_x)
|
||||||
|
return centers
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
# 8-cut Client Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Build a Tauri + Svelte desktop client that connects to the 8-cut server API for remote video editing. Full feature parity with the Qt app. Targets Linux first, then Mac.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Tauri app (Rust shell + Svelte webview)
|
||||||
|
├── mpv sidecar (bundled binary)
|
||||||
|
│ ├── plays video: http://server/api/stream/{path}?quality=low
|
||||||
|
│ ├── plays audio: http://server/api/audio/{path}
|
||||||
|
│ └── controlled via JSON IPC socket
|
||||||
|
├── Svelte UI
|
||||||
|
│ ├── File browser
|
||||||
|
│ ├── Canvas timeline (markers, cursor, play region)
|
||||||
|
│ ├── Canvas crop overlay
|
||||||
|
│ ├── Export controls + WebSocket progress
|
||||||
|
│ └── Settings panel (profile, subprofiles, quality)
|
||||||
|
└── Rust backend
|
||||||
|
├── Spawn/manage mpv process + IPC
|
||||||
|
├── Proxy server API calls (avoid CORS)
|
||||||
|
└── Tauri commands exposed to Svelte frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
## Playback
|
||||||
|
|
||||||
|
mpv runs as a sidecar process, controlled via JSON IPC socket. Two streams:
|
||||||
|
- Video: `http://server/api/stream/{path}?root={root}&quality={quality}` (transcoded, no audio)
|
||||||
|
- Audio: `http://server/api/audio/{path}?root={root}` (full quality WAV)
|
||||||
|
|
||||||
|
mpv's `--audio-file=` flag syncs both streams with frame-accurate seeking.
|
||||||
|
|
||||||
|
Quality presets: potato (480p), low (720p), medium (1080p), high (original).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### File management
|
||||||
|
- Browse server video roots (`GET /api/roots`, `GET /api/files`)
|
||||||
|
- Hide/unhide files per profile (`POST/DELETE /api/hidden/{filename}`)
|
||||||
|
- Sort by name/size, filter hidden
|
||||||
|
|
||||||
|
### Playback
|
||||||
|
- Play/pause/resume from pause point
|
||||||
|
- AB-loop with current spread/clips settings
|
||||||
|
- Play region adapts to spread changes without restarting
|
||||||
|
- Quality selector
|
||||||
|
|
||||||
|
### Timeline (Canvas)
|
||||||
|
- Cursor position, markers, play position indicator
|
||||||
|
- Click to seek, drag cursor
|
||||||
|
- Lock mode: cursor locked to marker, double-click jumps to end of clip span
|
||||||
|
- Autoclip: when paused, auto-adjust clip count to fit pause position
|
||||||
|
|
||||||
|
### Crop & keyframes
|
||||||
|
- Portrait ratio selector (9:16, 4:5, 1:1, off)
|
||||||
|
- Crop center slider with live canvas overlay
|
||||||
|
- Crop keyframes at arbitrary timeline positions
|
||||||
|
- Subject tracking (triggered server-side)
|
||||||
|
- Random portrait/square toggles
|
||||||
|
|
||||||
|
### Export
|
||||||
|
- Configurable: clips, spread, short side, format (MP4/WebP sequence)
|
||||||
|
- Label + category annotation
|
||||||
|
- Encoder selection (libx264 / h264_nvenc)
|
||||||
|
- Subprofiles with folder suffix routing
|
||||||
|
- Number keys 1-9 for subprofile quick export, E for main
|
||||||
|
- WebSocket progress (`WS /ws/export`), per-clip completion
|
||||||
|
- Delete/re-export from marker context menu
|
||||||
|
|
||||||
|
### Profiles
|
||||||
|
- Profile switcher, markers reload per profile
|
||||||
|
- Subprofile management (add/remove)
|
||||||
|
|
||||||
|
### Settings
|
||||||
|
- Server URL (configurable)
|
||||||
|
- Default quality preset
|
||||||
|
- All settings persisted client-side via Tauri store
|
||||||
|
|
||||||
|
## Server API endpoints used
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/roots
|
||||||
|
GET /api/files?root={root}
|
||||||
|
GET /api/video/{path}?root={root}
|
||||||
|
GET /api/stream/{path}?root={root}&quality={quality}
|
||||||
|
GET /api/audio/{path}?root={root}
|
||||||
|
GET /api/cache/status/{path}?root={root}
|
||||||
|
GET /api/markers/{filename}?profile={profile}
|
||||||
|
GET /api/profiles
|
||||||
|
GET /api/labels
|
||||||
|
POST /api/export
|
||||||
|
GET /api/export/{job_id}
|
||||||
|
DELETE /api/export?output_path={path}
|
||||||
|
POST /api/hidden/{filename}?profile={profile}
|
||||||
|
DELETE /api/hidden/{filename}?profile={profile}
|
||||||
|
GET /api/hidden?profile={profile}
|
||||||
|
WS /ws/export
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
client/
|
||||||
|
├── src-tauri/
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── main.rs (Tauri entry, app setup)
|
||||||
|
│ │ ├── mpv.rs (mpv sidecar spawn + IPC)
|
||||||
|
│ │ ├── commands.rs (Tauri commands for Svelte)
|
||||||
|
│ │ └── lib.rs
|
||||||
|
│ ├── Cargo.toml
|
||||||
|
│ └── tauri.conf.json
|
||||||
|
├── src/
|
||||||
|
│ ├── App.svelte
|
||||||
|
│ ├── lib/
|
||||||
|
│ │ ├── api.ts (server API client)
|
||||||
|
│ │ ├── mpv.ts (mpv IPC bridge via Tauri commands)
|
||||||
|
│ │ ├── ws.ts (WebSocket export progress)
|
||||||
|
│ │ └── stores.ts (Svelte stores: files, markers, settings)
|
||||||
|
│ ├── components/
|
||||||
|
│ │ ├── FileBrowser.svelte
|
||||||
|
│ │ ├── Timeline.svelte
|
||||||
|
│ │ ├── CropOverlay.svelte
|
||||||
|
│ │ ├── ExportPanel.svelte
|
||||||
|
│ │ ├── SettingsPanel.svelte
|
||||||
|
│ │ └── ProfileBar.svelte
|
||||||
|
│ └── main.ts
|
||||||
|
├── package.json
|
||||||
|
└── vite.config.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation order
|
||||||
|
|
||||||
|
1. Scaffold Tauri + Svelte project
|
||||||
|
2. mpv sidecar: spawn, IPC, basic play/pause/seek
|
||||||
|
3. API client module + server connection
|
||||||
|
4. File browser component
|
||||||
|
5. Video playback: load file → stream URL → mpv
|
||||||
|
6. Canvas timeline: cursor, seek, markers
|
||||||
|
7. Export panel + WebSocket progress
|
||||||
|
8. Crop overlay + keyframes
|
||||||
|
9. Lock mode, autoclip, play region
|
||||||
|
10. Profiles, subprofiles, hidden files
|
||||||
|
11. Keyboard shortcuts
|
||||||
|
12. Settings persistence
|
||||||
|
13. Package for Linux (.deb / .AppImage)
|
||||||
|
14. Package for Mac (.dmg)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,207 @@
|
|||||||
|
# 8-cut Server API Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Run 8-cut as a FastAPI server on Unraid (Docker) so a Tauri desktop client on Mac can edit remotely over WireGuard — no file transfers, no auth.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Unraid (Docker container):
|
||||||
|
FastAPI + ffmpeg + SQLite
|
||||||
|
├── /api/files list videos from mounted volumes
|
||||||
|
├── /api/stream/{path} transcoded video (cached, no audio)
|
||||||
|
├── /api/audio/{path} full-quality audio (cached, passthrough)
|
||||||
|
├── /api/video/{path} raw file (for reference/download)
|
||||||
|
├── /api/markers CRUD markers per profile
|
||||||
|
├── /api/profiles list/create profiles
|
||||||
|
├── /api/export trigger + manage exports
|
||||||
|
├── /api/labels label history
|
||||||
|
├── /api/hidden hidden file management
|
||||||
|
└── ws://…/ws/export real-time export progress
|
||||||
|
|
||||||
|
Mac (Tauri + Svelte + libmpv):
|
||||||
|
├── mpv plays stream URL (video) + audio URL separately
|
||||||
|
├── Canvas timeline + crop overlay + keyframes
|
||||||
|
├── Full UI: profiles, subprofiles, settings
|
||||||
|
└── Stateless — all state lives on server
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker mounts
|
||||||
|
|
||||||
|
| Mount | Purpose | Env var |
|
||||||
|
|-------------|--------------------------------|--------------|
|
||||||
|
| `/videos` | Source video files (read-only) | `MEDIA_DIRS` |
|
||||||
|
| `/exports` | Export output | `EXPORT_DIR` |
|
||||||
|
| `/data` | SQLite DB + transcode cache | `DB_PATH`, `CACHE_DIR` |
|
||||||
|
|
||||||
|
`MEDIA_DIRS` supports multiple paths: `/videos1,/videos2`.
|
||||||
|
|
||||||
|
## Video streaming with transcode cache
|
||||||
|
|
||||||
|
The client needs low-bitrate video for scrubbing over the network but full-quality audio for accurate editing.
|
||||||
|
|
||||||
|
**Flow:**
|
||||||
|
1. Client requests `/api/stream/{path}?quality=low`
|
||||||
|
2. Server checks cache: `{CACHE_DIR}/{quality}/{hash}.mp4`
|
||||||
|
3. If cached → serve with range requests (instant seeking)
|
||||||
|
4. If not → start background ffmpeg transcode, return `202 Accepted` with job ID
|
||||||
|
5. Client polls or gets WebSocket notification when ready
|
||||||
|
6. Audio: `/api/audio/{path}` extracts audio (passthrough, fast) to cache on first request
|
||||||
|
|
||||||
|
**Quality presets:**
|
||||||
|
|
||||||
|
| Preset | Resolution | Bitrate |
|
||||||
|
|----------|-----------|----------|
|
||||||
|
| `potato` | 480p | ~500 Kbps |
|
||||||
|
| `low` | 720p | ~2 Mbps |
|
||||||
|
| `medium` | 1080p | ~5 Mbps |
|
||||||
|
| `high` | original | ~10 Mbps |
|
||||||
|
|
||||||
|
Each quality level cached separately. Client can switch quality — mpv reloads the URL.
|
||||||
|
|
||||||
|
**mpv on client:**
|
||||||
|
```
|
||||||
|
video = http://server/api/stream/file.mp4?quality=low
|
||||||
|
audio = http://server/api/audio/file.mp4
|
||||||
|
```
|
||||||
|
mpv's `--audio-file=` flag plays both in sync with frame-accurate seeking.
|
||||||
|
|
||||||
|
## API endpoints
|
||||||
|
|
||||||
|
### Files
|
||||||
|
```
|
||||||
|
GET /api/files?root={root}
|
||||||
|
→ [{path, name, size, duration?, markers_count}]
|
||||||
|
|
||||||
|
GET /api/video/{path}
|
||||||
|
→ raw file with range requests
|
||||||
|
|
||||||
|
GET /api/stream/{path}?quality=low|medium|high|potato
|
||||||
|
→ cached transcoded video (no audio), range requests
|
||||||
|
→ 202 if transcode in progress
|
||||||
|
|
||||||
|
GET /api/audio/{path}
|
||||||
|
→ cached full-quality audio, range requests
|
||||||
|
→ 202 if extraction in progress
|
||||||
|
|
||||||
|
GET /api/cache/status/{path}
|
||||||
|
→ {qualities: {potato: "ready", low: "transcoding", ...}, audio: "ready"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Markers & profiles
|
||||||
|
```
|
||||||
|
GET /api/markers/{filename}?profile=default
|
||||||
|
→ [{start_time, marker_number, output_path}]
|
||||||
|
|
||||||
|
GET /api/profiles
|
||||||
|
→ ["default", "intense", ...]
|
||||||
|
|
||||||
|
GET /api/labels
|
||||||
|
→ ["dog barking", "rain", ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export
|
||||||
|
```
|
||||||
|
POST /api/export
|
||||||
|
body: {input_path, cursor, folder_suffix?, name, clips, spread,
|
||||||
|
short_side?, portrait_ratio?, crop_center, format,
|
||||||
|
label?, category?, profile, crop_keyframes?,
|
||||||
|
rand_portrait?, rand_square?, track_subject?}
|
||||||
|
→ {job_id}
|
||||||
|
|
||||||
|
GET /api/export/{job_id}
|
||||||
|
→ {status, completed, total, outputs: [...]}
|
||||||
|
|
||||||
|
DELETE /api/export/{output_path}
|
||||||
|
→ delete from DB + disk
|
||||||
|
|
||||||
|
WS /ws/export
|
||||||
|
→ server pushes: {type: "clip_done", path: "..."} | {type: "all_done"} | {type: "error", msg: "..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hidden files
|
||||||
|
```
|
||||||
|
POST /api/hidden/{filename}?profile=default
|
||||||
|
DELETE /api/hidden/{filename}?profile=default
|
||||||
|
GET /api/hidden?profile=default
|
||||||
|
→ ["file1.mp4", "file2.mp4"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code reuse from main.py
|
||||||
|
|
||||||
|
**Extracted to shared module (used by both server and Qt app):**
|
||||||
|
- `ProcessedDB` — SQLite operations
|
||||||
|
- `build_ffmpeg_command` — ffmpeg command construction
|
||||||
|
- `build_audio_extract_command`
|
||||||
|
- `build_export_path` / `build_sequence_dir`
|
||||||
|
- `detect_hw_encoders`
|
||||||
|
- `upsert_clip_annotation` / `remove_clip_annotation`
|
||||||
|
- `apply_keyframes_to_jobs` / `resolve_keyframe`
|
||||||
|
- `track_centers_for_jobs` (subject tracking)
|
||||||
|
|
||||||
|
**Server-specific (new):**
|
||||||
|
- FastAPI app + route handlers
|
||||||
|
- Transcode cache manager
|
||||||
|
- Export worker (plain threading, replaces QThread-based ExportWorker)
|
||||||
|
- File listing / media root scanning
|
||||||
|
- WebSocket export progress broadcaster
|
||||||
|
|
||||||
|
**Tauri client (new, Svelte):**
|
||||||
|
- mpv integration via Tauri plugin or sidecar
|
||||||
|
- Canvas-based timeline widget
|
||||||
|
- Canvas-based crop overlay
|
||||||
|
- All UI controls
|
||||||
|
- API client module
|
||||||
|
|
||||||
|
## Dockerfile
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM python:3.12-slim
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
WORKDIR /app
|
||||||
|
COPY server/ .
|
||||||
|
RUN pip install --no-cache-dir fastapi uvicorn
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
8-cut/
|
||||||
|
├── main.py (existing Qt app, unchanged)
|
||||||
|
├── core/ (shared logic, extracted from main.py)
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── db.py (ProcessedDB)
|
||||||
|
│ ├── ffmpeg.py (build commands, detect encoders)
|
||||||
|
│ ├── export.py (ExportWorker — plain threading)
|
||||||
|
│ ├── paths.py (build_export_path, build_sequence_dir)
|
||||||
|
│ └── annotations.py (dataset.json helpers)
|
||||||
|
├── server/
|
||||||
|
│ ├── app.py (FastAPI app)
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── files.py
|
||||||
|
│ │ ├── stream.py
|
||||||
|
│ │ ├── markers.py
|
||||||
|
│ │ ├── export.py
|
||||||
|
│ │ └── hidden.py
|
||||||
|
│ ├── cache.py (transcode cache manager)
|
||||||
|
│ ├── ws.py (WebSocket handler)
|
||||||
|
│ └── config.py (env vars, settings)
|
||||||
|
├── client/ (Tauri + Svelte — future)
|
||||||
|
│ └── ...
|
||||||
|
├── Dockerfile
|
||||||
|
└── docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation order
|
||||||
|
|
||||||
|
1. Extract shared logic from main.py → `core/`
|
||||||
|
2. Update main.py to import from `core/` (verify Qt app still works)
|
||||||
|
3. Build FastAPI server with file listing + video serving
|
||||||
|
4. Add transcode cache + audio extraction
|
||||||
|
5. Add markers/profiles/labels/hidden API
|
||||||
|
6. Add export endpoint + WebSocket progress
|
||||||
|
7. Dockerfile + docker-compose
|
||||||
|
8. (Later) Tauri client
|
||||||
@@ -0,0 +1,948 @@
|
|||||||
|
# Server API Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Extract shared logic from main.py into a `core/` package, then build the FastAPI server that serves video files, manages the DB, and runs exports.
|
||||||
|
|
||||||
|
**Architecture:** Shared logic (DB, ffmpeg, paths, annotations, tracking) moves to `core/`. Both `main.py` (Qt app) and `server/` import from `core/`. The server adds HTTP video streaming with transcode cache, REST endpoints, and WebSocket export progress.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.12, FastAPI, uvicorn, SQLite, ffmpeg
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Create core/ package — paths and helpers
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/__init__.py`
|
||||||
|
- Create: `core/paths.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/__init__.py**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create core/paths.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _frozen_path() -> Path:
|
||||||
|
if getattr(sys, "frozen", False):
|
||||||
|
return Path(sys._MEIPASS)
|
||||||
|
return Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _bin(name: str) -> str:
|
||||||
|
p = _frozen_path() / name
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _log(*args) -> None:
|
||||||
|
ts = datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name + ".mp4")
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
m = int(seconds // 60)
|
||||||
|
s = int(seconds % 60 * 10) / 10
|
||||||
|
return f"{m}:{s:04.1f}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/
|
||||||
|
git commit -m "feat: create core/paths module with shared path helpers"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Create core/ffmpeg.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/ffmpeg.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/ffmpeg.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 77-112 and 244-289: `_RATIOS`, `_portrait_crop_filter`, `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`, `detect_hw_encoders`. (Lines 115-188 are also ffmpeg-related. Lines 191-241 are annotations — extracted separately in Task 4.)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
_RATIOS: dict[str, tuple[int, int]] = {
|
||||||
|
"9:16": (9, 16),
|
||||||
|
"4:5": (4, 5),
|
||||||
|
"1:1": (1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||||
|
num, den = _RATIOS[ratio]
|
||||||
|
cw = f"ih*{num}/{den}"
|
||||||
|
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||||
|
return f"crop={cw}:ih:{x}:0"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_keyframe(
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
t: float,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||||
|
result = None
|
||||||
|
for kf in keyframes:
|
||||||
|
if kf[0] <= t + tolerance:
|
||||||
|
result = kf
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_keyframes_to_jobs(
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
base_center: float,
|
||||||
|
base_ratio: str | None,
|
||||||
|
base_rand_p: bool,
|
||||||
|
base_rand_s: bool,
|
||||||
|
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||||
|
result = []
|
||||||
|
for s, o, _r, _c in jobs:
|
||||||
|
kf = resolve_keyframe(keyframes, s)
|
||||||
|
if kf is not None:
|
||||||
|
_, center, ratio, rp, rs = kf
|
||||||
|
else:
|
||||||
|
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||||
|
result.append((s, o, ratio, center, rp, rs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_ffmpeg_command(
|
||||||
|
input_path: str, start: float, output_path: str,
|
||||||
|
short_side: int | None = None,
|
||||||
|
portrait_ratio: str | None = None,
|
||||||
|
crop_center: float = 0.5,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
) -> list[str]:
|
||||||
|
use_hw_vaapi = encoder == "h264_vaapi" and not image_sequence
|
||||||
|
cmd = [_bin("ffmpeg"), "-y"]
|
||||||
|
if use_hw_vaapi:
|
||||||
|
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||||
|
"-vaapi_device", "/dev/dri/renderD128"]
|
||||||
|
cmd += ["-threads", "0", "-ss", str(start), "-i", input_path, "-t", "8"]
|
||||||
|
filters: list[str] = []
|
||||||
|
if portrait_ratio is not None:
|
||||||
|
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||||
|
if short_side is not None:
|
||||||
|
filters.append(
|
||||||
|
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||||
|
)
|
||||||
|
if use_hw_vaapi:
|
||||||
|
if filters:
|
||||||
|
filters.insert(0, "hwdownload")
|
||||||
|
filters.insert(1, "format=nv12")
|
||||||
|
filters.append("format=nv12")
|
||||||
|
filters.append("hwupload")
|
||||||
|
if filters:
|
||||||
|
cmd += ["-vf", ",".join(filters)]
|
||||||
|
if image_sequence:
|
||||||
|
cmd += ["-an", "-c:v", "libwebp", "-quality", "92", "-compression_level", "1",
|
||||||
|
os.path.join(output_path, "frame_%04d.webp")]
|
||||||
|
else:
|
||||||
|
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||||
|
audio_path = sequence_dir + ".wav"
|
||||||
|
return [_bin("ffmpeg"), "-y", "-ss", str(start), "-i", input_path,
|
||||||
|
"-t", "8", "-vn", "-c:a", "pcm_s16le", audio_path]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_hw_encoders() -> list[str]:
|
||||||
|
_HW_ENCODERS = ["h264_nvenc", "h264_vaapi", "h264_qsv", "h264_amf", "h264_videotoolbox"]
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return []
|
||||||
|
output = result.stdout
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
available = []
|
||||||
|
for enc in _HW_ENCODERS:
|
||||||
|
if re.search(rf'\b{enc}\b', output):
|
||||||
|
available.append(enc)
|
||||||
|
if available:
|
||||||
|
_log(f"HW encoders detected: {', '.join(available)}")
|
||||||
|
else:
|
||||||
|
_log("No HW encoders detected — GPU export unavailable")
|
||||||
|
return available
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/ffmpeg.py
|
||||||
|
git commit -m "feat: create core/ffmpeg module with ffmpeg helpers"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Create core/db.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/db.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/db.py**
|
||||||
|
|
||||||
|
Extract the entire `ProcessedDB` class from main.py lines 398-626. Import `_log` from `core.paths`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .paths import _log
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessedDB:
|
||||||
|
_SCHEMA_VERSION = 3
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | None = None):
|
||||||
|
# ... exact copy of existing class ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Copy the full class body verbatim — all methods unchanged.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/db.py
|
||||||
|
git commit -m "feat: create core/db module with ProcessedDB"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Create core/annotations.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/annotations.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/annotations.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 191-241: `build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def build_annotation_json_path(folder: str) -> str:
|
||||||
|
return os.path.join(folder, "dataset.json")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
if not os.path.exists(json_path):
|
||||||
|
return
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return
|
||||||
|
entries = [e for e in entries if e.get("path") != abs_path]
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||||
|
if not label.strip():
|
||||||
|
return
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
entries: list[dict] = []
|
||||||
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
entries = []
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
entry: dict = {"path": abs_path, "label": label}
|
||||||
|
for i, e in enumerate(entries):
|
||||||
|
if e.get("path") == abs_path:
|
||||||
|
entries[i] = entry
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
entries.append(entry)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/annotations.py
|
||||||
|
git commit -m "feat: create core/annotations module"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Create core/export.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/export.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/export.py**
|
||||||
|
|
||||||
|
A plain-threading version of `ExportWorker` (no QThread dependency). Used by the server. The Qt app continues using its own QThread-based worker.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .ffmpeg import build_ffmpeg_command, build_audio_extract_command
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
class ExportRunner:
|
||||||
|
"""Run ffmpeg export jobs in a background thread pool.
|
||||||
|
|
||||||
|
Callbacks:
|
||||||
|
on_clip_done(path: str)
|
||||||
|
on_all_done()
|
||||||
|
on_error(msg: str)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_path: str,
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
short_side: int | None = None,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
max_workers: int | None = None,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
on_clip_done: Callable[[str], None] | None = None,
|
||||||
|
on_all_done: Callable[[], None] | None = None,
|
||||||
|
on_error: Callable[[str], None] | None = None,
|
||||||
|
):
|
||||||
|
self._input = input_path
|
||||||
|
self._jobs = jobs
|
||||||
|
self._short_side = short_side
|
||||||
|
self._image_sequence = image_sequence
|
||||||
|
self._max_workers = max_workers
|
||||||
|
self._encoder = encoder
|
||||||
|
self._on_clip_done = on_clip_done
|
||||||
|
self._on_all_done = on_all_done
|
||||||
|
self._on_error = on_error
|
||||||
|
self._cancel = False
|
||||||
|
self._procs: list[subprocess.Popen] = []
|
||||||
|
self._procs_lock = threading.Lock()
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
self._cancel = True
|
||||||
|
with self._procs_lock:
|
||||||
|
for proc in self._procs:
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._thread is not None and self._thread.is_alive()
|
||||||
|
|
||||||
|
def _run_one(self, start: float, output: str,
|
||||||
|
portrait_ratio: str | None, crop_center: float) -> str:
|
||||||
|
if self._cancel:
|
||||||
|
raise RuntimeError("cancelled")
|
||||||
|
if self._image_sequence:
|
||||||
|
os.makedirs(output, exist_ok=True)
|
||||||
|
cmd = build_ffmpeg_command(
|
||||||
|
self._input, start, output,
|
||||||
|
short_side=self._short_side,
|
||||||
|
portrait_ratio=portrait_ratio,
|
||||||
|
crop_center=crop_center,
|
||||||
|
image_sequence=self._image_sequence,
|
||||||
|
encoder=self._encoder,
|
||||||
|
)
|
||||||
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
with self._procs_lock:
|
||||||
|
self._procs.append(proc)
|
||||||
|
try:
|
||||||
|
_, stderr = proc.communicate(timeout=120)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
raise RuntimeError("ffmpeg timed out")
|
||||||
|
finally:
|
||||||
|
with self._procs_lock:
|
||||||
|
self._procs.remove(proc)
|
||||||
|
if self._cancel:
|
||||||
|
raise RuntimeError("cancelled")
|
||||||
|
if proc.returncode != 0:
|
||||||
|
msg = stderr.decode(errors='replace')[-500:] if stderr else "ffmpeg failed"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
if self._image_sequence:
|
||||||
|
audio_cmd = build_audio_extract_command(self._input, start, output)
|
||||||
|
subprocess.run(audio_cmd, capture_output=True, text=True, timeout=60)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
cap = self._max_workers or (os.cpu_count() or 2)
|
||||||
|
workers = min(len(self._jobs), cap)
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(self._run_one, s, o, pr, cc): o
|
||||||
|
for s, o, pr, cc in self._jobs
|
||||||
|
}
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
if self._cancel:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
path = fut.result()
|
||||||
|
if self._on_clip_done:
|
||||||
|
self._on_clip_done(path)
|
||||||
|
except Exception as e:
|
||||||
|
if "cancelled" not in str(e) and self._on_error:
|
||||||
|
self._on_error(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
if self._on_error:
|
||||||
|
self._on_error(str(e))
|
||||||
|
return
|
||||||
|
if self._cancel:
|
||||||
|
return
|
||||||
|
if self._on_all_done:
|
||||||
|
self._on_all_done()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/export.py
|
||||||
|
git commit -m "feat: create core/export module with ExportRunner"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Create core/tracking.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/tracking.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/tracking.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 294-395: YOLO tracking functions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_yolo():
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model is None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
_yolo_model = YOLO("yolov8n.pt")
|
||||||
|
_log("YOLO model loaded")
|
||||||
|
except ImportError:
|
||||||
|
_log("ultralytics not installed — tracking disabled")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"YOLO load failed: {e}")
|
||||||
|
return None
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frame_cv(video_path: str, time: float):
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||||
|
os.close(fd)
|
||||||
|
try:
|
||||||
|
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||||
|
"-frames:v", "1", tmp]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return cv2.imread(tmp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_subject_center(
|
||||||
|
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||||
|
) -> tuple[int | None, float, float] | None:
|
||||||
|
model = _get_yolo()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
frame = extract_frame_cv(video_path, time)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
results = model(frame, verbose=False)
|
||||||
|
if not results or len(results[0].boxes) == 0:
|
||||||
|
return None
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
dets = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
cx = (x1 + x2) / 2 / w
|
||||||
|
cy = (y1 + y2) / 2 / h
|
||||||
|
dets.append((cls, cx, cy))
|
||||||
|
def score(d):
|
||||||
|
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||||
|
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||||
|
return cls_penalty + dist
|
||||||
|
best = min(dets, key=score)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def track_centers_for_jobs(
|
||||||
|
video_path: str, cursor: float, crop_center: float,
|
||||||
|
starts: list[float],
|
||||||
|
) -> list[float]:
|
||||||
|
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||||
|
if ref is None:
|
||||||
|
_log("Tracking: no detection at cursor, using fixed center")
|
||||||
|
return [crop_center] * len(starts)
|
||||||
|
target_cls, last_x, last_y = ref
|
||||||
|
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||||
|
centers = []
|
||||||
|
for t in starts:
|
||||||
|
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||||
|
if det is not None:
|
||||||
|
_, cx, cy = det
|
||||||
|
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||||
|
centers.append(cx)
|
||||||
|
last_x, last_y = cx, cy
|
||||||
|
else:
|
||||||
|
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||||
|
centers.append(last_x)
|
||||||
|
return centers
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/tracking.py
|
||||||
|
git commit -m "feat: create core/tracking module with YOLO subject tracking"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Update main.py to import from core/
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py`
|
||||||
|
|
||||||
|
**Step 1: Replace function definitions with imports**
|
||||||
|
|
||||||
|
At the top of main.py, after the existing stdlib imports (line 17), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from core.paths import _bin, _log, build_export_path, build_sequence_dir, format_time
|
||||||
|
from core.ffmpeg import (
|
||||||
|
_RATIOS, resolve_keyframe, apply_keyframes_to_jobs,
|
||||||
|
build_ffmpeg_command, build_audio_extract_command, detect_hw_encoders,
|
||||||
|
)
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
from core.annotations import remove_clip_annotation, upsert_clip_annotation
|
||||||
|
from core.tracking import track_centers_for_jobs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Delete the extracted function definitions and dead imports**
|
||||||
|
|
||||||
|
Remove definitions from main.py:
|
||||||
|
- Lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`
|
||||||
|
- Lines 77-188: `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`
|
||||||
|
- Lines 191-241: annotation functions (`build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`)
|
||||||
|
- Lines 244-289: `detect_hw_encoders`, `_RATIOS`, `_portrait_crop_filter`
|
||||||
|
- Lines 294-395: tracking functions (`_yolo_model`, `_get_yolo`, `extract_frame_cv`, `detect_subject_center`, `track_centers_for_jobs`)
|
||||||
|
- Lines 398-626: `ProcessedDB` class
|
||||||
|
|
||||||
|
Remove now-dead stdlib imports from the top of main.py:
|
||||||
|
- `re` (only used in `detect_hw_encoders`)
|
||||||
|
- `json` (only used in annotation functions)
|
||||||
|
- `sqlite3` (only used in `ProcessedDB`)
|
||||||
|
- `tempfile` (only used in `extract_frame_cv`)
|
||||||
|
- `datetime`, `timezone` from the datetime import (only used in `_log` and `ProcessedDB`)
|
||||||
|
|
||||||
|
Keep in main.py:
|
||||||
|
- `_SELVA_CATEGORIES` (UI constant, line 291)
|
||||||
|
- `_RATIOS` reference — imported from core.ffmpeg
|
||||||
|
- `ExportWorker` (QThread-based, stays in main.py — the server uses `core.export.ExportRunner` instead)
|
||||||
|
- `_DBWorker` and `FrameGrabber` (QThread-based, stay in main.py)
|
||||||
|
|
||||||
|
**Step 3: Verify Qt app still works**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Open a video, export a clip, check markers — verify nothing broke.
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "refactor: import shared logic from core/ instead of inline definitions"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: Create server/config.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/__init__.py` (empty package marker)
|
||||||
|
- Create: `server/config.py`
|
||||||
|
|
||||||
|
**Step 1: Create `server/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create config**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
MEDIA_DIRS: list[str] = [
|
||||||
|
d.strip() for d in os.environ.get("MEDIA_DIRS", str(Path.home())).split(",") if d.strip()
|
||||||
|
]
|
||||||
|
EXPORT_DIR: str = os.environ.get("EXPORT_DIR", str(Path.home() / "8cut-exports"))
|
||||||
|
DB_PATH: str = os.environ.get("DB_PATH", str(Path.home() / ".8cut.db"))
|
||||||
|
CACHE_DIR: str = os.environ.get("CACHE_DIR", str(Path.home() / ".8cut-cache"))
|
||||||
|
HOST: str = os.environ.get("HOST", "0.0.0.0")
|
||||||
|
PORT: int = int(os.environ.get("PORT", "8000"))
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".ts", ".flv", ".wmv"}
|
||||||
|
|
||||||
|
QUALITY_PRESETS = {
|
||||||
|
"potato": {"height": 480, "bitrate": "500k"},
|
||||||
|
"low": {"height": 720, "bitrate": "2M"},
|
||||||
|
"medium": {"height": 1080, "bitrate": "5M"},
|
||||||
|
"high": {"height": 0, "bitrate": "10M"}, # 0 = original resolution
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/
|
||||||
|
git commit -m "feat: create server/config with env var settings and quality presets"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 9: Create server/app.py — FastAPI skeleton + file listing
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/app.py`
|
||||||
|
- Create: `server/routes/__init__.py`
|
||||||
|
- Create: `server/routes/files.py`
|
||||||
|
|
||||||
|
**Step 1: Create FastAPI app**
|
||||||
|
|
||||||
|
`server/app.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from .routes import files, stream, markers, export, hidden
|
||||||
|
|
||||||
|
app = FastAPI(title="8-cut Server")
|
||||||
|
app.include_router(files.router, prefix="/api")
|
||||||
|
app.include_router(stream.router, prefix="/api")
|
||||||
|
app.include_router(markers.router, prefix="/api")
|
||||||
|
app.include_router(export.router, prefix="/api")
|
||||||
|
app.include_router(hidden.router, prefix="/api")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create file listing route**
|
||||||
|
|
||||||
|
`server/routes/files.py`:
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_videos(root: str) -> list[dict]:
|
||||||
|
results = []
|
||||||
|
for dirpath, _, filenames in os.walk(root):
|
||||||
|
for f in sorted(filenames):
|
||||||
|
if os.path.splitext(f)[1].lower() in VIDEO_EXTENSIONS:
|
||||||
|
full = os.path.join(dirpath, f)
|
||||||
|
rel = os.path.relpath(full, root)
|
||||||
|
results.append({
|
||||||
|
"name": f,
|
||||||
|
"path": rel,
|
||||||
|
"root": root,
|
||||||
|
"size": os.path.getsize(full),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files")
|
||||||
|
def list_files(root: str | None = Query(None)):
|
||||||
|
dirs = [root] if root and root in MEDIA_DIRS else MEDIA_DIRS
|
||||||
|
files = []
|
||||||
|
for d in dirs:
|
||||||
|
files.extend(_scan_videos(d))
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/roots")
|
||||||
|
def list_roots():
|
||||||
|
return MEDIA_DIRS
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Create `server/routes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Create stub routers** so app.py imports don't fail. Each file gets a minimal router — later tasks fill in the real endpoints.
|
||||||
|
|
||||||
|
`server/routes/stream.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/markers.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/export.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/hidden.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/
|
||||||
|
git commit -m "feat: add FastAPI app with file listing endpoint"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 10: Create server/routes/stream.py — video serving + transcode cache
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/cache.py`
|
||||||
|
- Create: `server/routes/stream.py`
|
||||||
|
|
||||||
|
**Step 1: Create cache manager**
|
||||||
|
|
||||||
|
`server/cache.py` handles:
|
||||||
|
- Computing cache paths from source file hash + quality
|
||||||
|
- Checking cache status
|
||||||
|
- Launching background ffmpeg transcodes
|
||||||
|
- Tracking in-progress jobs
|
||||||
|
|
||||||
|
**Step 2: Create stream routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/video/{path} — raw file, range requests
|
||||||
|
GET /api/stream/{path}?quality=low — cached transcode, range requests (202 if not ready)
|
||||||
|
GET /api/audio/{path} — cached audio extraction, range requests (202 if not ready)
|
||||||
|
GET /api/cache/status/{path} — cache status for all qualities
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/cache.py server/routes/stream.py
|
||||||
|
git commit -m "feat: add video streaming with transcode cache and audio extraction"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 11: Create server/routes/markers.py — DB endpoints
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/markers.py`
|
||||||
|
|
||||||
|
**Step 1: Create markers/profiles/labels routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/markers/{filename}?profile=default
|
||||||
|
GET /api/profiles
|
||||||
|
GET /api/labels
|
||||||
|
```
|
||||||
|
|
||||||
|
Uses `ProcessedDB` singleton from `core.db`.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/markers.py
|
||||||
|
git commit -m "feat: add markers, profiles, and labels API endpoints"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 12: Create server/routes/export.py + WebSocket
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/export.py`
|
||||||
|
- Create: `server/ws.py`
|
||||||
|
|
||||||
|
**Step 1: Create export routes + WS**
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/export — start export job
|
||||||
|
GET /api/export/{id} — check job status
|
||||||
|
DELETE /api/export/{path} — delete export from DB + disk
|
||||||
|
WS /ws/export — real-time progress
|
||||||
|
```
|
||||||
|
|
||||||
|
Uses `ExportRunner` from `core.export`.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/export.py server/ws.py
|
||||||
|
git commit -m "feat: add export endpoint with WebSocket progress"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 13: Create server/routes/hidden.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/hidden.py`
|
||||||
|
|
||||||
|
**Step 1: Create hidden file routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/hidden/{filename}?profile=default
|
||||||
|
DELETE /api/hidden/{filename}?profile=default
|
||||||
|
GET /api/hidden?profile=default
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/hidden.py
|
||||||
|
git commit -m "feat: add hidden files API endpoints"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 14: Create Dockerfile + docker-compose.yml
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `Dockerfile`
|
||||||
|
- Create: `docker-compose.yml`
|
||||||
|
|
||||||
|
**Step 1: Create Dockerfile**
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM python:3.12-slim
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
WORKDIR /app
|
||||||
|
COPY core/ core/
|
||||||
|
COPY server/ server/
|
||||||
|
# Note: ultralytics + opencv-python needed only if subject tracking is used.
|
||||||
|
# Add them here if tracking is required on the server.
|
||||||
|
RUN pip install --no-cache-dir fastapi uvicorn
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create docker-compose.yml**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
8cut:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- /path/to/videos:/videos:ro
|
||||||
|
- /path/to/exports:/exports
|
||||||
|
- 8cut-data:/data
|
||||||
|
environment:
|
||||||
|
MEDIA_DIRS: /videos
|
||||||
|
EXPORT_DIR: /exports
|
||||||
|
DB_PATH: /data/8cut.db
|
||||||
|
CACHE_DIR: /data/cache
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
8cut-data:
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add Dockerfile docker-compose.yml
|
||||||
|
git commit -m "feat: add Dockerfile and docker-compose for server deployment"
|
||||||
|
```
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
# Audio Similarity Scanning — Design
|
||||||
|
|
||||||
|
**Goal:** Scan a video's audio track and highlight segments that match the sound profile of existing reference clips, so the user can quickly find similar moments without scrubbing manually.
|
||||||
|
|
||||||
|
**Runs in:** Python/Qt client (`main.py`), not the server.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Core Module: `core/audio_scan.py`
|
||||||
|
|
||||||
|
New module alongside `core/tracking.py`. Two main functions:
|
||||||
|
|
||||||
|
- `build_profile(clip_paths: list[str]) -> dict` — extracts MFCCs (20 coefficients) from each clip using `librosa`, returns a profile containing both the averaged vector and individual clip vectors.
|
||||||
|
- `scan_video(video_path: str, profile: dict, mode: str, threshold: float, hop: float) -> list[tuple[float, float, float]]` — slides an 8s window across the video's audio, returns `(start_time, end_time, score)` tuples for segments above threshold.
|
||||||
|
|
||||||
|
### Feature Extraction
|
||||||
|
|
||||||
|
- Audio loaded via `librosa.load()` (handles video files directly, mono, 22050Hz).
|
||||||
|
- MFCCs: `librosa.feature.mfcc(n_mfcc=20)`, averaged over time axis to produce a single vector per window/clip.
|
||||||
|
- Similarity: cosine similarity (`numpy` dot product on L2-normalized vectors).
|
||||||
|
|
||||||
|
### Matching Modes
|
||||||
|
|
||||||
|
- **Average mode:** Compare each window to the mean of all reference MFCC vectors. Fast, good when references are homogeneous.
|
||||||
|
- **Nearest mode:** Compare each window to every reference vector, take the max score. Better when references have variety within the style.
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
- `threshold` (float, 0.0–1.0): minimum cosine similarity to include a segment. Default 0.7.
|
||||||
|
- `hop` (float, seconds): step size for the sliding window. Default 1.0s.
|
||||||
|
- Window size fixed at 8s to match reference clip length.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## UI Integration in `main.py`
|
||||||
|
|
||||||
|
### Controls
|
||||||
|
|
||||||
|
Added near the existing tracking checkbox area:
|
||||||
|
|
||||||
|
- **"Scan" button** — triggers audio scan on current video.
|
||||||
|
- **Threshold slider** (0.0–1.0, step 0.05) — controls match strictness.
|
||||||
|
- **Mode combobox** — "Average" / "Nearest".
|
||||||
|
- **Reference source combobox** — "Current Profile" / "Custom Folder" (shows folder picker when "Custom Folder" selected).
|
||||||
|
|
||||||
|
### Scan Workflow
|
||||||
|
|
||||||
|
1. User clicks Scan.
|
||||||
|
2. Reference clips collected: either all export `output_path` values from the current profile (via DB) or all audio/video files in a custom folder.
|
||||||
|
3. Scan runs in a `QThread` so UI stays responsive.
|
||||||
|
4. On completion, results sent to Timeline widget via signal.
|
||||||
|
|
||||||
|
### Timeline Display
|
||||||
|
|
||||||
|
- New `set_scan_regions(regions: list[tuple[float, float, float]])` method on Timeline.
|
||||||
|
- Drawn as semi-transparent colored rectangles behind existing markers.
|
||||||
|
- Color intensity proportional to score (brighter = higher match).
|
||||||
|
- Cleared on file change or re-scan.
|
||||||
|
|
||||||
|
### Keyboard Shortcut
|
||||||
|
|
||||||
|
- `S` — jump cursor to the next scan region (similar to `M` for next marker).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
Reference clips (DB export paths or folder)
|
||||||
|
|
|
||||||
|
librosa.load() each -> MFCC vectors (20-dim)
|
||||||
|
|
|
||||||
|
Profile: { mean_vector, clip_vectors[] }
|
||||||
|
|
|
||||||
|
Current video -> librosa.load() full audio (mono 22050Hz)
|
||||||
|
|
|
||||||
|
Sliding 8s window (hop=1s) -> MFCC per window
|
||||||
|
|
|
||||||
|
Cosine similarity vs profile -> score per position
|
||||||
|
|
|
||||||
|
Threshold filter -> [(start, end, score), ...]
|
||||||
|
|
|
||||||
|
Timeline: semi-transparent highlight regions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
- 2-hour video at 22050Hz mono ~ 380MB memory.
|
||||||
|
- MFCC extraction + sliding window: ~10-30s.
|
||||||
|
- QThread keeps UI responsive.
|
||||||
|
|
||||||
|
## What This Does NOT Do
|
||||||
|
|
||||||
|
- No DB schema changes — scan results are ephemeral (visual only).
|
||||||
|
- No auto-export — user decides what to cut.
|
||||||
|
- No server integration — runs entirely in the Python client.
|
||||||
|
- No GPU/ML model dependency — just librosa + numpy.
|
||||||
@@ -0,0 +1,739 @@
|
|||||||
|
# Audio Similarity Scanning — Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Scan a video's audio track to find segments matching a reference sound profile, displayed as highlighted regions on the timeline.
|
||||||
|
|
||||||
|
**Architecture:** New `core/audio_scan.py` module extracts MFCC features from reference clips and slides an 8s window across the target video's audio, scoring each position via cosine similarity. A `ScanWorker` QThread runs the scan in the background, and results are drawn as semi-transparent rectangles on the existing Timeline widget.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3, librosa 0.11, numpy, PyQt6
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Core audio_scan module — build_profile
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/audio_scan.py`
|
||||||
|
- Create: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the tests**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# tests/test_audio_scan.py
|
||||||
|
import tempfile, os
|
||||||
|
import numpy as np
|
||||||
|
from core.audio_scan import build_profile, _extract_mfcc
|
||||||
|
|
||||||
|
|
||||||
|
def _make_wav(path: str, duration: float = 8.0, sr: int = 22050):
|
||||||
|
"""Create a short sine-wave WAV file for testing."""
|
||||||
|
import soundfile as sf
|
||||||
|
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
|
||||||
|
audio = 0.5 * np.sin(2 * np.pi * 440 * t)
|
||||||
|
sf.write(path, audio, sr)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_mfcc_returns_1d_vector():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
vec = _extract_mfcc(f.name)
|
||||||
|
assert vec.shape == (20,)
|
||||||
|
assert not np.isnan(vec).any()
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_single_clip():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
profile = build_profile([f.name])
|
||||||
|
assert "mean_vector" in profile
|
||||||
|
assert "clip_vectors" in profile
|
||||||
|
assert profile["mean_vector"].shape == (20,)
|
||||||
|
assert len(profile["clip_vectors"]) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_multiple_clips():
|
||||||
|
paths = []
|
||||||
|
try:
|
||||||
|
for i in range(3):
|
||||||
|
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
freq = 440 + i * 200
|
||||||
|
import soundfile as sf
|
||||||
|
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||||
|
sf.write(f.name, 0.5 * np.sin(2 * np.pi * freq * t), 22050)
|
||||||
|
paths.append(f.name)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
profile = build_profile(paths)
|
||||||
|
assert len(profile["clip_vectors"]) == 3
|
||||||
|
assert profile["mean_vector"].shape == (20,)
|
||||||
|
finally:
|
||||||
|
for p in paths:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_skips_missing_files():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
profile = build_profile([f.name, "/no/such/file.wav"])
|
||||||
|
assert len(profile["clip_vectors"]) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_empty_returns_none():
|
||||||
|
result = build_profile([])
|
||||||
|
assert result is None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run tests to verify they fail**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: FAIL with `ModuleNotFoundError: No module named 'core.audio_scan'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# core/audio_scan.py
|
||||||
|
"""Audio similarity scanning — MFCC-based profile matching."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
from .paths import _log
|
||||||
|
|
||||||
|
_N_MFCC = 20
|
||||||
|
_SR = 22050
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_mfcc(path: str, sr: int = _SR) -> np.ndarray:
|
||||||
|
"""Load audio from a file and return a mean MFCC vector (20-dim)."""
|
||||||
|
y, _ = librosa.load(path, sr=sr, mono=True)
|
||||||
|
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=_N_MFCC)
|
||||||
|
return mfcc.mean(axis=1) # average over time → (20,)
|
||||||
|
|
||||||
|
|
||||||
|
def build_profile(clip_paths: list[str]) -> dict | None:
|
||||||
|
"""Extract MFCCs from reference clips.
|
||||||
|
|
||||||
|
Returns dict with:
|
||||||
|
- mean_vector: averaged MFCC across all clips (20,)
|
||||||
|
- clip_vectors: list of individual MFCC vectors
|
||||||
|
Returns None if no clips could be loaded.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
for p in clip_paths:
|
||||||
|
try:
|
||||||
|
vec = _extract_mfcc(p)
|
||||||
|
vectors.append(vec)
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: skip {p}: {e}")
|
||||||
|
if not vectors:
|
||||||
|
return None
|
||||||
|
arr = np.stack(vectors)
|
||||||
|
return {
|
||||||
|
"mean_vector": arr.mean(axis=0),
|
||||||
|
"clip_vectors": vectors,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run tests to verify they pass**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: all 5 PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add audio_scan module with build_profile"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Core audio_scan module — scan_video
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py`
|
||||||
|
- Modify: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the tests**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from core.audio_scan import scan_video
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_finds_matching_region():
|
||||||
|
"""A video made of the same sine wave as the reference should match."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
_make_wav(ref.name, duration=8.0)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
_make_wav(vid.name, duration=20.0)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0)
|
||||||
|
assert len(regions) > 0
|
||||||
|
for start, end, score in regions:
|
||||||
|
assert abs((end - start) - 8.0) < 1e-9
|
||||||
|
assert score >= 0.5
|
||||||
|
assert score >= 0.5
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_nearest_mode():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
_make_wav(ref.name, duration=8.0)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
_make_wav(vid.name, duration=20.0)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.5, hop=1.0)
|
||||||
|
assert len(regions) > 0
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_high_threshold_no_match():
|
||||||
|
"""Different frequencies with very high threshold should not match."""
|
||||||
|
import soundfile as sf
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||||
|
sf.write(ref.name, 0.5 * np.sin(2 * np.pi * 440 * t), 22050)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
# White noise — very different from sine wave
|
||||||
|
sf.write(vid.name, np.random.randn(22050 * 20).astype(np.float32) * 0.1, 22050)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="average", threshold=0.99, hop=1.0)
|
||||||
|
assert len(regions) == 0
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run tests to verify they fail**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_scan_video_finds_matching_region -v`
|
||||||
|
Expected: FAIL with `ImportError: cannot import name 'scan_video'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
Add to `core/audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""Cosine similarity between two vectors.
|
||||||
|
|
||||||
|
Returns value in [-1, 1]. Negative means anti-correlated (very
|
||||||
|
dissimilar). For threshold filtering this is fine — negative scores
|
||||||
|
never exceed the threshold. Scores near 0 may be uncorrelated or
|
||||||
|
weakly anti-correlated.
|
||||||
|
"""
|
||||||
|
na = np.linalg.norm(a)
|
||||||
|
nb = np.linalg.norm(b)
|
||||||
|
if na == 0 or nb == 0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a, b) / (na * nb))
|
||||||
|
|
||||||
|
|
||||||
|
def scan_video(
|
||||||
|
video_path: str,
|
||||||
|
profile: dict,
|
||||||
|
mode: str = "average",
|
||||||
|
threshold: float = 0.7,
|
||||||
|
hop: float = 1.0,
|
||||||
|
window: float = 8.0,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Slide a window across the video audio and score against the profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: path to video/audio file
|
||||||
|
profile: dict from build_profile()
|
||||||
|
mode: "average" (compare to mean) or "nearest" (max over all clips)
|
||||||
|
threshold: minimum cosine similarity to include
|
||||||
|
hop: step size in seconds
|
||||||
|
window: window size in seconds (default 8s)
|
||||||
|
cancel_flag: object with _cancel bool attribute; checked each iteration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (start_time, end_time, score) for regions above threshold
|
||||||
|
"""
|
||||||
|
_log(f"audio_scan: loading {video_path}")
|
||||||
|
y, sr = librosa.load(video_path, sr=_SR, mono=True)
|
||||||
|
duration = len(y) / sr
|
||||||
|
_log(f"audio_scan: {duration:.1f}s loaded, scanning with hop={hop}s")
|
||||||
|
|
||||||
|
win_samples = int(window * sr)
|
||||||
|
hop_samples = int(hop * sr)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
pos = 0
|
||||||
|
while pos + win_samples <= len(y):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: cancelled")
|
||||||
|
return results
|
||||||
|
|
||||||
|
chunk = y[pos : pos + win_samples]
|
||||||
|
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=_N_MFCC)
|
||||||
|
vec = mfcc.mean(axis=1)
|
||||||
|
|
||||||
|
if mode == "nearest":
|
||||||
|
score = max(
|
||||||
|
_cosine_similarity(vec, cv) for cv in profile["clip_vectors"]
|
||||||
|
)
|
||||||
|
else: # average
|
||||||
|
score = _cosine_similarity(vec, profile["mean_vector"])
|
||||||
|
|
||||||
|
if score >= threshold:
|
||||||
|
start_t = pos / sr
|
||||||
|
results.append((start_t, start_t + window, score))
|
||||||
|
|
||||||
|
pos += hop_samples
|
||||||
|
|
||||||
|
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run tests to verify they pass**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: all 8 PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add scan_video with average and nearest modes"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Timeline — draw scan regions
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (Timeline class, around lines 209-260 and 300-375)
|
||||||
|
|
||||||
|
**Step 1: Add scan region storage to Timeline.__init__**
|
||||||
|
|
||||||
|
In `main.py`, find the Timeline class `__init__` method (around line 198). After `self._markers` initialization (line 209), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add set_scan_regions method**
|
||||||
|
|
||||||
|
After the `set_markers` method (line 249-252), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
|
||||||
|
"""regions: list of (start_time, end_time, score)"""
|
||||||
|
self._scan_regions = regions
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def clear_scan_regions(self) -> None:
|
||||||
|
self._scan_regions = []
|
||||||
|
self.update()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Draw scan regions in paintEvent**
|
||||||
|
|
||||||
|
In `paintEvent` (starts around line 282), find the marker drawing section (line 363, comment `# ── export markers`). BEFORE that section, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ── scan regions ──────────────────────────────────────────────
|
||||||
|
if self._scan_regions and self._duration > 0:
|
||||||
|
for (start, end, score) in self._scan_regions:
|
||||||
|
x1 = int(start / self._duration * w)
|
||||||
|
x2 = int(end / self._duration * w)
|
||||||
|
alpha = int(40 + score * 80) # 40–120 opacity
|
||||||
|
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: app starts without errors. No scan regions visible yet (none set).
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: timeline scan region rendering"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: ScanWorker QThread
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (add ScanWorker class, after ExportWorker around line 165)
|
||||||
|
|
||||||
|
**Step 1: Add the ScanWorker class**
|
||||||
|
|
||||||
|
After the `ExportWorker` class (ends around line 165), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ScanWorker(QThread):
|
||||||
|
"""Runs audio similarity scan off the main thread."""
|
||||||
|
finished = pyqtSignal(list) # emits list of (start, end, score)
|
||||||
|
error = pyqtSignal(str)
|
||||||
|
progress = pyqtSignal(str) # status message
|
||||||
|
|
||||||
|
def __init__(self, video_path: str, clip_paths: list[str],
|
||||||
|
mode: str = "average", threshold: float = 0.7):
|
||||||
|
super().__init__()
|
||||||
|
self._video_path = video_path
|
||||||
|
self._clip_paths = clip_paths
|
||||||
|
self._mode = mode
|
||||||
|
self._threshold = threshold
|
||||||
|
self._cancel = False
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._cancel = True
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
from core.audio_scan import build_profile, scan_video
|
||||||
|
try:
|
||||||
|
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
|
||||||
|
profile = build_profile(self._clip_paths)
|
||||||
|
if self._cancel:
|
||||||
|
return
|
||||||
|
if profile is None:
|
||||||
|
self.error.emit("No valid reference clips found")
|
||||||
|
return
|
||||||
|
self.progress.emit("Scanning audio...")
|
||||||
|
regions = scan_video(
|
||||||
|
self._video_path, profile,
|
||||||
|
mode=self._mode, threshold=self._threshold,
|
||||||
|
cancel_flag=self,
|
||||||
|
)
|
||||||
|
if not self._cancel:
|
||||||
|
self.finished.emit(regions)
|
||||||
|
except Exception as e:
|
||||||
|
if not self._cancel:
|
||||||
|
self.error.emit(str(e))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify import works**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -c "from main import ScanWorker; print('ok')"`
|
||||||
|
Expected: `ok`
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add ScanWorker QThread for background scanning"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: DB helper — get_all_export_paths
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/db.py`
|
||||||
|
- Modify: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the test**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_db_get_all_export_paths():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||||
|
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||||
|
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||||
|
paths = db.get_all_export_paths("test")
|
||||||
|
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run test to verify it fails**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||||
|
Expected: FAIL with `AttributeError: 'ProcessedDB' object has no attribute 'get_all_export_paths'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
Add to `core/db.py`, after the `get_markers` method. Note: no lock needed — follows
|
||||||
|
the codebase convention where read-only methods don't acquire the lock.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||||
|
"""Return all unique output_path values for a given profile."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run test to verify it passes**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/db.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add get_all_export_paths to ProcessedDB"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: UI controls for audio scanning
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (MainWindow class — control creation ~1490-1575, layout ~1620-1640)
|
||||||
|
|
||||||
|
**Step 1: Add scan control widgets**
|
||||||
|
|
||||||
|
In the MainWindow `__init__`, find the control creation section. After `self._chk_track` (around line 1501), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ── audio scan controls ──────────────────────────────────────
|
||||||
|
self._btn_scan = QPushButton("Scan")
|
||||||
|
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
|
||||||
|
self._btn_scan.clicked.connect(self._start_scan)
|
||||||
|
|
||||||
|
self._sld_threshold = QDoubleSpinBox()
|
||||||
|
self._sld_threshold.setRange(0.0, 1.0)
|
||||||
|
self._sld_threshold.setSingleStep(0.05)
|
||||||
|
self._sld_threshold.setValue(0.7)
|
||||||
|
self._sld_threshold.setPrefix("Thr: ")
|
||||||
|
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
|
||||||
|
|
||||||
|
self._cmb_scan_mode = QComboBox()
|
||||||
|
self._cmb_scan_mode.addItems(["Average", "Nearest"])
|
||||||
|
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
|
||||||
|
|
||||||
|
self._cmb_scan_ref = QComboBox()
|
||||||
|
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
|
||||||
|
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
|
||||||
|
self._scan_folder: str = ""
|
||||||
|
|
||||||
|
self._scan_worker: ScanWorker | None = None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add controls to settings_row layout**
|
||||||
|
|
||||||
|
Find the `settings_row` assembly (around line 1620). Before `settings_row.addStretch()` (around line 1635), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
settings_row.addWidget(self._btn_scan)
|
||||||
|
settings_row.addWidget(self._sld_threshold)
|
||||||
|
settings_row.addWidget(self._cmb_scan_mode)
|
||||||
|
settings_row.addWidget(self._cmb_scan_ref)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Add handler methods**
|
||||||
|
|
||||||
|
Add these methods to MainWindow (after `_jump_to_next_marker` around line 2410):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _on_scan_ref_changed(self, index: int) -> None:
|
||||||
|
if index == 1: # Custom Folder
|
||||||
|
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
|
||||||
|
if folder:
|
||||||
|
self._scan_folder = folder
|
||||||
|
else:
|
||||||
|
self._cmb_scan_ref.setCurrentIndex(0)
|
||||||
|
|
||||||
|
def _cleanup_scan_worker(self) -> None:
|
||||||
|
"""Disconnect signals and schedule deletion of old scan worker."""
|
||||||
|
if self._scan_worker is not None:
|
||||||
|
try:
|
||||||
|
self._scan_worker.finished.disconnect()
|
||||||
|
self._scan_worker.error.disconnect()
|
||||||
|
self._scan_worker.progress.disconnect()
|
||||||
|
except TypeError:
|
||||||
|
pass # already disconnected
|
||||||
|
self._scan_worker.deleteLater()
|
||||||
|
self._scan_worker = None
|
||||||
|
|
||||||
|
def _start_scan(self) -> None:
|
||||||
|
if not self._file_path:
|
||||||
|
self._show_status("No video loaded")
|
||||||
|
return
|
||||||
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
|
self._show_status("Scan already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clean up previous worker
|
||||||
|
self._cleanup_scan_worker()
|
||||||
|
|
||||||
|
# Collect reference clip paths
|
||||||
|
if self._cmb_scan_ref.currentIndex() == 0:
|
||||||
|
# Current profile — all exports across all files in this profile
|
||||||
|
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
|
||||||
|
if os.path.exists(p)]
|
||||||
|
else:
|
||||||
|
# Custom folder
|
||||||
|
if not self._scan_folder:
|
||||||
|
self._show_status("No reference folder selected")
|
||||||
|
return
|
||||||
|
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
|
||||||
|
clip_paths = [
|
||||||
|
os.path.join(self._scan_folder, f)
|
||||||
|
for f in sorted(os.listdir(self._scan_folder))
|
||||||
|
if f.lower().endswith(exts)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not clip_paths:
|
||||||
|
self._show_status("No reference clips found")
|
||||||
|
return
|
||||||
|
|
||||||
|
mode = self._cmb_scan_mode.currentText().lower()
|
||||||
|
threshold = self._sld_threshold.value()
|
||||||
|
|
||||||
|
self._btn_scan.setEnabled(False)
|
||||||
|
self._scan_file_path = self._file_path # remember which file we're scanning
|
||||||
|
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
|
||||||
|
|
||||||
|
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
|
||||||
|
self._scan_worker.finished.connect(self._on_scan_done)
|
||||||
|
self._scan_worker.error.connect(self._on_scan_error)
|
||||||
|
self._scan_worker.progress.connect(self._show_status)
|
||||||
|
self._scan_worker.start()
|
||||||
|
|
||||||
|
def _on_scan_done(self, regions: list) -> None:
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
# Ignore stale results if the user switched files during scan
|
||||||
|
if self._file_path != getattr(self, '_scan_file_path', None):
|
||||||
|
return
|
||||||
|
self._timeline.set_scan_regions(regions)
|
||||||
|
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
||||||
|
|
||||||
|
def _on_scan_error(self, msg: str) -> None:
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
self._show_status(f"Scan error: {msg}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: Scan button, threshold spinner, mode dropdown, and reference source dropdown visible in the settings row. Clicking Scan with no file loaded shows "No video loaded" in status.
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add scan UI controls and start_scan handler"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Keyboard shortcut — jump to next scan region
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py`
|
||||||
|
|
||||||
|
**Step 1: Add the keyboard shortcut**
|
||||||
|
|
||||||
|
Find the shortcut definitions (around line 1728, where `QShortcut(QKeySequence("M"), ...)` is defined). Add after it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add the jump method**
|
||||||
|
|
||||||
|
After `_on_scan_error` (or after `_jump_to_next_marker`), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _jump_to_next_scan_region(self) -> None:
|
||||||
|
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
|
||||||
|
if not regions:
|
||||||
|
return
|
||||||
|
for (start, _end, _score) in regions:
|
||||||
|
if start > self._cursor + 0.1:
|
||||||
|
self._step_cursor(start - self._cursor)
|
||||||
|
return
|
||||||
|
# Wrap to first region
|
||||||
|
self._step_cursor(regions[0][0] - self._cursor)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Update help text**
|
||||||
|
|
||||||
|
Find the help/shortcuts tooltip (around line 1757). Add a row:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Clear scan regions and cancel running scan on file change**
|
||||||
|
|
||||||
|
Find `_load_file` method (around line 1931). After the existing marker/state resets, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self._timeline.clear_scan_regions()
|
||||||
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
|
self._scan_worker.cancel()
|
||||||
|
self._cleanup_scan_worker()
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: S key does nothing when no scan regions exist. After a scan, S jumps through matched regions.
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add S shortcut and clear scan on file change"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: Final integration test
|
||||||
|
|
||||||
|
**Step 1: End-to-end manual test**
|
||||||
|
|
||||||
|
1. Open the app: `cd /media/p5/8-cut && python main.py`
|
||||||
|
2. Load a video file
|
||||||
|
3. Export a few clips (these become the reference)
|
||||||
|
4. Set reference source to "Current Profile"
|
||||||
|
5. Click "Scan"
|
||||||
|
6. Verify: status shows progress messages, then "Scan complete: N matching regions"
|
||||||
|
7. Verify: cyan-tinted regions appear on the timeline
|
||||||
|
8. Press S to jump through scan regions
|
||||||
|
9. Change threshold and re-scan — verify different number of regions
|
||||||
|
10. Switch mode to "Nearest" and re-scan
|
||||||
|
11. Switch reference to "Custom Folder", pick a folder with clips
|
||||||
|
12. Re-scan and verify results
|
||||||
|
|
||||||
|
**Step 2: Run all tests**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/ -v`
|
||||||
|
Expected: all tests PASS
|
||||||
|
|
||||||
|
**Step 3: Final commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add -A
|
||||||
|
git commit -m "feat: audio similarity scanning complete"
|
||||||
|
```
|
||||||
+18
-1
@@ -1,4 +1,21 @@
|
|||||||
|
# Core GUI
|
||||||
PyQt6>=6.4
|
PyQt6>=6.4
|
||||||
python-mpv>=1.0
|
python-mpv>=1.0
|
||||||
pytest>=7.0
|
|
||||||
|
# Audio & ML
|
||||||
|
librosa>=0.10
|
||||||
|
numpy>=1.24
|
||||||
|
scikit-learn>=1.3
|
||||||
|
joblib>=1.3
|
||||||
|
soundfile>=0.12
|
||||||
|
|
||||||
|
# Deep learning — install via setup_env.sh for correct CUDA version,
|
||||||
|
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
torch>=2.0
|
||||||
|
torchaudio>=2.0
|
||||||
|
|
||||||
|
# Object detection
|
||||||
ultralytics>=8.0
|
ultralytics>=8.0
|
||||||
|
|
||||||
|
# Dev
|
||||||
|
pytest>=7.0
|
||||||
|
|||||||
+42
-8
@@ -1,19 +1,48 @@
|
|||||||
# 8-cut Windows setup script
|
# 8-cut Windows setup script
|
||||||
# Run once: powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
# Run once: powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||||
#
|
#
|
||||||
# Prerequisites: Python 3.10+ must be installed and on PATH
|
# Prerequisites: Python 3.11+ must be installed and on PATH
|
||||||
# https://www.python.org/downloads/
|
# https://www.python.org/downloads/
|
||||||
|
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
|
trap { Write-Host "`n$_" -ForegroundColor Red; Read-Host "Press Enter to close"; exit 1 }
|
||||||
$root = Split-Path -Parent $MyInvocation.MyCommand.Path
|
$root = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||||
|
|
||||||
Write-Host "=== 8-cut Windows Setup ===" -ForegroundColor Cyan
|
Write-Host "=== 8-cut Windows Setup ===" -ForegroundColor Cyan
|
||||||
|
|
||||||
# ── Python deps ────────────────────────────────────────────
|
# ── Virtual environment ───────────────────────────────────
|
||||||
Write-Host "`nInstalling Python dependencies..."
|
$venvDir = Join-Path $root ".venv"
|
||||||
pip install PyQt6 python-mpv
|
if (Test-Path (Join-Path $venvDir "Scripts\python.exe")) {
|
||||||
|
Write-Host "`nVirtual environment already exists, activating..." -ForegroundColor Green
|
||||||
|
} else {
|
||||||
|
Write-Host "`nCreating virtual environment..."
|
||||||
|
python -m venv $venvDir
|
||||||
|
Write-Host "Virtual environment created at $venvDir" -ForegroundColor Green
|
||||||
|
}
|
||||||
|
& "$venvDir\Scripts\Activate.ps1"
|
||||||
|
|
||||||
# ── libmpv ─────────────────────────────────────────────────
|
# ── PyTorch ───────────────────────────────────────────────
|
||||||
|
$hasTorch = python -c "import torch" 2>&1
|
||||||
|
if ($LASTEXITCODE -eq 0) {
|
||||||
|
Write-Host "`nPyTorch already installed, skipping." -ForegroundColor Green
|
||||||
|
} else {
|
||||||
|
# Detect NVIDIA GPU via nvidia-smi
|
||||||
|
$hasNvidia = Get-Command nvidia-smi -ErrorAction SilentlyContinue
|
||||||
|
if ($hasNvidia) {
|
||||||
|
Write-Host "`nNVIDIA GPU detected — installing PyTorch with CUDA 12.8..." -ForegroundColor Green
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
} else {
|
||||||
|
Write-Host "`nNo NVIDIA GPU detected — installing CPU-only PyTorch..." -ForegroundColor Yellow
|
||||||
|
Write-Host "(Audio scanning will work but will be slower without GPU)" -ForegroundColor Yellow
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Python deps ───────────────────────────────────────────
|
||||||
|
Write-Host "`nInstalling project dependencies..."
|
||||||
|
pip install -r (Join-Path $root "requirements.txt")
|
||||||
|
|
||||||
|
# ── libmpv ────────────────────────────────────────────────
|
||||||
$mpvDll = Join-Path $root "libmpv-2.dll"
|
$mpvDll = Join-Path $root "libmpv-2.dll"
|
||||||
if (Test-Path $mpvDll) {
|
if (Test-Path $mpvDll) {
|
||||||
Write-Host "`nlibmpv-2.dll already present, skipping." -ForegroundColor Green
|
Write-Host "`nlibmpv-2.dll already present, skipping." -ForegroundColor Green
|
||||||
@@ -30,12 +59,11 @@ if (Test-Path $mpvDll) {
|
|||||||
Write-Host "libmpv-2.dll downloaded." -ForegroundColor Green
|
Write-Host "libmpv-2.dll downloaded." -ForegroundColor Green
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── ffmpeg ─────────────────────────────────────────────────
|
# ── ffmpeg ────────────────────────────────────────────────
|
||||||
$ffmpeg = Join-Path $root "ffmpeg.exe"
|
$ffmpeg = Join-Path $root "ffmpeg.exe"
|
||||||
if (Test-Path $ffmpeg) {
|
if (Test-Path $ffmpeg) {
|
||||||
Write-Host "`nffmpeg.exe already present, skipping." -ForegroundColor Green
|
Write-Host "`nffmpeg.exe already present, skipping." -ForegroundColor Green
|
||||||
} else {
|
} else {
|
||||||
# Check if ffmpeg is on PATH
|
|
||||||
$onPath = Get-Command ffmpeg -ErrorAction SilentlyContinue
|
$onPath = Get-Command ffmpeg -ErrorAction SilentlyContinue
|
||||||
if ($onPath) {
|
if ($onPath) {
|
||||||
Write-Host "`nffmpeg found on PATH: $($onPath.Source)" -ForegroundColor Green
|
Write-Host "`nffmpeg found on PATH: $($onPath.Source)" -ForegroundColor Green
|
||||||
@@ -54,6 +82,12 @@ if (Test-Path $ffmpeg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ── Verify ────────────────────────────────────────────────
|
||||||
|
Write-Host "`n--- Verification ---" -ForegroundColor Cyan
|
||||||
|
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||||
|
python -c "import sklearn, librosa, torchaudio; print('All imports OK')"
|
||||||
|
|
||||||
Write-Host "`n=== Setup complete ===" -ForegroundColor Cyan
|
Write-Host "`n=== Setup complete ===" -ForegroundColor Cyan
|
||||||
Write-Host "Run 8-cut with: python main.py"
|
Write-Host "Run 8-cut with: .venv\Scripts\python.exe main.py"
|
||||||
Write-Host "Or double-click: 8cut.bat"
|
Write-Host "Or double-click: 8cut.bat"
|
||||||
|
Read-Host "`nPress Enter to close"
|
||||||
|
|||||||
Executable
+114
@@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# 8-cut environment setup — supports conda (miniforge) or python venv
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./setup_env.sh # auto-detect (prefers conda if available)
|
||||||
|
# ./setup_env.sh --conda # force conda
|
||||||
|
# ./setup_env.sh --venv # force python venv
|
||||||
|
# ─��────────────────────────────��───────────────────────────────────────
|
||||||
|
|
||||||
|
ENV_NAME="8cut"
|
||||||
|
PYTHON_VERSION="3.12"
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||||
|
|
||||||
|
# Auto-detect GPU for PyTorch index URL
|
||||||
|
if command -v nvidia-smi &>/dev/null; then
|
||||||
|
TORCH_INDEX="https://download.pytorch.org/whl/cu128"
|
||||||
|
echo "NVIDIA GPU detected — will install PyTorch with CUDA 12.8"
|
||||||
|
else
|
||||||
|
TORCH_INDEX="https://download.pytorch.org/whl/cpu"
|
||||||
|
echo "No NVIDIA GPU detected — will install CPU-only PyTorch"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Parse args ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
MODE=""
|
||||||
|
for arg in "$@"; do
|
||||||
|
case "$arg" in
|
||||||
|
--conda) MODE="conda" ;;
|
||||||
|
--venv) MODE="venv" ;;
|
||||||
|
*) echo "Unknown arg: $arg"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$MODE" ]; then
|
||||||
|
if command -v conda &>/dev/null; then
|
||||||
|
MODE="conda"
|
||||||
|
else
|
||||||
|
MODE="venv"
|
||||||
|
fi
|
||||||
|
echo "Auto-detected mode: $MODE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Conda setup ─────────────��─────────────────────────────────────────
|
||||||
|
|
||||||
|
setup_conda() {
|
||||||
|
echo "==> Setting up conda environment: $ENV_NAME"
|
||||||
|
|
||||||
|
# Source conda shell hooks if not already active
|
||||||
|
if ! command -v conda &>/dev/null; then
|
||||||
|
echo "conda not found in PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
eval "$(conda shell.bash hook)"
|
||||||
|
|
||||||
|
if conda env list | grep -qw "$ENV_NAME"; then
|
||||||
|
echo " Environment '$ENV_NAME' already exists, updating..."
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
else
|
||||||
|
echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..."
|
||||||
|
conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION"
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
|
pip install torch torchaudio --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo " Installing project dependencies..."
|
||||||
|
pip install -r "$SCRIPT_DIR/requirements.txt"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done! Activate with:"
|
||||||
|
echo " conda activate $ENV_NAME"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Venv setup ───────��────────────────────────────────────────────────
|
||||||
|
|
||||||
|
setup_venv() {
|
||||||
|
echo "==> Setting up Python venv at: $VENV_DIR"
|
||||||
|
|
||||||
|
if [ ! -d "$VENV_DIR" ]; then
|
||||||
|
python3 -m venv "$VENV_DIR"
|
||||||
|
echo " Created venv"
|
||||||
|
else
|
||||||
|
echo " Venv already exists, updating..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
source "$VENV_DIR/bin/activate"
|
||||||
|
|
||||||
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
|
pip install torch torchaudio --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo " Installing project dependencies..."
|
||||||
|
pip install -r "$SCRIPT_DIR/requirements.txt"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done! Activate with:"
|
||||||
|
echo " source $VENV_DIR/bin/activate"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Run ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
case "$MODE" in
|
||||||
|
conda) setup_conda ;;
|
||||||
|
venv) setup_venv ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Verify with:"
|
||||||
|
echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\""
|
||||||
|
echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\""
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import tempfile, os
|
||||||
|
import numpy as np
|
||||||
|
from core.audio_scan import scan_video, load_classifier, default_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_no_model_returns_empty():
|
||||||
|
"""scan_video with no model should return empty list."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000)
|
||||||
|
try:
|
||||||
|
regions = scan_video(vid.name, model=None)
|
||||||
|
assert regions == []
|
||||||
|
finally:
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_classifier_missing_returns_none():
|
||||||
|
assert load_classifier("/no/such/model.joblib") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_model_path_contains_profile():
|
||||||
|
path = default_model_path("test_profile")
|
||||||
|
assert "test_profile" in path
|
||||||
|
assert path.endswith(".joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_get_all_export_paths():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||||
|
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||||
|
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||||
|
paths = db.get_all_export_paths("test")
|
||||||
|
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
+2
-1
@@ -1,5 +1,6 @@
|
|||||||
import tempfile, os, json
|
import tempfile, os, json
|
||||||
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, build_annotation_json_path, upsert_clip_annotation, resolve_keyframe, apply_keyframes_to_jobs
|
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, resolve_keyframe, apply_keyframes_to_jobs
|
||||||
|
from core.annotations import build_annotation_json_path, upsert_clip_annotation
|
||||||
from main import ProcessedDB
|
from main import ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user