Compare commits
121 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1bdeb33a6f | |||
| 387ed7bc6a | |||
| f268d61fe4 | |||
| 24db32c09f | |||
| 0f6ae88ea6 | |||
| 4d99cf6015 | |||
| b75fa85ff5 | |||
| e7d47331c6 | |||
| 7cd31ebe55 | |||
| 3a37dddfd9 | |||
| b249705506 | |||
| aaf405dd3d | |||
| cb2060beb8 | |||
| 0db412baf4 | |||
| 876026d1f6 | |||
| 6c1d42adfe | |||
| d8b3972bdc | |||
| bd345abca2 | |||
| 7d6fee9df1 | |||
| fd043f4172 | |||
| 3c3b1d74bb | |||
| a3c657c66e | |||
| 5d45b8d8eb | |||
| e6db83f00b | |||
| edc5784ba6 | |||
| 8ed9fbf557 | |||
| 4fb2ae144f | |||
| 2614a765d5 | |||
| c020c0dfec | |||
| e7b791fbfa | |||
| f5361a963e | |||
| 8fb8581816 | |||
| 5b25e85e98 | |||
| e3f133ef84 | |||
| 4736f150b0 | |||
| 52aa982aa2 | |||
| 07457d0d6f | |||
| c5d613fc5f | |||
| 7855ea62c2 | |||
| 70be5974cf | |||
| a0286d5cf9 | |||
| 2b7dfb330d | |||
| 518554f788 | |||
| 282156e8ed | |||
| 3417a0f603 | |||
| cd0552197f | |||
| 7dffcb08eb | |||
| 93bcb23fa7 | |||
| eda7826a40 | |||
| e7e20b0fe6 | |||
| 814ef946eb | |||
| 2e738df9ae | |||
| 6ddfcde8ee | |||
| b161412d94 | |||
| 5a9e068903 | |||
| 6870e5aaf3 | |||
| f597ff29e8 | |||
| e1789d4e71 | |||
| 7834b1d05c | |||
| 12ed183f1b | |||
| f2c38aee79 | |||
| 8ab5bdba77 | |||
| c6c5934fe8 | |||
| 73d5367424 | |||
| 1e2cebd424 | |||
| c439aca9b9 | |||
| afda9b2d9f | |||
| fd42791c9f | |||
| 4cf54f2642 | |||
| e7f4de9ec1 | |||
| 9cf9e3233f | |||
| e17d8f67aa | |||
| b1980de6d1 | |||
| 85e0641440 | |||
| 834b89b682 | |||
| a67e189aa0 | |||
| 2b6c56cd15 | |||
| 0f6082061f | |||
| 9662b815db | |||
| 9776b83ac5 | |||
| 39f873bec2 | |||
| 409eb82e5c | |||
| 297aafa51c | |||
| b4cf972d59 | |||
| 5cc1e52e75 | |||
| 6bf0b0ae99 | |||
| b6fbda01dd | |||
| 51d41f0a56 | |||
| 16bd1a9ae0 | |||
| 2036c49b52 | |||
| b12758c53c | |||
| 3d484952c2 | |||
| 12dae93671 | |||
| 1e65fd6b0f | |||
| f7756320e5 | |||
| cd0331d4ce | |||
| 38c6174f83 | |||
| 5b22bceed2 | |||
| 80f21915e3 | |||
| b09ba3fa9e | |||
| 5b7a55a05d | |||
| 2200da491f | |||
| 3d6469c60c | |||
| 6a4ac8b8ed | |||
| 1f6906c946 | |||
| dfba88a601 | |||
| e94c088df0 | |||
| 9569103edd | |||
| 079afeee7c | |||
| fbbfa6fdce | |||
| 56920a5247 | |||
| 08c1dd8b33 | |||
| 2b63ad1857 | |||
| 72f6a4e8f5 | |||
| 799a2ab353 | |||
| 066f4431ba | |||
| 97f9ef7073 | |||
| 592e40c1a6 | |||
| 73dd7a1569 | |||
| 7abf0b4d4c | |||
| 9e5bd4a8ec |
@@ -0,0 +1,36 @@
|
||||
name: Docker Image
|
||||
|
||||
on:
|
||||
workflow_dispatch: # manual only — build locally and push to ghcr.io
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- uses: docker/metadata-action@v5
|
||||
id: meta
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}-server
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
type=sha,prefix=
|
||||
|
||||
- uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -3,3 +3,8 @@ __pycache__/
|
||||
*.pyo
|
||||
.pytest_cache/
|
||||
.worktrees/
|
||||
.venv/
|
||||
models/
|
||||
cache/
|
||||
*.joblib
|
||||
*.pt
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
@echo off
|
||||
cd /d "%~dp0"
|
||||
python main.py %*
|
||||
if exist ".venv\Scripts\python.exe" (
|
||||
.venv\Scripts\python.exe main.py %*
|
||||
) else (
|
||||
python main.py %*
|
||||
)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Launch 8-cut with auto-detected venv/conda environment
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
ENV_NAME="8cut"
|
||||
CONDA_PREFIX_BASE="/media/p5/miniforge3"
|
||||
|
||||
# 1. Try .venv in project dir
|
||||
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
|
||||
source "$SCRIPT_DIR/.venv/bin/activate"
|
||||
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
|
||||
# 2. Try conda env (works without shell init)
|
||||
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
|
||||
if [ -x "$CONDA_PYTHON" ]; then
|
||||
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
|
||||
# 3. Try conda via shell hook (interactive shells)
|
||||
if command -v conda &>/dev/null; then
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
|
||||
conda activate "$ENV_NAME"
|
||||
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 4. Fallback to system Python
|
||||
exec python3 "$SCRIPT_DIR/main.py" "$@"
|
||||
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Calibration — per-video normalized features + classifier."""
|
||||
import sys, os, time, warnings
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
||||
from core.audio_scan import _SR, _WINDOW
|
||||
|
||||
_HOP_LENGTH = 1024
|
||||
_N_FFT = 2048
|
||||
from core.db import ProcessedDB
|
||||
|
||||
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||
PROFILE_NAME = "JAV_missionary"
|
||||
TOLERANCE = 12.0
|
||||
NEG_MARGIN = 120.0
|
||||
|
||||
|
||||
def extract_rich_features(y, sr=_SR):
|
||||
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
|
||||
hop = _HOP_LENGTH
|
||||
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
|
||||
rms = librosa.feature.rms(S=S, hop_length=hop)
|
||||
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
|
||||
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
|
||||
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
|
||||
flatness = librosa.feature.spectral_flatness(S=S)
|
||||
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
|
||||
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
|
||||
|
||||
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
|
||||
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
|
||||
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
|
||||
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
|
||||
band_feats = []
|
||||
for flo, fhi in bands:
|
||||
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
|
||||
if mask.sum() > 0:
|
||||
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
|
||||
else:
|
||||
band_feats.append(np.zeros((1, mel_S.shape[1])))
|
||||
|
||||
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
|
||||
|
||||
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
|
||||
band_feats[0].shape[1])
|
||||
return np.vstack([
|
||||
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
|
||||
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
|
||||
] + [b[:, :min_t] for b in band_feats]
|
||||
+ [sc[:, :min_t]])
|
||||
|
||||
|
||||
def compute_window_stats(feat, hop=1.0):
|
||||
"""Sliding window mean/std → (timestamps, feature_vectors)."""
|
||||
n_feats, T = feat.shape
|
||||
fps = _SR / _HOP_LENGTH
|
||||
win_frames = int(_WINDOW * fps)
|
||||
hop_frames = int(hop * fps)
|
||||
if win_frames > T:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
cumsum = np.zeros((n_feats, T + 1))
|
||||
cumsum[:, 1:] = np.cumsum(feat, axis=1)
|
||||
cumsq = np.zeros((n_feats, T + 1))
|
||||
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
|
||||
|
||||
starts = np.arange(0, T - win_frames + 1, hop_frames)
|
||||
ends = starts + win_frames
|
||||
sums = cumsum[:, ends] - cumsum[:, starts]
|
||||
sq_sums = cumsq[:, ends] - cumsq[:, starts]
|
||||
means = sums / win_frames
|
||||
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
|
||||
|
||||
return starts / fps, np.vstack([means, stds]).T
|
||||
|
||||
|
||||
def label_windows(timestamps, gt_intense, gt_soft):
|
||||
all_gt = list(gt_intense) + list(gt_soft)
|
||||
labels = np.zeros(len(timestamps), dtype=int)
|
||||
for i, t in enumerate(timestamps):
|
||||
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||
if di < TOLERANCE:
|
||||
labels[i] = 1
|
||||
elif da > NEG_MARGIN:
|
||||
labels[i] = -1
|
||||
return labels
|
||||
|
||||
|
||||
def main():
|
||||
db = ProcessedDB()
|
||||
rows = db._con.execute(
|
||||
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
|
||||
(PROFILE_NAME,),
|
||||
).fetchall()
|
||||
|
||||
intense_by_video, soft_by_video = {}, {}
|
||||
for fn, st, op in rows:
|
||||
if '/mp4_Intense/' in op:
|
||||
intense_by_video.setdefault(fn, set()).add(st)
|
||||
elif '/mp4_Soft/' in op:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
videos = [fn for fn in intense_by_video
|
||||
if os.path.exists(os.path.join(PLEX_DIR, fn))]
|
||||
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
|
||||
videos = videos[:n_vids]
|
||||
print(f"Processing {len(videos)} videos...")
|
||||
|
||||
all_data_raw = [] # raw features
|
||||
all_data_norm = [] # per-video z-scored features
|
||||
|
||||
for vi, vname in enumerate(videos):
|
||||
vpath = os.path.join(PLEX_DIR, vname)
|
||||
gt_intense = sorted(intense_by_video.get(vname, set()))
|
||||
gt_soft = sorted(soft_by_video.get(vname, set()))
|
||||
|
||||
t0 = time.time()
|
||||
y, _ = librosa.load(vpath, sr=_SR, mono=True)
|
||||
feat = extract_rich_features(y)
|
||||
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
|
||||
dt = time.time() - t0
|
||||
|
||||
if len(timestamps) == 0:
|
||||
continue
|
||||
|
||||
labels = label_windows(timestamps, gt_intense, gt_soft)
|
||||
|
||||
# Per-video z-score normalization
|
||||
vid_mean = window_vectors.mean(axis=0)
|
||||
vid_std = window_vectors.std(axis=0)
|
||||
vid_std = np.maximum(vid_std, 1e-6)
|
||||
normed = (window_vectors - vid_mean) / vid_std
|
||||
|
||||
n_pos = (labels == 1).sum()
|
||||
n_neg = (labels == -1).sum()
|
||||
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
|
||||
|
||||
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
|
||||
all_data_norm.append((vi, vname, timestamps, normed, labels))
|
||||
|
||||
# Run CV for both raw and normalized
|
||||
for label, data in [("RAW features", all_data_raw),
|
||||
("PER-VIDEO NORMALIZED features", all_data_norm)]:
|
||||
print(f"\n{'='*70}")
|
||||
print(f" {label}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
all_y_true, all_y_prob = [], []
|
||||
|
||||
for test_idx in range(len(data)):
|
||||
_, vname, _, test_X, test_labels = data[test_idx]
|
||||
test_mask = test_labels != 0
|
||||
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
|
||||
continue
|
||||
X_test = test_X[test_mask]
|
||||
y_test = (test_labels[test_mask] == 1).astype(int)
|
||||
|
||||
X_parts, y_parts = [], []
|
||||
for i, (_, _, _, feats, labs) in enumerate(data):
|
||||
if i == test_idx:
|
||||
continue
|
||||
m = labs != 0
|
||||
if m.sum() == 0:
|
||||
continue
|
||||
X_parts.append(feats[m])
|
||||
y_parts.append((labs[m] == 1).astype(int))
|
||||
|
||||
if not X_parts:
|
||||
continue
|
||||
X_train = np.vstack(X_parts)
|
||||
y_train = np.concatenate(y_parts)
|
||||
|
||||
pos_idx = np.where(y_train == 1)[0]
|
||||
neg_idx = np.where(y_train == 0)[0]
|
||||
if len(pos_idx) == 0 or len(neg_idx) == 0:
|
||||
continue
|
||||
rng = np.random.RandomState(42)
|
||||
n_neg = min(len(neg_idx), len(pos_idx) * 3)
|
||||
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
|
||||
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||
|
||||
clf = GradientBoostingClassifier(
|
||||
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
|
||||
)
|
||||
clf.fit(X_train[train_idx], y_train[train_idx])
|
||||
probs = clf.predict_proba(X_test)[:, 1]
|
||||
|
||||
tp = ((probs >= 0.5) & (y_test == 1)).sum()
|
||||
fp = ((probs >= 0.5) & (y_test == 0)).sum()
|
||||
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
|
||||
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
|
||||
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
|
||||
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
|
||||
|
||||
all_y_true.extend(y_test)
|
||||
all_y_prob.extend(probs)
|
||||
|
||||
if not all_y_true:
|
||||
print(" No test results.")
|
||||
continue
|
||||
|
||||
y_true = np.array(all_y_true)
|
||||
y_prob = np.array(all_y_prob)
|
||||
pos_probs = y_prob[y_true == 1]
|
||||
neg_probs = y_prob[y_true == 0]
|
||||
|
||||
if len(pos_probs) > 0 and len(neg_probs) > 0:
|
||||
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
|
||||
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
|
||||
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
|
||||
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
|
||||
|
||||
best_f1, best_thr = 0, 0
|
||||
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
|
||||
for thr in np.arange(0.10, 0.91, 0.05):
|
||||
tp = ((y_prob >= thr) & (y_true == 1)).sum()
|
||||
fp = ((y_prob >= thr) & (y_true == 0)).sum()
|
||||
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
|
||||
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
|
||||
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||
if f1 > best_f1:
|
||||
best_f1, best_thr = f1, thr
|
||||
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
|
||||
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
|
||||
|
||||
# Feature importance
|
||||
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
|
||||
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
|
||||
pos_idx = np.where(y_all == 1)[0]
|
||||
neg_idx = np.where(y_all == 0)[0]
|
||||
rng = np.random.RandomState(42)
|
||||
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
|
||||
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
|
||||
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
|
||||
|
||||
feat_names = (
|
||||
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
|
||||
+ [f"mel{i}" for i in range(8)]
|
||||
+ [f"sc{i}" for i in range(7)]
|
||||
)
|
||||
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
|
||||
imp = clf.feature_importances_
|
||||
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
|
||||
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train an audio scan classifier from DB ground truth.
|
||||
|
||||
Usage:
|
||||
python 8cut_train.py # default model, auto-detect positive
|
||||
python 8cut_train.py --model BEATS # specific embedding model
|
||||
python 8cut_train.py --positive mp4_Intense # explicit positive folder
|
||||
python 8cut_train.py --positive mp4_Intense --model BEATS # both
|
||||
"""
|
||||
import sys, os, warnings
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
|
||||
from core.db import ProcessedDB
|
||||
|
||||
PROFILE_NAME = "JAV_missionary"
|
||||
|
||||
# Fallback for old DB rows without source_path
|
||||
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||
|
||||
|
||||
def main():
|
||||
embed_model = None
|
||||
if "--model" in sys.argv:
|
||||
idx = sys.argv.index("--model")
|
||||
if idx + 1 < len(sys.argv):
|
||||
embed_model = sys.argv[idx + 1]
|
||||
if embed_model not in _EMBED_MODELS:
|
||||
print(f"Unknown model: {embed_model}")
|
||||
print(f"Available: {', '.join(_EMBED_MODELS)}")
|
||||
sys.exit(1)
|
||||
|
||||
positive_suffix = None
|
||||
if "--positive" in sys.argv:
|
||||
idx = sys.argv.index("--positive")
|
||||
if idx + 1 < len(sys.argv):
|
||||
positive_suffix = sys.argv[idx + 1]
|
||||
|
||||
db = ProcessedDB()
|
||||
|
||||
# If --positive given, use the new DB helper
|
||||
if positive_suffix:
|
||||
video_infos = db.get_training_data(
|
||||
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
|
||||
)
|
||||
if not video_infos:
|
||||
print(f"No training data found for positive='{positive_suffix}'")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Legacy fallback: classify by folder path pattern
|
||||
rows = db._con.execute(
|
||||
"SELECT filename, start_time, output_path, source_path"
|
||||
" FROM processed WHERE profile = ?",
|
||||
(PROFILE_NAME,),
|
||||
).fetchall()
|
||||
|
||||
intense_by_video, soft_by_video = {}, {}
|
||||
source_by_fn = {}
|
||||
for fn, st, op, sp in rows:
|
||||
if sp:
|
||||
source_by_fn[fn] = sp
|
||||
if "/mp4_Intense/" in op or "_Intense/" in op:
|
||||
intense_by_video.setdefault(fn, set()).add(st)
|
||||
elif "/mp4_Soft/" in op or "_Soft/" in op:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
video_infos = []
|
||||
for fn in intense_by_video:
|
||||
# Try source_path from DB first, fall back to PLEX_DIR
|
||||
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
|
||||
if not os.path.exists(vpath):
|
||||
print(f" skip (not found): {fn}")
|
||||
continue
|
||||
gt_intense = sorted(intense_by_video[fn])
|
||||
gt_soft = sorted(soft_by_video.get(fn, set()))
|
||||
video_infos.append((vpath, gt_intense, gt_soft))
|
||||
|
||||
label = embed_model or "WAV2VEC2_BASE"
|
||||
print(f"Training {label} model on {len(video_infos)} videos...")
|
||||
model_path = default_model_path(PROFILE_NAME)
|
||||
result = train_classifier(
|
||||
video_infos, model_path=model_path, embed_model=embed_model,
|
||||
)
|
||||
if result is None:
|
||||
print("Training failed: no valid samples or missing class balance")
|
||||
sys.exit(1)
|
||||
print(f"Model saved to {model_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
||||
</p>
|
||||
|
||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets.
|
||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets. Includes audio classification for automated scanning and batch export.
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -22,19 +22,44 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
||||
|
||||
## Features
|
||||
|
||||
### Clip export
|
||||
|
||||
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
||||
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
||||
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
||||
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
||||
- **Random portrait** — optionally apply a random portrait crop to a subset of each batch
|
||||
- **Random portrait/square** — optionally apply a random crop to a subset of each batch
|
||||
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
||||
- **Sound annotation** — label and category fields saved to the clip database; label also written to `dataset.json`
|
||||
- **Export history** — timeline markers show previously exported clips; double-click to enter overwrite mode; right-click to delete
|
||||
- **End-frame preview** — floating window shows the last frame of the selection region
|
||||
- **Playlist** — drag-and-drop or use the Open Files button; right-click to remove items
|
||||
- **Playback loop** — plays the exact selection region on loop so you can preview what will be exported
|
||||
- **Group operations** — delete or overwrite acts on all sub-clips in a batch, not just one
|
||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait") for the same video
|
||||
- **Hardware encoding** — GPU-accelerated export via NVENC, VAAPI, QSV, AMF, or VideoToolbox
|
||||
- **Subject tracking** — auto-adjust crop center using YOLOv8 detection (optional)
|
||||
|
||||
### Audio scanning
|
||||
|
||||
- **Embedding models** — WAV2VEC2 (base/large), HuBERT (base/large/xlarge), BEATs
|
||||
- **Train classifier** — train a gradient boosting classifier on your exported clips to find similar audio
|
||||
- **Scan video** — detect regions matching your trained model with configurable threshold
|
||||
- **Scan All** — batch scan every video in the playlist
|
||||
- **Region fusion** — merge overlapping detections into contiguous regions
|
||||
- **Hard negatives** — mark false positives to refine training
|
||||
- **Model versioning** — timestamped backups with rollback support
|
||||
- **Scan export** — batch export from scan results with spread and minimum duration filtering
|
||||
|
||||
### Scan results panel
|
||||
|
||||
- **Tabbed results** — one tab per model, showing start/end/score per region
|
||||
- **Disable regions** — Delete/Backspace toggles regions off (greyed out, excluded from export) without removing them
|
||||
- **Resize regions** — double-click Time or End cells to edit, or drag region edges directly on the timeline
|
||||
- **Grey ghost** — trimmed portions of resized regions shown as grey overlay on timeline
|
||||
- **Undo** — Ctrl+Z reverts the last disable, resize, drag, or negative toggle
|
||||
|
||||
### Organization
|
||||
|
||||
- **Sound annotation** — label and category fields saved to the clip database and `dataset.json`
|
||||
- **Export history** — timeline markers show previously exported clips; double-click to overwrite; right-click to delete
|
||||
- **Playlist** — drag-and-drop video queue with progress tracking
|
||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait")
|
||||
- **Subprofiles** — lightweight export folder variants for multiple output targets
|
||||
- **Review mode** — clean timeline view for navigating scan results without export clutter
|
||||
|
||||
## Keyboard shortcuts
|
||||
|
||||
@@ -50,37 +75,158 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
||||
| `M` | Jump to next marker (wraps) |
|
||||
| `N` | Next file in playlist |
|
||||
| `G` | Toggle cursor lock |
|
||||
| `Delete` / `Backspace` | Toggle disable on selected scan regions |
|
||||
| `Ctrl+Z` | Undo last scan panel action |
|
||||
| `?` / `F1` | Show keyboard shortcuts |
|
||||
|
||||
Shortcuts are suppressed when a text field has focus.
|
||||
|
||||
## Requirements
|
||||
## Installation
|
||||
|
||||
- Python 3.11+
|
||||
- `ffmpeg` on `PATH`
|
||||
- PyQt6
|
||||
- python-mpv (requires libmpv)
|
||||
### Prerequisites
|
||||
|
||||
- **Python 3.11+** — [python.org/downloads](https://www.python.org/downloads/)
|
||||
- **ffmpeg** — video encoding
|
||||
- **libmpv** — video playback
|
||||
|
||||
### Quick start (all platforms)
|
||||
|
||||
The setup script creates a virtual environment and installs everything including PyTorch with CUDA support:
|
||||
|
||||
```bash
|
||||
# Linux / macOS
|
||||
./setup_env.sh
|
||||
|
||||
# Windows (PowerShell)
|
||||
powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||
```
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
# Linux / macOS
|
||||
./8cut.sh
|
||||
|
||||
# Windows
|
||||
8cut.bat
|
||||
```
|
||||
|
||||
The launch scripts auto-detect your venv or conda environment.
|
||||
|
||||
### Manual installation
|
||||
|
||||
#### 1. Install system dependencies
|
||||
|
||||
**Linux (Arch):**
|
||||
```bash
|
||||
pacman -S python mpv ffmpeg
|
||||
```
|
||||
|
||||
**Linux (Debian/Ubuntu):**
|
||||
```bash
|
||||
apt install python3 python3-venv libmpv-dev ffmpeg
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```powershell
|
||||
# ffmpeg
|
||||
winget install ffmpeg
|
||||
|
||||
# libmpv — download mpv-2.dll and place next to main.py
|
||||
# https://sourceforge.net/projects/mpv-player-windows/files/libmpv/
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install python mpv ffmpeg
|
||||
```
|
||||
|
||||
#### 2. Create a virtual environment
|
||||
|
||||
A virtual environment keeps 8-cut's dependencies isolated from your system Python. This is strongly recommended — PyTorch alone is several GB and can conflict with other projects.
|
||||
|
||||
**Using venv (recommended):**
|
||||
|
||||
```bash
|
||||
# Create the venv in the project directory
|
||||
python3 -m venv .venv
|
||||
|
||||
# Activate it
|
||||
# Linux / macOS:
|
||||
source .venv/bin/activate
|
||||
# Windows (cmd):
|
||||
.venv\Scripts\activate.bat
|
||||
# Windows (PowerShell):
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
**Using conda / miniforge:**
|
||||
|
||||
```bash
|
||||
conda create -n 8cut python=3.12
|
||||
conda activate 8cut
|
||||
```
|
||||
|
||||
You must activate the environment every time you open a new terminal before running 8-cut. The `8cut.sh` launcher does this automatically.
|
||||
|
||||
#### 3. Install PyTorch
|
||||
|
||||
PyTorch must be installed separately with the correct CUDA version for GPU acceleration. Without CUDA, audio scanning will fall back to CPU (much slower).
|
||||
|
||||
**With NVIDIA GPU (CUDA 12.8):**
|
||||
```bash
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
```
|
||||
|
||||
**CPU only (no GPU):**
|
||||
```bash
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
Check available CUDA versions at [pytorch.org/get-started](https://pytorch.org/get-started/locally/).
|
||||
|
||||
#### 4. Install project dependencies
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Platform notes
|
||||
#### 5. Verify
|
||||
|
||||
| Platform | libmpv |
|
||||
|----------|--------|
|
||||
| **Linux** | `apt install libmpv-dev` or `pacman -S mpv` |
|
||||
| **macOS** | `brew install mpv` |
|
||||
| **Windows** | Download `mpv-2.dll` from [mpv Windows builds](https://sourceforge.net/projects/mpv-player-windows/files/libmpv/) and place it in `PATH` or next to `main.py` |
|
||||
```bash
|
||||
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||
python -c "import librosa, torchaudio, sklearn; print('All imports OK')"
|
||||
```
|
||||
|
||||
Windows also needs `ffmpeg.exe` on `PATH` (e.g. `winget install ffmpeg`).
|
||||
### Running
|
||||
|
||||
```bash
|
||||
# With venv activated:
|
||||
python main.py
|
||||
|
||||
# Or use the launcher (auto-activates venv/conda):
|
||||
./8cut.sh # Linux / macOS
|
||||
8cut.bat # Windows
|
||||
```
|
||||
|
||||
### GPU encoding
|
||||
|
||||
Hardware encoders are auto-detected from ffmpeg. Available encoders by platform:
|
||||
|
||||
| Platform | Encoders |
|
||||
|----------|----------|
|
||||
| **Linux** | `h264_nvenc` (NVIDIA), `h264_vaapi` (AMD/Intel), `h264_qsv` (Intel) |
|
||||
| **Windows** | `h264_nvenc` (NVIDIA), `h264_qsv` (Intel), `h264_amf` (AMD) |
|
||||
| **macOS** | `h264_videotoolbox` |
|
||||
|
||||
Enable the **HW** checkbox in the export controls to use GPU encoding.
|
||||
|
||||
### Optional: audio scanning
|
||||
|
||||
Audio scanning requires PyTorch (installed above). Embedding models are downloaded on first use and cached in `cache/downloads/`. A CUDA-capable GPU is strongly recommended for training and scanning speed.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
||||
|
||||
### Export layout
|
||||
@@ -109,6 +255,20 @@ output/
|
||||
clip_001_0.wav
|
||||
```
|
||||
|
||||
### Scan export layout
|
||||
|
||||
Scan exports create one group folder per detected area:
|
||||
|
||||
```
|
||||
output/
|
||||
clip_037/
|
||||
clip_037_a1_0.mp4 # area 1, clip 0
|
||||
clip_037_a1_1.mp4 # area 1, clip 1
|
||||
clip_038/
|
||||
clip_038_a2_0.mp4 # area 2, clip 0
|
||||
...
|
||||
```
|
||||
|
||||
### Sound annotation
|
||||
|
||||
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
||||
@@ -124,9 +284,73 @@ Labels persist between exports so you can cut many clips of the same class witho
|
||||
- **Right-click** a marker to delete it from the database
|
||||
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
||||
|
||||
## Audio scan workflow
|
||||
|
||||
### 1. Build a dataset
|
||||
|
||||
Export clips manually from several videos. Clips from the same export folder (e.g. `mp4_Intense`) become your positive training class.
|
||||
|
||||
**Minimum dataset:** ~20 clips from 2–3 different videos. This is enough for the classifier to learn a basic boundary, but expect noisy results — you'll need to mark hard negatives and retrain.
|
||||
|
||||
**Ideal dataset:** 50–100+ clips from 5+ videos covering the full range of variation in your target sound (different recording conditions, distances, intensities). More variety in your positives makes the model generalize better to unseen footage. Negatives are sampled automatically from regions far from your markers, but adding explicit negatives of confusable sounds (e.g. thunder when training for explosions) significantly reduces false positives.
|
||||
|
||||
The classifier improves iteratively: export a small initial set → train → scan → mark false positives as hard negatives → retrain. Each cycle sharpens the decision boundary without needing a large upfront dataset.
|
||||
|
||||
### 2. Train a classifier
|
||||
|
||||
Click **Train** to open the training dialog:
|
||||
|
||||
- **Positive class** — select the export folder containing your target sounds
|
||||
- **Negative class** — optional explicit negatives, or leave as "(auto only)" for automatic sampling
|
||||
- **Model** — embedding model to use (HuBERT XLARGE recommended)
|
||||
- **Auto-neg margin** — distance from markers to sample automatic negatives (30s default)
|
||||
- **Include scan-exported clips** — whether to include previously scan-exported clips in training
|
||||
|
||||
The classifier trains a `HistGradientBoostingClassifier` on audio embeddings and saves to `models/`.
|
||||
|
||||
### 3. Scan videos
|
||||
|
||||
Select a trained model from the dropdown and click **Scan**. Adjust the threshold slider to control sensitivity. Detected regions appear as colored bands on the timeline and as rows in the results panel.
|
||||
|
||||
Audio embeddings are computed once per video and cached to disk (`cache/w2v/`). Subsequent scans with the same embedding model skip the GPU entirely and only re-run the classifier, which takes milliseconds. This makes the retrain → rescan loop nearly free after the first pass.
|
||||
|
||||
### 4. Review and refine
|
||||
|
||||
- Toggle **Review** mode for a clean timeline focused on scan results
|
||||
- **Disable** false positive regions (Delete key) — they stay in the list but are excluded from export
|
||||
- **Resize** regions by dragging edges on the timeline or editing times in the table
|
||||
- **Mark as negative** — add false positives to the hard negative set for retraining
|
||||
- **Ctrl+Z** to undo any of the above
|
||||
|
||||
### 5. Export results
|
||||
|
||||
Click **Export Scan Results** to batch export all enabled regions. The button shows the estimated clip count based on spread and minimum duration settings.
|
||||
|
||||
### 6. Retrain with feedback
|
||||
|
||||
Train again — hard negatives are automatically included. Each training run saves with a timestamp. Click the **⏲** button next to the model dropdown to restore a previous version if results degrade — restoring automatically rescans with the selected version.
|
||||
|
||||
## Database
|
||||
|
||||
Export history is stored in `~/.8cut.db` (SQLite). The database records filename, start time, output path, label, category, and all encoding settings for every clip. When you open a file, 8-cut matches the filename and pre-populates the timeline with existing markers.
|
||||
Export history is stored in `~/.8cut.db` (SQLite). Tables:
|
||||
|
||||
| Table | Purpose |
|
||||
|-------|---------|
|
||||
| `processed` | Every exported clip with full encoding settings |
|
||||
| `scan_results` | Audio scan detections per video/model |
|
||||
| `hard_negatives` | Timestamps marked as false positives for training |
|
||||
| `hidden_files` | Playlist files hidden by the user |
|
||||
|
||||
The database auto-migrates when new columns are added.
|
||||
|
||||
## File locations
|
||||
|
||||
| Path | Contents |
|
||||
|------|----------|
|
||||
| `~/.8cut.db` | SQLite database |
|
||||
| `models/` | Trained classifier models (`.joblib`) |
|
||||
| `cache/w2v/` | Embedding cache (`.npz`, keyed by video hash) |
|
||||
| `cache/downloads/` | Downloaded pretrained models |
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def build_annotation_json_path(folder: str) -> str:
|
||||
return os.path.join(folder, "dataset.json")
|
||||
|
||||
|
||||
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||
"""Remove the entry for *clip_path* from <folder>/dataset.json if present."""
|
||||
json_path = build_annotation_json_path(folder)
|
||||
if not os.path.exists(json_path):
|
||||
return
|
||||
abs_path = os.path.abspath(clip_path)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
entries = json.load(f)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return
|
||||
entries = [e for e in entries if e.get("path") != abs_path]
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||
"""Insert or update one entry in <folder>/dataset.json.
|
||||
|
||||
Each entry stores a path relative to *folder* and the sound label.
|
||||
Matches on ``path``; if an entry for the same clip already exists it is
|
||||
replaced (overwrite-export case). Nothing is written when *label* is
|
||||
empty.
|
||||
"""
|
||||
if not label.strip():
|
||||
return
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
json_path = build_annotation_json_path(folder)
|
||||
entries: list[dict] = []
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
entries = json.load(f)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
entries = []
|
||||
abs_path = os.path.abspath(clip_path)
|
||||
entry: dict = {"path": abs_path, "label": label}
|
||||
for i, e in enumerate(entries):
|
||||
if e.get("path") == abs_path:
|
||||
entries[i] = entry
|
||||
break
|
||||
else:
|
||||
entries.append(entry)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
@@ -0,0 +1,803 @@
|
||||
"""Audio scanning — embedding-based classifier for audio event detection."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
_SR = 16000 # lower sr = faster
|
||||
|
||||
|
||||
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
|
||||
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
|
||||
cmd = [
|
||||
_bin("ffmpeg"), "-i", path,
|
||||
"-vn", # skip video
|
||||
"-ac", "1", # mono
|
||||
"-ar", str(sr), # resample
|
||||
"-f", "f32le", # raw 32-bit float little-endian
|
||||
"-loglevel", "error",
|
||||
"pipe:1",
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"ffmpeg timed out (300s) on {os.path.basename(path)}")
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
|
||||
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
_WINDOW = 8.0 # seconds
|
||||
_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
|
||||
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
|
||||
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
|
||||
|
||||
# Redirect torch hub and huggingface downloads into the project
|
||||
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
|
||||
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding extraction (lazy-loaded)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_w2v_model = None
|
||||
_w2v_device = None
|
||||
_w2v_model_name = None
|
||||
_ast_feature_extractor = None
|
||||
|
||||
# Supported embedding models — name → embed_dim
|
||||
_EMBED_MODELS = {
|
||||
"WAV2VEC2_BASE": 768,
|
||||
"WAV2VEC2_LARGE": 1024,
|
||||
"WAV2VEC2_LARGE_LV60K":1024,
|
||||
"HUBERT_BASE": 768,
|
||||
"HUBERT_LARGE": 1024,
|
||||
"HUBERT_XLARGE": 1280,
|
||||
"BEATS": 768,
|
||||
# Multi-layer variants (4 quartile layers concatenated)
|
||||
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
"EAT": 768,
|
||||
"EAT_LARGE": 1024,
|
||||
}
|
||||
_DEFAULT_EMBED_MODEL = "WAV2VEC2_BASE"
|
||||
|
||||
_BEATS_CHECKPOINT = os.path.join(
|
||||
_DL_CACHE_DIR, "huggingface", "hub",
|
||||
"models--lpepino--beats_ckpts", "snapshots",
|
||||
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
||||
)
|
||||
|
||||
|
||||
def _get_w2v_model(model_name: str | None = None):
|
||||
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||
global _w2v_model, _w2v_device, _w2v_model_name
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
# Multi-layer variants use the same base model weights
|
||||
ml = _ml_config(model_name)
|
||||
load_name = ml[0] if ml else model_name
|
||||
if _w2v_model is None or _w2v_model_name != load_name:
|
||||
import torch
|
||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if load_name == "BEATS":
|
||||
from .beats_model import BEATs, BEATsConfig
|
||||
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||
weights_only=False)
|
||||
cfg = BEATsConfig(checkpoint['cfg'])
|
||||
_w2v_model = BEATs(cfg)
|
||||
_w2v_model.load_state_dict(checkpoint['model'])
|
||||
_w2v_model.to(_w2v_device)
|
||||
elif load_name == "AST":
|
||||
from transformers import ASTModel, ASTFeatureExtractor
|
||||
_w2v_model = ASTModel.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
).to(_w2v_device)
|
||||
global _ast_feature_extractor
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
elif load_name in ("EAT", "EAT_LARGE"):
|
||||
from transformers import AutoModel
|
||||
eat_repo = ("worstchan/EAT-large_epoch20_finetune_AS2M"
|
||||
if load_name == "EAT_LARGE"
|
||||
else "worstchan/EAT-base_epoch30_finetune_AS2M")
|
||||
_w2v_model = AutoModel.from_pretrained(
|
||||
eat_repo, trust_remote_code=True,
|
||||
).to(_w2v_device)
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||
|
||||
_w2v_model.eval()
|
||||
_w2v_model_name = load_name
|
||||
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||
return _w2v_model, _w2v_device
|
||||
|
||||
|
||||
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||
|
||||
Returns tensor of shape [B, 1, T, 128].
|
||||
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
TARGET_LEN = 1024
|
||||
MEAN, STD = -4.268, 4.569
|
||||
|
||||
mels = []
|
||||
for chunk in chunks:
|
||||
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||
fbank = kaldi.fbank(
|
||||
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||
)
|
||||
# Pad or truncate to TARGET_LEN
|
||||
if fbank.shape[0] < TARGET_LEN:
|
||||
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||
else:
|
||||
fbank = fbank[:TARGET_LEN]
|
||||
fbank = (fbank - MEAN) / (STD * 2)
|
||||
mels.append(fbank)
|
||||
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||
|
||||
|
||||
def _embed_dim(model_name: str | None = None) -> int:
|
||||
"""Return embedding dimension for a model name."""
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
return _EMBED_MODELS.get(model_name, 768)
|
||||
|
||||
|
||||
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||
|
||||
Returns None for single-layer models.
|
||||
Layer indices are 0-based into the list returned by extract_features().
|
||||
"""
|
||||
if not model_name.endswith("_ML"):
|
||||
return None
|
||||
base = model_name[:-3] # strip "_ML"
|
||||
if base not in _EMBED_MODELS:
|
||||
return None
|
||||
# Layer counts per model family
|
||||
layer_counts = {
|
||||
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||
"AST": 12,
|
||||
}
|
||||
n = layer_counts.get(base)
|
||||
if n is None:
|
||||
return None
|
||||
# Select 4 layers at quartile boundaries (0-indexed)
|
||||
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||
return base, indices
|
||||
|
||||
|
||||
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> str:
|
||||
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
abspath = os.path.abspath(video_path)
|
||||
mtime = os.path.getmtime(abspath)
|
||||
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
|
||||
h = hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
||||
|
||||
|
||||
def _w2v_cache_exists(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> bool:
|
||||
"""Check if embedding cache exists for a video."""
|
||||
try:
|
||||
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
return os.path.exists(path)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _w2v_cache_load(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
|
||||
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
|
||||
try:
|
||||
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
if os.path.exists(path):
|
||||
data = np.load(path)
|
||||
_log(f"audio_scan: cache hit ({path})")
|
||||
return data["timestamps"], data["embeddings"]
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache read failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
hop: float = 1.0, window: float = _WINDOW,
|
||||
video_path: str | None = None,
|
||||
cancel_flag: object = None,
|
||||
model_name: str | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Extract embeddings for all sliding windows using a torchaudio model.
|
||||
|
||||
If video_path is given, results are cached to disk for fast re-scans.
|
||||
Returns (timestamps, embeddings) where embeddings is (N, D).
|
||||
"""
|
||||
edim = _embed_dim(model_name)
|
||||
|
||||
# Try loading from cache
|
||||
cache_file = None
|
||||
if video_path:
|
||||
try:
|
||||
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
if os.path.exists(cache_file):
|
||||
data = np.load(cache_file)
|
||||
_log(f"audio_scan: cache hit ({cache_file})")
|
||||
return data["timestamps"], data["embeddings"]
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache read failed: {e}")
|
||||
|
||||
win_samples = int(window * sr)
|
||||
hop_samples = int(hop * sr)
|
||||
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
|
||||
|
||||
if n_windows == 0:
|
||||
return np.array([]), np.empty((0, edim))
|
||||
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
# Auto-size batches based on available GPU memory
|
||||
batch_size = 16
|
||||
if device == "cuda":
|
||||
try:
|
||||
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
||||
if vram_gb >= 16:
|
||||
batch_size = 64
|
||||
elif vram_gb >= 8:
|
||||
batch_size = 32
|
||||
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
|
||||
except Exception:
|
||||
pass
|
||||
timestamps = np.arange(n_windows) * hop
|
||||
embeddings = []
|
||||
|
||||
for batch_start in range(0, n_windows, batch_size):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return np.array([]), np.empty((0, edim))
|
||||
batch_end = min(batch_start + batch_size, n_windows)
|
||||
chunks = []
|
||||
for i in range(batch_start, batch_end):
|
||||
start = i * hop_samples
|
||||
chunks.append(y[start:start + win_samples])
|
||||
with torch.no_grad():
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
else:
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
|
||||
result_ts = timestamps
|
||||
result_emb = np.vstack(embeddings)
|
||||
|
||||
# Save to cache
|
||||
if cache_file:
|
||||
try:
|
||||
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
|
||||
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
|
||||
_log(f"audio_scan: w2v cache saved ({cache_file})")
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache write failed: {e}")
|
||||
|
||||
return result_ts, result_emb
|
||||
|
||||
|
||||
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
gt_soft: list[float], tolerance: float = 12.0,
|
||||
neg_margin: float = 120.0,
|
||||
model_name: str | None = None,
|
||||
gt_negative: list[float] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Extract embeddings only near positives and distant negatives.
|
||||
|
||||
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
|
||||
"""
|
||||
edim = _embed_dim(model_name)
|
||||
duration = len(y) / sr
|
||||
win_samples = int(_WINDOW * sr)
|
||||
all_gt = list(gt_intense) + list(gt_soft)
|
||||
|
||||
# Positive windows: every second near intense markers
|
||||
pos_times = set()
|
||||
for gt in gt_intense:
|
||||
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||
t = gt + offset
|
||||
if 0 <= t <= duration - _WINDOW:
|
||||
pos_times.add(int(t))
|
||||
|
||||
# Manual negative windows: near explicit negative markers
|
||||
manual_neg_times = set()
|
||||
if gt_negative:
|
||||
for gt in gt_negative:
|
||||
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||
t = gt + offset
|
||||
if 0 <= t <= duration - _WINDOW:
|
||||
manual_neg_times.add(int(t))
|
||||
# Don't let manual negatives overlap with positives
|
||||
manual_neg_times -= pos_times
|
||||
|
||||
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0 or no markers)
|
||||
neg_times = set()
|
||||
if all_gt and neg_margin > 0:
|
||||
for t in range(0, int(duration - _WINDOW), 4):
|
||||
if min(abs(t - g) for g in all_gt) > neg_margin:
|
||||
neg_times.add(t)
|
||||
|
||||
all_times = sorted(pos_times | neg_times | manual_neg_times)
|
||||
# Filter out windows that go past the end
|
||||
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
||||
|
||||
if not valid_times:
|
||||
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
|
||||
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
batch_size = 16
|
||||
timestamps_list: list[float] = []
|
||||
embeddings_list: list[np.ndarray] = []
|
||||
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
|
||||
for batch_start in range(0, len(valid_times), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||
chunks = []
|
||||
for t in valid_times[batch_start:batch_end]:
|
||||
start = int(t * sr)
|
||||
chunks.append(y[start:start + win_samples])
|
||||
timestamps_list.append(float(t))
|
||||
with torch.no_grad():
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
else:
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
|
||||
timestamps = np.array(timestamps_list)
|
||||
embeddings = np.vstack(embeddings_list)
|
||||
|
||||
labels = np.zeros(len(timestamps), dtype=int)
|
||||
for i, t in enumerate(timestamps):
|
||||
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
|
||||
if di < tolerance:
|
||||
labels[i] = 1
|
||||
elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
|
||||
labels[i] = -1
|
||||
return timestamps, embeddings, labels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classifier mode — train / save / load / scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
||||
model_path: str | None = None,
|
||||
tolerance: float = 12.0,
|
||||
neg_margin: float = 120.0,
|
||||
embed_model: str | None = None,
|
||||
cancel_flag: object = None,
|
||||
n_workers: int = 4,
|
||||
progress_cb: object = None) -> dict:
|
||||
"""Train a classifier from labeled videos.
|
||||
|
||||
Args:
|
||||
video_infos: list of (video_path, intense_times, soft_times)
|
||||
model_path: if given, save model to this path
|
||||
tolerance/neg_margin: labeling parameters
|
||||
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
||||
cancel_flag: object with _cancel attribute; if set, training aborts early
|
||||
n_workers: number of threads for parallel audio loading
|
||||
|
||||
Returns:
|
||||
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||
|
||||
def _progress(msg: str) -> None:
|
||||
_log(msg)
|
||||
if progress_cb:
|
||||
progress_cb(msg)
|
||||
|
||||
def _load_audio(path: str) -> np.ndarray:
|
||||
return _load_audio_ffmpeg(path, sr=_SR)
|
||||
|
||||
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
|
||||
n = len(video_infos)
|
||||
load_workers = min(n_workers, 4)
|
||||
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
|
||||
audio_data: dict[int, np.ndarray] = {}
|
||||
with ThreadPoolExecutor(max_workers=load_workers) as pool:
|
||||
future_to_idx = {
|
||||
pool.submit(_load_audio, vi[0]): i
|
||||
for i, vi in enumerate(video_infos)
|
||||
}
|
||||
failed = set()
|
||||
for future in as_completed(future_to_idx):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: training cancelled")
|
||||
return None
|
||||
idx = future_to_idx[future]
|
||||
try:
|
||||
audio_data[idx] = future.result()
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
|
||||
failed.add(idx)
|
||||
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
|
||||
|
||||
# Phase 2: extract embeddings sequentially on GPU
|
||||
_progress(f"Extracting embeddings: 0/{n}")
|
||||
all_X, all_y = [], []
|
||||
for vi, vinfo in enumerate(video_infos):
|
||||
if vi in failed:
|
||||
continue
|
||||
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
|
||||
gt_negative = vinfo[3] if len(vinfo) > 3 else []
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: training cancelled")
|
||||
return None
|
||||
_progress(f"Extracting embeddings: {vi+1}/{n}")
|
||||
y = audio_data.pop(vi)
|
||||
|
||||
timestamps, embeddings, labels = _extract_w2v_targeted(
|
||||
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
||||
model_name=embed_model, gt_negative=gt_negative,
|
||||
)
|
||||
if len(timestamps) == 0:
|
||||
continue
|
||||
# Per-video z-score normalize
|
||||
vid_mean = embeddings.mean(axis=0)
|
||||
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
|
||||
normed = (embeddings - vid_mean) / vid_std
|
||||
for i in range(len(labels)):
|
||||
if labels[i] == 1:
|
||||
all_X.append(normed[i])
|
||||
all_y.append(1)
|
||||
elif labels[i] == -1:
|
||||
all_X.append(normed[i])
|
||||
all_y.append(0)
|
||||
|
||||
if not all_X:
|
||||
_log("audio_scan: no training samples collected")
|
||||
return None
|
||||
|
||||
X = np.stack(all_X)
|
||||
y_arr = np.array(all_y)
|
||||
n_pos = (y_arr == 1).sum()
|
||||
n_neg = (y_arr == 0).sum()
|
||||
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
|
||||
|
||||
if n_pos == 0 or n_neg == 0:
|
||||
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
|
||||
return None
|
||||
|
||||
# Subsample negatives for balance
|
||||
rng = np.random.RandomState(42)
|
||||
pos_idx = np.where(y_arr == 1)[0]
|
||||
neg_idx = np.where(y_arr == 0)[0]
|
||||
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
|
||||
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
|
||||
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||
rng.shuffle(train_idx)
|
||||
|
||||
_progress(f"Fitting classifier on {len(train_idx)} samples...")
|
||||
clf = HistGradientBoostingClassifier(
|
||||
max_iter=200, max_depth=5, learning_rate=0.1, random_state=42,
|
||||
)
|
||||
clf.fit(X[train_idx], y_arr[train_idx])
|
||||
_log("audio_scan: classifier trained")
|
||||
|
||||
# Calibrate probabilities for better threshold behavior
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
min_class = min(int(n_pos), int(n_neg_sample))
|
||||
if min_class >= 6:
|
||||
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||
clf = cal_clf
|
||||
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||
else:
|
||||
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||
|
||||
model = {"classifier": clf, "n_features": X.shape[1],
|
||||
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||
|
||||
if model_path:
|
||||
import joblib
|
||||
from datetime import datetime
|
||||
parent = os.path.dirname(model_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
# Save with timestamp in name; keep a symlink/copy as the "latest"
|
||||
stem, ext = os.path.splitext(model_path)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
versioned = f"{stem}_{ts}{ext}"
|
||||
joblib.dump(model, versioned)
|
||||
_log(f"audio_scan: model saved to {versioned}")
|
||||
# Update the base path to point to latest version (copy)
|
||||
import shutil
|
||||
shutil.copy2(versioned, model_path)
|
||||
_log(f"audio_scan: latest model updated: {model_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_classifier(model_path: str) -> dict | None:
|
||||
"""Load a saved classifier model."""
|
||||
if not os.path.exists(model_path):
|
||||
return None
|
||||
import joblib
|
||||
return joblib.load(model_path)
|
||||
|
||||
|
||||
def default_model_path(profile_name: str = "default",
|
||||
embed_model: str | None = None) -> str:
|
||||
"""Return the path for a profile's classifier model.
|
||||
|
||||
When embed_model is given the file is ``{profile}_{model}.joblib``,
|
||||
otherwise ``{profile}.joblib`` (legacy single-model layout).
|
||||
"""
|
||||
if embed_model:
|
||||
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
|
||||
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||
|
||||
|
||||
def list_model_versions(profile_name: str = "default",
|
||||
embed_model: str | None = None) -> list[tuple[str, str]]:
|
||||
"""Return available backup versions for a model, newest first.
|
||||
|
||||
Returns list of (timestamp_label, file_path).
|
||||
The current (active) model is listed first as "current".
|
||||
"""
|
||||
import re
|
||||
current = default_model_path(profile_name, embed_model)
|
||||
stem, ext = os.path.splitext(current)
|
||||
versions: list[tuple[str, str]] = []
|
||||
if os.path.exists(current):
|
||||
versions.append(("current", current))
|
||||
if not os.path.isdir(_MODEL_DIR):
|
||||
return versions
|
||||
pattern = re.compile(re.escape(os.path.basename(stem)) + r"_(\d{8}_\d{6})" + re.escape(ext) + "$")
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
versions.append((m.group(1), os.path.join(_MODEL_DIR, fname)))
|
||||
# Sort backups newest first (after "current")
|
||||
current_entry = versions[:1]
|
||||
backups = sorted(versions[1:], key=lambda v: v[0], reverse=True)
|
||||
return current_entry + backups
|
||||
|
||||
|
||||
def restore_model_version(version_path: str, profile_name: str = "default",
|
||||
embed_model: str | None = None) -> None:
|
||||
"""Restore a backup version as the active model."""
|
||||
import filecmp, shutil
|
||||
from datetime import datetime
|
||||
current = default_model_path(profile_name, embed_model)
|
||||
if version_path == current:
|
||||
return
|
||||
# Back up current before replacing — but only if no identical backup exists
|
||||
if os.path.exists(current):
|
||||
stem, ext = os.path.splitext(current)
|
||||
already_saved = False
|
||||
if os.path.isdir(_MODEL_DIR):
|
||||
import re
|
||||
pat = re.compile(re.escape(os.path.basename(stem)) + r"_\d{8}_\d{6}" + re.escape(ext) + "$")
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
if pat.match(fname):
|
||||
candidate = os.path.join(_MODEL_DIR, fname)
|
||||
if filecmp.cmp(current, candidate, shallow=False):
|
||||
already_saved = True
|
||||
break
|
||||
if not already_saved:
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
shutil.move(current, f"{stem}_{ts}{ext}")
|
||||
shutil.copy2(version_path, current)
|
||||
_log(f"audio_scan: restored {os.path.basename(version_path)} as active model")
|
||||
|
||||
|
||||
def list_trained_models(profile_name: str = "default") -> list[str]:
|
||||
"""Return embedding model names that have a trained .joblib for *profile_name*.
|
||||
|
||||
Looks for files matching ``{profile}_{MODEL}.joblib`` in the models dir.
|
||||
"""
|
||||
prefix = f"{profile_name}_"
|
||||
suffix = ".joblib"
|
||||
result = []
|
||||
if not os.path.isdir(_MODEL_DIR):
|
||||
return result
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
if fname.startswith(prefix) and fname.endswith(suffix):
|
||||
model_name = fname[len(prefix):-len(suffix)]
|
||||
if model_name in _EMBED_MODELS:
|
||||
result.append(model_name)
|
||||
# Also check legacy {profile}.joblib
|
||||
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||
if os.path.exists(legacy) and not result:
|
||||
# Legacy model — we don't know the embed model, but it's usable
|
||||
result.append("")
|
||||
return sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scanning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fuse_regions(regions: list[tuple[float, float, float]]
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Merge overlapping/adjacent regions, keeping max score."""
|
||||
if not regions:
|
||||
return []
|
||||
by_start = sorted(regions, key=lambda r: r[0])
|
||||
fused: list[tuple[float, float, float]] = []
|
||||
s, e, sc = by_start[0]
|
||||
for s2, e2, sc2 in by_start[1:]:
|
||||
if s2 <= e: # overlapping or touching
|
||||
e = max(e, e2)
|
||||
sc = max(sc, sc2)
|
||||
else:
|
||||
fused.append((s, e, sc))
|
||||
s, e, sc = s2, e2, sc2
|
||||
fused.append((s, e, sc))
|
||||
return fused
|
||||
|
||||
|
||||
def prefetch_audio(video_path: str, embed_model: str | None = None,
|
||||
hop: float = 1.0, window: float = _WINDOW) -> np.ndarray | None:
|
||||
"""Pre-load audio for a video if embeddings aren't cached.
|
||||
|
||||
Returns the raw audio array, or None if cache already exists.
|
||||
Call from a background thread while the GPU is busy with another video.
|
||||
"""
|
||||
if _w2v_cache_exists(video_path, hop, window, embed_model):
|
||||
return None
|
||||
_log(f"audio_scan: prefetching {os.path.basename(video_path)}")
|
||||
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||
_log(f"audio_scan: prefetched {len(y)/_SR:.1f}s")
|
||||
return y
|
||||
|
||||
|
||||
def scan_video(
|
||||
video_path: str,
|
||||
model: dict = None,
|
||||
threshold: float = 0.50,
|
||||
hop: float = 1.0,
|
||||
window: float = _WINDOW,
|
||||
cancel_flag: object = None,
|
||||
prefetched_audio: np.ndarray | None = None,
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Scan a video for matching audio regions using a trained classifier.
|
||||
|
||||
Returns list of (start_time, end_time, score) above threshold.
|
||||
If prefetched_audio is provided, skips the ffmpeg decode step.
|
||||
"""
|
||||
if model is None:
|
||||
_log("audio_scan: no model provided")
|
||||
return []
|
||||
|
||||
clf = model["classifier"]
|
||||
embed_model = model.get("embed_model")
|
||||
|
||||
# Try cache first — skip expensive audio loading if embeddings exist
|
||||
cached = _w2v_cache_load(video_path, hop, window, embed_model)
|
||||
if cached is not None:
|
||||
timestamps, window_vectors = cached
|
||||
else:
|
||||
if prefetched_audio is not None:
|
||||
_log(f"audio_scan: using prefetched audio")
|
||||
y = prefetched_audio
|
||||
else:
|
||||
_log(f"audio_scan: loading {video_path}")
|
||||
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||
sr = _SR
|
||||
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
|
||||
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return []
|
||||
|
||||
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
||||
timestamps, window_vectors = _extract_w2v_windows(
|
||||
y, sr, hop=hop, window=window, video_path=video_path,
|
||||
cancel_flag=cancel_flag, model_name=embed_model,
|
||||
)
|
||||
|
||||
if len(timestamps) == 0:
|
||||
_log("audio_scan: video shorter than window")
|
||||
return []
|
||||
|
||||
# Per-video z-score normalize
|
||||
vid_mean = window_vectors.mean(axis=0)
|
||||
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
|
||||
normed = (window_vectors - vid_mean) / vid_std
|
||||
|
||||
_log(f"audio_scan: classifying {len(normed)} windows...")
|
||||
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return []
|
||||
|
||||
probs = clf.predict_proba(normed)[:, 1]
|
||||
mask = probs >= threshold
|
||||
raw = [
|
||||
(timestamps[i], timestamps[i] + window, float(probs[i]))
|
||||
for i in np.nonzero(mask)[0]
|
||||
]
|
||||
results = _fuse_regions(raw)
|
||||
_log(f"audio_scan: {len(results)} regions above threshold {threshold} (from {len(raw)} raw)")
|
||||
return results
|
||||
@@ -0,0 +1,783 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm, Parameter
|
||||
from .beats_modules import (
|
||||
GradMultiply,
|
||||
SamePad,
|
||||
get_activation_fn,
|
||||
GLU_Linear,
|
||||
quant_noise,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
|
||||
self.pos_conv = nn.Conv1d(
|
||||
self.embedding_dim,
|
||||
self.embedding_dim,
|
||||
kernel_size=args.conv_pos,
|
||||
padding=args.conv_pos // 2,
|
||||
groups=args.conv_pos_groups,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
self.relative_position_embedding = args.relative_position_embedding
|
||||
self.num_buckets = args.num_buckets
|
||||
self.max_distance = args.max_distance
|
||||
else:
|
||||
self.relative_position_embedding = False
|
||||
self.num_buckets = 0
|
||||
self.max_distance = 0
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
deep_norm=args.deep_norm,
|
||||
has_relative_attention_bias=self.relative_position_embedding,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
gru_rel_pos=args.gru_rel_pos,
|
||||
encoder_layers=args.encoder_layers,
|
||||
)
|
||||
for i in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
if self.relative_position_embedding:
|
||||
for i in range(1, args.encoder_layers):
|
||||
del self.layers[i].self_attn.relative_attention_bias
|
||||
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
if args.deep_norm:
|
||||
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||
for i in range(args.encoder_layers):
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||
|
||||
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
z = None
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
r = None
|
||||
pos_bias = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||
dropout_probability = np.random.random()
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: float = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
deep_norm: bool = False,
|
||||
has_relative_attention_bias: bool = False,
|
||||
num_buckets: int = 0,
|
||||
max_distance: int = 0,
|
||||
rescale_init: bool = False,
|
||||
gru_rel_pos: bool = False,
|
||||
encoder_layers: int = 0,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
self.activation_name = activation_fn
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
num_buckets=num_buckets,
|
||||
max_distance=max_distance,
|
||||
rescale_init=rescale_init,
|
||||
gru_rel_pos=gru_rel_pos,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
if self.activation_name == "glu":
|
||||
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||
else:
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
self.deep_norm = deep_norm
|
||||
if self.deep_norm:
|
||||
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||
else:
|
||||
self.deep_norm_alpha = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
pos_bias=None
|
||||
):
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, pos_bias
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
has_relative_attention_bias=False,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
gru_rel_pos=False,
|
||||
rescale_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_head_dim = self.head_dim
|
||||
self.k_head_dim = self.head_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
k_bias = True
|
||||
if rescale_init:
|
||||
k_bias = False
|
||||
|
||||
k_embed_dim = embed_dim
|
||||
q_embed_dim = embed_dim
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.gru_rel_pos = gru_rel_pos
|
||||
if self.gru_rel_pos:
|
||||
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
if self.has_relative_attention_bias:
|
||||
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||
|
||||
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||
num_buckets = self.num_buckets
|
||||
max_distance = self.max_distance
|
||||
relative_buckets = 0
|
||||
|
||||
if bidirectional:
|
||||
num_buckets = num_buckets // 2
|
||||
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||
relative_positions = torch.abs(relative_positions)
|
||||
else:
|
||||
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_positions < max_exact
|
||||
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_positions.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_positions_bucket(
|
||||
relative_position,
|
||||
bidirectional=True
|
||||
)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
values = values.permute([2, 0, 1])
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
position_bias: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if self.has_relative_attention_bias and position_bias is None:
|
||||
position_bias = self.compute_bias(tgt_len, src_len)
|
||||
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
alpha = 32
|
||||
q *= 1 / alpha
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||
key_padding_mask
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v, position_bias
|
||||
|
||||
if position_bias is not None:
|
||||
attn_mask_rel_pos = position_bias
|
||||
if self.gru_rel_pos == 1:
|
||||
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||
|
||||
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||
|
||||
attn_weights = attn_weights + attn_mask_rel_pos
|
||||
|
||||
attn_weights_float = F.softmax(
|
||||
attn_weights, dim=-1
|
||||
)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights, position_bias
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(
|
||||
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
@@ -0,0 +1,179 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from .beats_backbone import (
|
||||
TransformerEncoder,
|
||||
)
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEATsConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.input_patch_size: int = -1 # path size of patch embedding
|
||||
self.embed_dim: int = 512 # patch embedding dimension
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
# label predictor
|
||||
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||
self.predictor_class: int = 527 # target class number for the predictor
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class BEATs(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: BEATsConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
self.embed = cfg.embed_dim
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.input_patch_size = cfg.input_patch_size
|
||||
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||
bias=cfg.conv_bias)
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
|
||||
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
if cfg.finetuned_model:
|
||||
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||
else:
|
||||
self.predictor = None
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(
|
||||
padding_mask.size(0), features.size(1), -1
|
||||
)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
) -> torch.Tensor:
|
||||
fbanks = []
|
||||
for waveform in source:
|
||||
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||
fbanks.append(fbank)
|
||||
fbank = torch.stack(fbanks, dim=0)
|
||||
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||
return fbank
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
):
|
||||
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||
|
||||
fbank = fbank.unsqueeze(1)
|
||||
features = self.patch_embedding(fbank)
|
||||
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
x = self.dropout_input(features)
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if self.predictor is not None:
|
||||
x = self.predictor_dropout(x)
|
||||
logits = self.predictor(x)
|
||||
|
||||
if padding_mask is not None and padding_mask.any():
|
||||
logits[padding_mask] = 0
|
||||
logits = logits.sum(dim=1)
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||
else:
|
||||
logits = logits.mean(dim=1)
|
||||
|
||||
lprobs = torch.sigmoid(logits)
|
||||
|
||||
return lprobs, padding_mask
|
||||
else:
|
||||
return x, padding_mask
|
||||
@@ -0,0 +1,219 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
self.act = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.act(x)
|
||||
|
||||
|
||||
class GLU_Linear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||
super(GLU_Linear, self).__init__()
|
||||
|
||||
self.glu_type = glu_type
|
||||
self.output_dim = output_dim
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = torch.nn.Sigmoid()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = torch.nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = torch.nn.GELU()
|
||||
|
||||
if bias_in_glu:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||
else:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||
|
||||
def forward(self, x):
|
||||
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||
x = self.linear(x)
|
||||
|
||||
if self.glu_type == "bilinear":
|
||||
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||
else:
|
||||
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
warnings.warn(
|
||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||
)
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "glu":
|
||||
return lambda x: x
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features, device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
+905
@@ -0,0 +1,905 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from .paths import _log
|
||||
|
||||
|
||||
class ProcessedDB:
|
||||
_SCHEMA_VERSION = 4 # bump when schema changes
|
||||
|
||||
def __init__(self, db_path: str | None = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path.home() / ".8cut.db")
|
||||
self._path = db_path
|
||||
self._lock = threading.Lock()
|
||||
try:
|
||||
self._con = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._migrate()
|
||||
self._enabled = True
|
||||
_log(f"DB opened: {db_path}")
|
||||
except Exception as e:
|
||||
_log(f"DB unavailable: {e}")
|
||||
self._con = None
|
||||
self._enabled = False
|
||||
|
||||
def _migrate(self) -> None:
|
||||
"""Create table if missing, then add any new columns for old DBs."""
|
||||
cols = {
|
||||
row[1]
|
||||
for row in self._con.execute("PRAGMA table_info(processed)").fetchall()
|
||||
}
|
||||
if not cols:
|
||||
# Fresh DB — create from scratch
|
||||
self._con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS processed ("
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||
" filename TEXT NOT NULL,"
|
||||
" start_time REAL NOT NULL,"
|
||||
" output_path TEXT NOT NULL,"
|
||||
" label TEXT NOT NULL DEFAULT '',"
|
||||
" category TEXT NOT NULL DEFAULT '',"
|
||||
" short_side INTEGER DEFAULT 512,"
|
||||
" portrait_ratio TEXT NOT NULL DEFAULT '',"
|
||||
" crop_center REAL NOT NULL DEFAULT 0.5,"
|
||||
" format TEXT NOT NULL DEFAULT 'MP4',"
|
||||
" clip_count INTEGER NOT NULL DEFAULT 3,"
|
||||
" spread REAL NOT NULL DEFAULT 3.0,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" source_path TEXT NOT NULL DEFAULT '',"
|
||||
" scan_export INTEGER NOT NULL DEFAULT 0,"
|
||||
" processed_at TEXT NOT NULL"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
# Add missing columns to legacy tables
|
||||
new_cols = {
|
||||
"label": "TEXT NOT NULL DEFAULT ''",
|
||||
"category": "TEXT NOT NULL DEFAULT ''",
|
||||
"short_side": "INTEGER DEFAULT 512",
|
||||
"portrait_ratio": "TEXT NOT NULL DEFAULT ''",
|
||||
"crop_center": "REAL NOT NULL DEFAULT 0.5",
|
||||
"format": "TEXT NOT NULL DEFAULT 'MP4'",
|
||||
"clip_count": "INTEGER NOT NULL DEFAULT 3",
|
||||
"spread": "REAL NOT NULL DEFAULT 3.0",
|
||||
"profile": "TEXT NOT NULL DEFAULT 'default'",
|
||||
"source_path": "TEXT NOT NULL DEFAULT ''",
|
||||
"scan_export": "INTEGER NOT NULL DEFAULT 0",
|
||||
}
|
||||
for col, typedef in new_cols.items():
|
||||
if col not in cols:
|
||||
self._con.execute(
|
||||
f"ALTER TABLE processed ADD COLUMN {col} {typedef}"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_filename ON processed(filename)"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS hidden_files ("
|
||||
" filename TEXT NOT NULL,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" PRIMARY KEY (filename, profile)"
|
||||
")"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS scan_results ("
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||
" filename TEXT NOT NULL,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" model TEXT NOT NULL,"
|
||||
" start_time REAL NOT NULL,"
|
||||
" end_time REAL NOT NULL,"
|
||||
" score REAL NOT NULL,"
|
||||
" disabled INTEGER NOT NULL DEFAULT 0,"
|
||||
" orig_start_time REAL,"
|
||||
" orig_end_time REAL,"
|
||||
" scan_timestamp TEXT NOT NULL DEFAULT ''"
|
||||
")"
|
||||
)
|
||||
# Migrate: add new columns to existing scan_results tables
|
||||
sr_cols = {
|
||||
row[1]
|
||||
for row in self._con.execute("PRAGMA table_info(scan_results)").fetchall()
|
||||
}
|
||||
for col, typedef in [
|
||||
("disabled", "INTEGER NOT NULL DEFAULT 0"),
|
||||
("orig_start_time", "REAL"),
|
||||
("orig_end_time", "REAL"),
|
||||
("scan_timestamp", "TEXT NOT NULL DEFAULT ''"),
|
||||
]:
|
||||
if col not in sr_cols:
|
||||
self._con.execute(
|
||||
f"ALTER TABLE scan_results ADD COLUMN {col} {typedef}"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_scan_file_profile_model"
|
||||
" ON scan_results(filename, profile, model)"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS hard_negatives ("
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||
" filename TEXT NOT NULL,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" start_time REAL NOT NULL,"
|
||||
" source_path TEXT NOT NULL DEFAULT '',"
|
||||
" source_model TEXT NOT NULL DEFAULT ''"
|
||||
")"
|
||||
)
|
||||
# Migrate: add source_model column to existing hard_negatives tables
|
||||
hn_cols = {
|
||||
row[1]
|
||||
for row in self._con.execute("PRAGMA table_info(hard_negatives)").fetchall()
|
||||
}
|
||||
if "source_model" not in hn_cols:
|
||||
self._con.execute(
|
||||
"ALTER TABLE hard_negatives ADD COLUMN source_model TEXT NOT NULL DEFAULT ''"
|
||||
)
|
||||
self._con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_hardneg_file_profile"
|
||||
" ON hard_negatives(filename, profile)"
|
||||
)
|
||||
self._con.commit()
|
||||
self._migrate_vid_folders()
|
||||
|
||||
def _migrate_vid_folders(self) -> None:
|
||||
"""Migrate old clip_NNN group dirs → vid_NNN per-video folders.
|
||||
|
||||
Old layout: export_folder/clip_NNN/clip_NNN_sub.mp4
|
||||
New layout: export_folder/vid_NNN/clip_NNN_sub.mp4
|
||||
|
||||
Rewrites output_path in DB and moves files on disk.
|
||||
"""
|
||||
# Check if any rows still use the old clip_NNN parent dir layout
|
||||
row = self._con.execute(
|
||||
"SELECT id FROM processed WHERE output_path LIKE '%/clip_%/%' LIMIT 1"
|
||||
).fetchone()
|
||||
if not row:
|
||||
return
|
||||
|
||||
_log("Migrating old clip group dirs → vid folders …")
|
||||
rows = self._con.execute(
|
||||
"SELECT id, filename, profile, output_path FROM processed"
|
||||
" ORDER BY profile, filename, output_path"
|
||||
).fetchall()
|
||||
|
||||
# Assign vid_NNN per (profile, export_folder, filename)
|
||||
vid_map: dict[tuple, str] = {}
|
||||
vid_counters: dict[tuple, int] = {}
|
||||
|
||||
for rid, filename, profile, op in rows:
|
||||
parent = os.path.dirname(op)
|
||||
export_folder = os.path.dirname(parent)
|
||||
key = (profile, export_folder, filename)
|
||||
if key not in vid_map:
|
||||
counter_key = (profile, export_folder)
|
||||
n = vid_counters.get(counter_key, 1)
|
||||
vid_map[key] = f"vid_{n:03d}"
|
||||
vid_counters[counter_key] = n + 1
|
||||
|
||||
updates: list[tuple[str, int]] = []
|
||||
moves: list[tuple[str, str]] = []
|
||||
dirs_to_create: set[str] = set()
|
||||
old_dirs: set[str] = set()
|
||||
|
||||
for rid, filename, profile, op in rows:
|
||||
parent = os.path.dirname(op)
|
||||
parent_name = os.path.basename(parent)
|
||||
# Skip rows already using vid_NNN layout
|
||||
if parent_name.startswith("vid_"):
|
||||
continue
|
||||
export_folder = os.path.dirname(parent)
|
||||
key = (profile, export_folder, filename)
|
||||
vid_name = vid_map[key]
|
||||
new_path = os.path.join(export_folder, vid_name, os.path.basename(op))
|
||||
updates.append((new_path, rid))
|
||||
dirs_to_create.add(os.path.join(export_folder, vid_name))
|
||||
old_dirs.add(parent)
|
||||
if os.path.exists(op):
|
||||
moves.append((op, new_path))
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
# Create vid directories
|
||||
for d in sorted(dirs_to_create):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
# Move files
|
||||
import shutil
|
||||
for old, new in moves:
|
||||
if os.path.exists(old) and not os.path.exists(new):
|
||||
shutil.move(old, new)
|
||||
|
||||
# Update DB
|
||||
self._con.executemany(
|
||||
"UPDATE processed SET output_path = ? WHERE id = ?", updates
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
# Remove empty old group directories
|
||||
for d in sorted(old_dirs, reverse=True):
|
||||
try:
|
||||
if os.path.isdir(d) and not os.listdir(d):
|
||||
os.rmdir(d)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
_log(f"Migrated {len(updates)} rows, moved {len(moves)} files to vid folders")
|
||||
|
||||
def add(self, filename: str, start_time: float, output_path: str,
|
||||
label: str = "", category: str = "",
|
||||
short_side: int | None = None, portrait_ratio: str = "",
|
||||
crop_center: float = 0.5, fmt: str = "MP4",
|
||||
clip_count: int = 3, spread: float = 3.0,
|
||||
profile: str = "default", source_path: str = "",
|
||||
scan_export: bool = False) -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
"INSERT INTO processed"
|
||||
" (filename, start_time, output_path, label, category,"
|
||||
" short_side, portrait_ratio, crop_center, format,"
|
||||
" clip_count, spread, profile, source_path, scan_export, processed_at)"
|
||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(filename, start_time, output_path, label, category,
|
||||
short_side, portrait_ratio, crop_center, fmt,
|
||||
clip_count, spread, profile, source_path,
|
||||
1 if scan_export else 0,
|
||||
datetime.now(timezone.utc).isoformat()),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_labels(self) -> list[str]:
|
||||
"""Return distinct non-empty labels ordered by most recently used."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT label FROM processed"
|
||||
" WHERE label != '' ORDER BY processed_at DESC"
|
||||
).fetchall()
|
||||
# Deduplicate while preserving order (DISTINCT on processed_at DESC
|
||||
# may return duplicates if the same label was used multiple times).
|
||||
seen: set[str] = set()
|
||||
result = []
|
||||
for (lbl,) in rows:
|
||||
if lbl not in seen:
|
||||
seen.add(lbl)
|
||||
result.append(lbl)
|
||||
return result
|
||||
|
||||
def get_by_output_path(self, output_path: str) -> dict | None:
|
||||
"""Return config dict for an output_path, or None."""
|
||||
if not self._enabled:
|
||||
return None
|
||||
cur = self._con.cursor()
|
||||
cur.row_factory = sqlite3.Row
|
||||
row = cur.execute(
|
||||
"SELECT label, category, short_side, portrait_ratio, crop_center, format,"
|
||||
" clip_count, spread"
|
||||
" FROM processed WHERE output_path = ?",
|
||||
(output_path,),
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def delete_by_output_path(self, output_path: str) -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute("DELETE FROM processed WHERE output_path = ?", (output_path,))
|
||||
self._con.commit()
|
||||
|
||||
def get_group(self, output_path: str, profile: str = "") -> list[str]:
|
||||
"""Return all output_paths sharing the same (filename, start_time, profile) as *output_path*."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
row = self._con.execute(
|
||||
"SELECT filename, start_time, profile FROM processed WHERE output_path = ?",
|
||||
(output_path,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return []
|
||||
filename, start_time, row_profile = row
|
||||
p = profile or row_profile
|
||||
rows = self._con.execute(
|
||||
"SELECT output_path FROM processed"
|
||||
" WHERE filename = ? AND start_time = ? AND profile = ? ORDER BY output_path",
|
||||
(filename, start_time, p),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def delete_group(self, output_path: str, profile: str = "") -> list[str]:
|
||||
"""Delete all rows sharing the same (filename, start_time, profile) as *output_path*.
|
||||
Returns list of deleted output_paths."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
with self._lock:
|
||||
row = self._con.execute(
|
||||
"SELECT filename, start_time, profile FROM processed WHERE output_path = ?",
|
||||
(output_path,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return []
|
||||
filename, start_time, row_profile = row
|
||||
p = profile or row_profile
|
||||
paths = [r[0] for r in self._con.execute(
|
||||
"SELECT output_path FROM processed"
|
||||
" WHERE filename = ? AND start_time = ? AND profile = ?",
|
||||
(filename, start_time, p),
|
||||
).fetchall()]
|
||||
self._con.execute(
|
||||
"DELETE FROM processed WHERE filename = ? AND start_time = ? AND profile = ?",
|
||||
(filename, start_time, p),
|
||||
)
|
||||
self._con.commit()
|
||||
return paths
|
||||
|
||||
def _get_markers_for(self, match: str, profile: str = "default") -> list[tuple[float, int, str]]:
|
||||
rows = self._con.execute(
|
||||
"SELECT start_time, output_path FROM processed"
|
||||
" WHERE filename = ? AND profile = ? AND scan_export = 0"
|
||||
" ORDER BY start_time",
|
||||
(match, profile),
|
||||
).fetchall()
|
||||
# Deduplicate by start_time — batch exports share the same cursor.
|
||||
seen_times: dict[float, tuple[float, int, str]] = {}
|
||||
n = 0
|
||||
for t, p in rows:
|
||||
if t not in seen_times:
|
||||
n += 1
|
||||
seen_times[t] = (t, n, p)
|
||||
return list(seen_times.values())
|
||||
|
||||
def get_markers(self, filename: str, profile: str = "default") -> list[tuple[float, int, str]]:
|
||||
"""Return [(start_time, marker_number, output_path), ...] for exact
|
||||
filename match, sorted by start_time. Empty list if no match.
|
||||
Excludes scan exports (shown via scan panel instead)."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
return self._get_markers_for(filename, profile)
|
||||
|
||||
def get_clip_count(self, filename: str, profile: str = "default") -> int:
|
||||
"""Return total number of exported clips (including scan exports)."""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
row = self._con.execute(
|
||||
"SELECT COUNT(*) FROM processed WHERE filename = ? AND profile = ?",
|
||||
(filename, profile),
|
||||
).fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_profiles(self) -> list[str]:
|
||||
"""Return distinct profile names, ordered alphabetically."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT profile FROM processed ORDER BY profile"
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||
"""Return all unique output_path values for a given profile."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_max_counter(self, folder: str, name: str) -> int:
|
||||
"""Return the highest counter N found in output_paths matching folder/name_NNN*.
|
||||
|
||||
Parses the counter from filenames (e.g. 'clip_035_0.mp4' → 35).
|
||||
*folder* is typically the vid folder. Returns 0 if no matches exist.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
prefix = os.path.join(folder, name + "_")
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed"
|
||||
" WHERE output_path LIKE ?",
|
||||
(prefix + "%",),
|
||||
).fetchall()
|
||||
max_n = 0
|
||||
name_prefix = name + "_"
|
||||
for (op,) in rows:
|
||||
stem = os.path.splitext(os.path.basename(op))[0]
|
||||
# stem: "clip_035_0" or "clip_036_a1_0"
|
||||
if not stem.startswith(name_prefix):
|
||||
continue
|
||||
rest = stem[len(name_prefix):] # "035_0" or "036_a1_0"
|
||||
counter_str = rest.split("_")[0]
|
||||
try:
|
||||
max_n = max(max_n, int(counter_str))
|
||||
except ValueError:
|
||||
pass
|
||||
return max_n
|
||||
|
||||
def delete_scan_exports(self, filename: str, profile: str) -> int:
|
||||
"""Delete all scan_export entries for *filename* in *profile*.
|
||||
|
||||
Returns the number of rows deleted.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
cur = self._con.execute(
|
||||
"DELETE FROM processed"
|
||||
" WHERE filename = ? AND profile = ? AND scan_export = 1",
|
||||
(filename, profile),
|
||||
)
|
||||
self._con.commit()
|
||||
return cur.rowcount
|
||||
|
||||
def get_vid_folder(self, filename: str, profile: str,
|
||||
export_folder: str) -> str:
|
||||
"""Return the vid_NNN folder name for a source video.
|
||||
|
||||
Checks existing DB output_paths first; if the video already has a
|
||||
vid_NNN folder, returns it. Otherwise assigns max(existing) + 1,
|
||||
also checking disk for orphan vid folders.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return "vid_001"
|
||||
# Use the most recent entry (ORDER BY rowid DESC) for determinism
|
||||
# when a file has entries across multiple vid folders.
|
||||
row = self._con.execute(
|
||||
"SELECT output_path FROM processed"
|
||||
" WHERE filename = ? AND profile = ?"
|
||||
" ORDER BY rowid DESC LIMIT 1",
|
||||
(filename, profile),
|
||||
).fetchone()
|
||||
if row:
|
||||
parent = os.path.basename(os.path.dirname(row[0]))
|
||||
if parent.startswith("vid_"):
|
||||
return parent
|
||||
# Collect max vid_NNN number from DB + disk (never reuse old numbers)
|
||||
max_n = 0
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
for (op,) in rows:
|
||||
p = os.path.basename(os.path.dirname(op))
|
||||
if p.startswith("vid_"):
|
||||
try:
|
||||
max_n = max(max_n, int(p.split("_")[1]))
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
if os.path.isdir(export_folder):
|
||||
for d in os.listdir(export_folder):
|
||||
if d.startswith("vid_") and os.path.isdir(
|
||||
os.path.join(export_folder, d)
|
||||
):
|
||||
try:
|
||||
max_n = max(max_n, int(d.split("_")[1]))
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
return f"vid_{max_n + 1:03d}"
|
||||
|
||||
def get_export_folders(self, profile: str = "default",
|
||||
include_scan_exports: bool = False) -> list[str]:
|
||||
"""Return distinct export folder names found in output_paths for a profile.
|
||||
|
||||
Export paths follow the structure:
|
||||
.../export_folder/vid_NNN/clip.mp4
|
||||
The export folder is 2 levels up from the clip file.
|
||||
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
|
||||
"""
|
||||
if not self._enabled:
|
||||
return []
|
||||
if include_scan_exports:
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed"
|
||||
" WHERE profile = ? AND scan_export = 0",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
folder_names: set[str] = set()
|
||||
for (op,) in rows:
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent:
|
||||
folder_names.add(grandparent)
|
||||
return sorted(folder_names)
|
||||
|
||||
def get_training_data(self, profile: str, positive_folder: str,
|
||||
negative_folder: str = "",
|
||||
fallback_video_dir: str = "",
|
||||
include_scan_exports: bool = False,
|
||||
use_hard_negatives: bool = True,
|
||||
) -> list[tuple[str, list[float], list[float], list[float]]]:
|
||||
"""Build training video_infos from DB data.
|
||||
|
||||
Args:
|
||||
profile: profile name
|
||||
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
|
||||
negative_folder: export folder name for explicit negatives (optional)
|
||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||
include_scan_exports: if True, include auto-exported scan clips
|
||||
use_hard_negatives: if False, skip hard negatives from scan feedback
|
||||
|
||||
Returns:
|
||||
list of (source_video_path, positive_times, soft_times, negative_times)
|
||||
per video. Soft times = clips from any other non-negative folder.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return []
|
||||
if include_scan_exports:
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, start_time, output_path, source_path"
|
||||
" FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, start_time, output_path, source_path"
|
||||
" FROM processed WHERE profile = ? AND scan_export = 0",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
|
||||
# Collect times by video, split by folder role
|
||||
pos_by_video: dict[str, set[float]] = {}
|
||||
neg_by_video: dict[str, set[float]] = {}
|
||||
soft_by_video: dict[str, set[float]] = {}
|
||||
source_by_filename: dict[str, str] = {}
|
||||
|
||||
for fn, st, op, sp in rows:
|
||||
if sp:
|
||||
source_by_filename[fn] = sp
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent == positive_folder:
|
||||
pos_by_video.setdefault(fn, set()).add(st)
|
||||
elif negative_folder and grandparent == negative_folder:
|
||||
neg_by_video.setdefault(fn, set()).add(st)
|
||||
else:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
# Include hard negatives from scan feedback
|
||||
if use_hard_negatives:
|
||||
hard_rows = self._con.execute(
|
||||
"SELECT filename, start_time, source_path FROM hard_negatives"
|
||||
" WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
for fn, st, sp in hard_rows:
|
||||
neg_by_video.setdefault(fn, set()).add(st)
|
||||
if sp:
|
||||
source_by_filename.setdefault(fn, sp)
|
||||
|
||||
# Remove positive times from soft/neg to avoid conflicting labels
|
||||
for fn in pos_by_video:
|
||||
if fn in soft_by_video:
|
||||
soft_by_video[fn] -= pos_by_video[fn]
|
||||
if fn in neg_by_video:
|
||||
neg_by_video[fn] -= pos_by_video[fn]
|
||||
|
||||
# Deduplicate nearby markers (spread clips from same position)
|
||||
def _dedup_times(times: set[float], min_gap: float = 8.0) -> list[float]:
|
||||
if not times:
|
||||
return []
|
||||
ordered = sorted(times)
|
||||
result = [ordered[0]]
|
||||
for t in ordered[1:]:
|
||||
if t - result[-1] >= min_gap:
|
||||
result.append(t)
|
||||
return result
|
||||
|
||||
# Include videos that have positives OR explicit negatives
|
||||
all_videos = set(pos_by_video) | set(neg_by_video)
|
||||
result = []
|
||||
for fn in all_videos:
|
||||
sp = source_by_filename.get(fn, "")
|
||||
if not sp or not os.path.exists(sp):
|
||||
if fallback_video_dir:
|
||||
sp = os.path.join(fallback_video_dir, fn)
|
||||
if not sp or not os.path.exists(sp):
|
||||
continue
|
||||
gt_pos = _dedup_times(pos_by_video.get(fn, set()))
|
||||
gt_soft = _dedup_times(soft_by_video.get(fn, set()))
|
||||
gt_neg = _dedup_times(neg_by_video.get(fn, set()))
|
||||
result.append((sp, gt_pos, gt_soft, gt_neg))
|
||||
return result
|
||||
|
||||
def get_training_stats(self, profile: str,
|
||||
include_scan_exports: bool = False) -> dict[str, dict]:
|
||||
"""Return per-subprofile stats for training readiness display.
|
||||
|
||||
Returns dict mapping subprofile_name → {
|
||||
'videos': number of distinct source videos,
|
||||
'clips': total clip count,
|
||||
}
|
||||
"""
|
||||
if not self._enabled:
|
||||
return {}
|
||||
if include_scan_exports:
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, output_path FROM processed"
|
||||
" WHERE profile = ? AND scan_export = 0",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
folders = self.get_export_folders(profile, include_scan_exports=include_scan_exports)
|
||||
stats: dict[str, dict] = {}
|
||||
for folder_name in folders:
|
||||
videos: set[str] = set()
|
||||
clips = 0
|
||||
for fn, op in rows:
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent == folder_name:
|
||||
videos.add(fn)
|
||||
clips += 1
|
||||
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||
|
||||
# ── Scan results ─────────────────────────────────────────────
|
||||
|
||||
def save_scan_results(self, filename: str, profile: str, model: str,
|
||||
regions: list[tuple[float, float, float]],
|
||||
max_versions: int = 5) -> None:
|
||||
"""Save scan results as a new version for (filename, profile, model).
|
||||
|
||||
regions: list of (start_time, end_time, score).
|
||||
Keeps up to max_versions; oldest are pruned automatically.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
with self._lock:
|
||||
self._con.executemany(
|
||||
"INSERT INTO scan_results"
|
||||
" (filename, profile, model, start_time, end_time, score,"
|
||||
" orig_start_time, orig_end_time, scan_timestamp)"
|
||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
[(filename, profile, model, s, e, sc, s, e, ts)
|
||||
for s, e, sc in regions],
|
||||
)
|
||||
# Prune old versions beyond max_versions
|
||||
versions = self._con.execute(
|
||||
"SELECT DISTINCT scan_timestamp FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||
" ORDER BY scan_timestamp DESC",
|
||||
(filename, profile, model),
|
||||
).fetchall()
|
||||
if len(versions) > max_versions:
|
||||
old_ts = [v[0] for v in versions[max_versions:]]
|
||||
self._con.execute(
|
||||
"DELETE FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||
f" AND scan_timestamp IN ({','.join('?' * len(old_ts))})",
|
||||
(filename, profile, model, *old_ts),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_scan_versions(self, filename: str, profile: str, model: str
|
||||
) -> list[dict]:
|
||||
"""Return list of scan versions for (filename, profile, model).
|
||||
|
||||
Returns [{timestamp, count, max_score}, ...] ordered newest first.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT scan_timestamp, COUNT(*), MAX(score)"
|
||||
" FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ? AND model = ?"
|
||||
" AND scan_timestamp != ''"
|
||||
" GROUP BY scan_timestamp"
|
||||
" ORDER BY scan_timestamp DESC",
|
||||
(filename, profile, model),
|
||||
).fetchall()
|
||||
return [{"timestamp": ts, "count": cnt, "max_score": sc}
|
||||
for ts, cnt, sc in rows]
|
||||
|
||||
def get_scan_results(self, filename: str, profile: str,
|
||||
scan_timestamp: str | None = None
|
||||
) -> dict[str, list[tuple[int, float, float, float, bool, float, float]]]:
|
||||
"""Return scan results grouped by model.
|
||||
|
||||
If scan_timestamp is given, returns only that version's rows.
|
||||
Otherwise returns the latest version per model.
|
||||
|
||||
Returns {model: [(row_id, start, end, score, disabled, orig_start, orig_end), ...]}
|
||||
sorted by start_time.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return {}
|
||||
if scan_timestamp:
|
||||
rows = self._con.execute(
|
||||
"SELECT id, model, start_time, end_time, score, disabled,"
|
||||
" orig_start_time, orig_end_time"
|
||||
" FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ? AND scan_timestamp = ?"
|
||||
" ORDER BY model, start_time",
|
||||
(filename, profile, scan_timestamp),
|
||||
).fetchall()
|
||||
else:
|
||||
# For each model, get rows from the latest timestamp only
|
||||
rows = self._con.execute(
|
||||
"SELECT r.id, r.model, r.start_time, r.end_time, r.score,"
|
||||
" r.disabled, r.orig_start_time, r.orig_end_time"
|
||||
" FROM scan_results r"
|
||||
" INNER JOIN ("
|
||||
" SELECT model, MAX(scan_timestamp) AS latest"
|
||||
" FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ?"
|
||||
" GROUP BY model"
|
||||
" ) m ON r.model = m.model AND r.scan_timestamp = m.latest"
|
||||
" WHERE r.filename = ? AND r.profile = ?"
|
||||
" ORDER BY r.model, r.start_time",
|
||||
(filename, profile, filename, profile),
|
||||
).fetchall()
|
||||
result: dict[str, list[tuple[int, float, float, float, bool, float, float]]] = {}
|
||||
for row_id, model, s, e, sc, dis, os_, oe in rows:
|
||||
# Fall back to current bounds for legacy rows without orig
|
||||
result.setdefault(model, []).append(
|
||||
(row_id, s, e, sc, bool(dis), os_ if os_ is not None else s,
|
||||
oe if oe is not None else e))
|
||||
return result
|
||||
|
||||
def delete_scan_result(self, row_id: int) -> None:
|
||||
"""Delete a single scan result row."""
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute("DELETE FROM scan_results WHERE id = ?", (row_id,))
|
||||
self._con.commit()
|
||||
|
||||
def toggle_scan_result_disabled(self, row_id: int, disabled: bool) -> None:
|
||||
"""Set disabled flag on a scan result row."""
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
"UPDATE scan_results SET disabled = ? WHERE id = ?",
|
||||
(1 if disabled else 0, row_id),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def update_scan_result_times(self, row_id: int,
|
||||
start: float, end: float) -> None:
|
||||
"""Update start/end times of a scan result row (resize)."""
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
"UPDATE scan_results SET start_time = ?, end_time = ? WHERE id = ?",
|
||||
(start, end, row_id),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_scan_models(self, filename: str, profile: str) -> list[str]:
|
||||
"""Return model names that have scan results for this file."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT model FROM scan_results"
|
||||
" WHERE filename = ? AND profile = ? ORDER BY model",
|
||||
(filename, profile),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_scanned_filenames(self, profile: str, model: str) -> set[str]:
|
||||
"""Return filenames that already have scan results for this model."""
|
||||
if not self._enabled:
|
||||
return set()
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT filename FROM scan_results"
|
||||
" WHERE profile = ? AND model = ?",
|
||||
(profile, model),
|
||||
).fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
def add_hard_negatives(self, filename: str, profile: str,
|
||||
times: list[float], source_path: str = "",
|
||||
source_model: str = "") -> None:
|
||||
"""Save timestamps as hard-negative training examples."""
|
||||
if not self._enabled or not times:
|
||||
return
|
||||
with self._lock:
|
||||
for t in times:
|
||||
self._con.execute(
|
||||
"INSERT INTO hard_negatives"
|
||||
" (filename, profile, start_time, source_path, source_model)"
|
||||
" VALUES (?, ?, ?, ?, ?)",
|
||||
(filename, profile, t, source_path, source_model),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_hard_negative_times(self, filename: str, profile: str) -> set[float]:
|
||||
"""Return start_times marked as hard negatives for this file."""
|
||||
if not self._enabled:
|
||||
return set()
|
||||
rows = self._con.execute(
|
||||
"SELECT start_time FROM hard_negatives"
|
||||
" WHERE filename = ? AND profile = ?",
|
||||
(filename, profile),
|
||||
).fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
def get_hard_negatives(self, profile: str) -> list[dict]:
|
||||
"""Return all hard negatives for a profile with full details."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT id, filename, start_time, source_path, source_model"
|
||||
" FROM hard_negatives WHERE profile = ?"
|
||||
" ORDER BY filename, start_time",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return [{"id": r[0], "filename": r[1], "start_time": r[2],
|
||||
"source_path": r[3], "source_model": r[4]} for r in rows]
|
||||
|
||||
def delete_hard_negatives_by_ids(self, ids: list[int]) -> None:
|
||||
"""Delete hard negatives by row IDs."""
|
||||
if not self._enabled or not ids:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
f"DELETE FROM hard_negatives WHERE id IN ({','.join('?' * len(ids))})",
|
||||
ids,
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def remove_hard_negatives(self, filename: str, profile: str,
|
||||
times: list[float]) -> None:
|
||||
"""Remove specific hard-negative timestamps."""
|
||||
if not self._enabled or not times:
|
||||
return
|
||||
with self._lock:
|
||||
for t in times:
|
||||
self._con.execute(
|
||||
"DELETE FROM hard_negatives"
|
||||
" WHERE filename = ? AND profile = ? AND start_time = ?",
|
||||
(filename, profile, t),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_training_filenames(self, profile: str) -> set[str]:
|
||||
"""Return filenames used in training (have exported clips)."""
|
||||
if not self._enabled:
|
||||
return set()
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT filename FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
# ── Hidden files ───────────────────────────────────────────
|
||||
|
||||
def hide_file(self, filename: str, profile: str = "default") -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
"INSERT OR IGNORE INTO hidden_files (filename, profile) VALUES (?, ?)",
|
||||
(filename, profile),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def unhide_file(self, filename: str, profile: str = "default") -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
self._con.execute(
|
||||
"DELETE FROM hidden_files WHERE filename = ? AND profile = ?",
|
||||
(filename, profile),
|
||||
)
|
||||
self._con.commit()
|
||||
|
||||
def get_hidden_files(self, profile: str = "default") -> set[str]:
|
||||
if not self._enabled:
|
||||
return set()
|
||||
rows = self._con.execute(
|
||||
"SELECT filename FROM hidden_files WHERE profile = ?", (profile,)
|
||||
).fetchall()
|
||||
return {r[0] for r in rows}
|
||||
+187
@@ -0,0 +1,187 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
|
||||
_RATIOS: dict[str, tuple[int, int]] = {
|
||||
"9:16": (9, 16),
|
||||
"4:5": (4, 5),
|
||||
"1:1": (1, 1),
|
||||
}
|
||||
|
||||
|
||||
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||
"""Return an ffmpeg crop= filter expression for the given portrait ratio.
|
||||
|
||||
Uses ffmpeg expression syntax so source dimensions are resolved at runtime.
|
||||
Commas inside min()/max() are escaped with \\, to prevent ffmpeg's
|
||||
filtergraph parser from treating them as filter-chain separators.
|
||||
"""
|
||||
num, den = _RATIOS[ratio]
|
||||
cw = f"ih*{num}/{den}"
|
||||
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||
return f"crop={cw}:ih:{x}:0"
|
||||
|
||||
|
||||
def resolve_keyframe(
|
||||
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||
t: float,
|
||||
tolerance: float = 0.05,
|
||||
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||
"""Return the latest keyframe at or before *t*, or None."""
|
||||
result = None
|
||||
for kf in keyframes:
|
||||
if kf[0] <= t + tolerance:
|
||||
result = kf
|
||||
else:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def apply_keyframes_to_jobs(
|
||||
jobs: list[tuple[float, str, str | None, float]],
|
||||
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||
base_center: float,
|
||||
base_ratio: str | None,
|
||||
base_rand_p: bool,
|
||||
base_rand_s: bool,
|
||||
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||
"""Resolve each job's crop state from keyframes, returning widened tuples.
|
||||
|
||||
Returns list of (start, path, ratio, center, rand_portrait, rand_square).
|
||||
"""
|
||||
result = []
|
||||
for s, o, _r, _c in jobs:
|
||||
kf = resolve_keyframe(keyframes, s)
|
||||
if kf is not None:
|
||||
_, center, ratio, rp, rs = kf
|
||||
else:
|
||||
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||
result.append((s, o, ratio, center, rp, rs))
|
||||
return result
|
||||
|
||||
|
||||
def _find_vaapi_device() -> str:
|
||||
"""Return the first available VAAPI render device path (Linux)."""
|
||||
import glob
|
||||
devices = sorted(glob.glob("/dev/dri/renderD*"))
|
||||
return devices[0] if devices else "/dev/dri/renderD128"
|
||||
|
||||
|
||||
def build_ffmpeg_command(
|
||||
input_path: str, start: float, output_path: str,
|
||||
short_side: int | None = None,
|
||||
portrait_ratio: str | None = None,
|
||||
crop_center: float = 0.5,
|
||||
image_sequence: bool = False,
|
||||
encoder: str = "libx264",
|
||||
) -> list[str]:
|
||||
# -ss before -i: fast input-seeking. Safe here because we always re-encode,
|
||||
# so there is no keyframe-alignment issue from pre-input seek.
|
||||
# Image sequences always use libwebp, so skip HW encoder setup.
|
||||
use_hw_vaapi = (encoder == "h264_vaapi" and not image_sequence
|
||||
and sys.platform == "linux")
|
||||
cmd = [_bin("ffmpeg"), "-y"]
|
||||
|
||||
# VAAPI needs a render device for hardware context (Linux only).
|
||||
if use_hw_vaapi:
|
||||
vaapi_dev = _find_vaapi_device()
|
||||
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||
"-vaapi_device", vaapi_dev]
|
||||
|
||||
cmd += [
|
||||
"-threads", "0",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", "8",
|
||||
]
|
||||
|
||||
filters: list[str] = []
|
||||
if portrait_ratio is not None:
|
||||
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||
if short_side is not None:
|
||||
# Scale so the shorter dimension equals short_side.
|
||||
filters.append(
|
||||
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||
)
|
||||
|
||||
# VAAPI: decoded frames are GPU surfaces. CPU filters need hwdownload first.
|
||||
if use_hw_vaapi:
|
||||
if filters:
|
||||
filters.insert(0, "hwdownload")
|
||||
filters.insert(1, "format=nv12")
|
||||
filters.append("format=nv12")
|
||||
filters.append("hwupload")
|
||||
|
||||
if filters:
|
||||
cmd += ["-vf", ",".join(filters)]
|
||||
|
||||
if image_sequence:
|
||||
cmd += [
|
||||
"-an",
|
||||
"-c:v", "libwebp",
|
||||
"-quality", "92",
|
||||
"-compression_level", "1",
|
||||
os.path.join(output_path, "frame_%04d.webp"),
|
||||
]
|
||||
else:
|
||||
cmd += ["-c:v", encoder]
|
||||
if "nvenc" in encoder:
|
||||
cmd += ["-preset", "p4", "-cq", "28"]
|
||||
elif "vaapi" in encoder:
|
||||
cmd += ["-qp", "28"]
|
||||
elif "qsv" in encoder:
|
||||
cmd += ["-global_quality", "28"]
|
||||
elif "amf" in encoder:
|
||||
cmd += ["-qp_i", "28", "-qp_p", "28"]
|
||||
cmd += ["-c:a", "pcm_s16le", output_path]
|
||||
return cmd
|
||||
|
||||
|
||||
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||
"""Return an ffmpeg command that extracts audio to <sequence_dir>.wav."""
|
||||
audio_path = sequence_dir + ".wav"
|
||||
return [
|
||||
_bin("ffmpeg"), "-y",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", "8",
|
||||
"-vn",
|
||||
"-c:a", "pcm_s16le",
|
||||
audio_path,
|
||||
]
|
||||
|
||||
|
||||
def detect_hw_encoders() -> list[str]:
|
||||
"""Probe ffmpeg for available H.264 hardware encoders.
|
||||
|
||||
Returns only encoders relevant to the current platform:
|
||||
- Windows: h264_nvenc, h264_qsv, h264_amf
|
||||
- Linux: h264_nvenc, h264_vaapi, h264_qsv
|
||||
- macOS: h264_videotoolbox
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
candidates = ["h264_nvenc", "h264_qsv", "h264_amf"]
|
||||
elif sys.platform == "darwin":
|
||||
candidates = ["h264_videotoolbox"]
|
||||
else:
|
||||
candidates = ["h264_nvenc", "h264_vaapi", "h264_qsv"]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
output = result.stdout
|
||||
except Exception:
|
||||
return []
|
||||
available = [enc for enc in candidates if re.search(rf'\b{enc}\b', output)]
|
||||
if available:
|
||||
_log(f"HW encoders detected: {', '.join(available)}")
|
||||
else:
|
||||
_log("No HW encoders detected — GPU export unavailable")
|
||||
return available
|
||||
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _frozen_path() -> Path:
|
||||
if getattr(sys, "frozen", False):
|
||||
return Path(sys._MEIPASS)
|
||||
return Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def _bin(name: str) -> str:
|
||||
"""Resolve a binary name (e.g. 'ffmpeg') to its full path in frozen builds."""
|
||||
p = _frozen_path() / name
|
||||
if p.exists():
|
||||
return str(p)
|
||||
return name # fall back to PATH
|
||||
|
||||
|
||||
def _log(*args) -> None:
|
||||
"""Print a timestamped log line to stderr."""
|
||||
ts = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||
|
||||
|
||||
def build_export_path(folder: str, basename: str, counter: int,
|
||||
sub: int | None = None, tag: str | None = None) -> str:
|
||||
"""Build clip output path. *folder* should be the vid folder (e.g. .../mp4/vid_001)."""
|
||||
name = f"{basename}_{counter:03d}"
|
||||
if tag is not None:
|
||||
name = f"{name}_{tag}"
|
||||
if sub is not None:
|
||||
name = f"{name}_{sub}"
|
||||
return os.path.join(folder, name + ".mp4")
|
||||
|
||||
|
||||
def build_sequence_dir(folder: str, basename: str, counter: int,
|
||||
sub: int | None = None, tag: str | None = None) -> str:
|
||||
"""Build WebP sequence output dir. *folder* should be the vid folder."""
|
||||
name = f"{basename}_{counter:03d}"
|
||||
if tag is not None:
|
||||
name = f"{name}_{tag}"
|
||||
if sub is not None:
|
||||
name = f"{name}_{sub}"
|
||||
return os.path.join(folder, name)
|
||||
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
m = int(seconds // 60)
|
||||
# Floor-truncate to 1 dp (not round) — prevents "X:60.0" rollover when
|
||||
# seconds is e.g. 59.95.
|
||||
s = int(seconds % 60 * 10) / 10
|
||||
return f"{m}:{s:04.1f}"
|
||||
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
def _get_yolo():
|
||||
"""Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed."""
|
||||
global _yolo_model
|
||||
if _yolo_model is None:
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
_yolo_model = YOLO("yolov8n.pt")
|
||||
_log("YOLO model loaded")
|
||||
except ImportError:
|
||||
_log("ultralytics not installed — tracking disabled")
|
||||
return None
|
||||
except Exception as e:
|
||||
_log(f"YOLO load failed: {e}")
|
||||
return None
|
||||
return _yolo_model
|
||||
|
||||
|
||||
def extract_frame_cv(video_path: str, time: float):
|
||||
"""Extract a single frame as a numpy array (BGR) via ffmpeg -> temp PNG -> cv2."""
|
||||
try:
|
||||
import cv2
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
return None
|
||||
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||
os.close(fd)
|
||||
try:
|
||||
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||
"-frames:v", "1", tmp]
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
return cv2.imread(tmp)
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
def detect_subject_center(
|
||||
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||
) -> tuple[int | None, float, float] | None:
|
||||
"""Detect objects at *time* and return (class_id, norm_x, norm_y) of the
|
||||
best match to (target_cls, last_x, last_y). Returns None on failure."""
|
||||
model = _get_yolo()
|
||||
if model is None:
|
||||
return None
|
||||
frame = extract_frame_cv(video_path, time)
|
||||
if frame is None:
|
||||
return None
|
||||
results = model(frame, verbose=False)
|
||||
if not results or len(results[0].boxes) == 0:
|
||||
return None
|
||||
h, w = frame.shape[:2]
|
||||
dets = []
|
||||
for box in results[0].boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
cls = int(box.cls[0])
|
||||
cx = (x1 + x2) / 2 / w
|
||||
cy = (y1 + y2) / 2 / h
|
||||
dets.append((cls, cx, cy))
|
||||
# Prefer same class, nearest to last known position.
|
||||
def score(d):
|
||||
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||
return cls_penalty + dist
|
||||
best = min(dets, key=score)
|
||||
return best
|
||||
|
||||
|
||||
def track_centers_for_jobs(
|
||||
video_path: str, cursor: float, crop_center: float,
|
||||
starts: list[float],
|
||||
) -> list[float]:
|
||||
"""Run detection at the cursor (to identify the target) then at each start
|
||||
time. Returns a list of horizontal crop centers (one per start)."""
|
||||
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||
if ref is None:
|
||||
_log("Tracking: no detection at cursor, using fixed center")
|
||||
return [crop_center] * len(starts)
|
||||
target_cls, last_x, last_y = ref
|
||||
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||
centers = []
|
||||
for t in starts:
|
||||
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||
if det is not None:
|
||||
_, cx, cy = det
|
||||
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||
centers.append(cx)
|
||||
last_x, last_y = cx, cy
|
||||
else:
|
||||
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||
centers.append(last_x)
|
||||
return centers
|
||||
@@ -0,0 +1,148 @@
|
||||
# 8-cut Client Design
|
||||
|
||||
## Goal
|
||||
|
||||
Build a Tauri + Svelte desktop client that connects to the 8-cut server API for remote video editing. Full feature parity with the Qt app. Targets Linux first, then Mac.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Tauri app (Rust shell + Svelte webview)
|
||||
├── mpv sidecar (bundled binary)
|
||||
│ ├── plays video: http://server/api/stream/{path}?quality=low
|
||||
│ ├── plays audio: http://server/api/audio/{path}
|
||||
│ └── controlled via JSON IPC socket
|
||||
├── Svelte UI
|
||||
│ ├── File browser
|
||||
│ ├── Canvas timeline (markers, cursor, play region)
|
||||
│ ├── Canvas crop overlay
|
||||
│ ├── Export controls + WebSocket progress
|
||||
│ └── Settings panel (profile, subprofiles, quality)
|
||||
└── Rust backend
|
||||
├── Spawn/manage mpv process + IPC
|
||||
├── Proxy server API calls (avoid CORS)
|
||||
└── Tauri commands exposed to Svelte frontend
|
||||
```
|
||||
|
||||
## Playback
|
||||
|
||||
mpv runs as a sidecar process, controlled via JSON IPC socket. Two streams:
|
||||
- Video: `http://server/api/stream/{path}?root={root}&quality={quality}` (transcoded, no audio)
|
||||
- Audio: `http://server/api/audio/{path}?root={root}` (full quality WAV)
|
||||
|
||||
mpv's `--audio-file=` flag syncs both streams with frame-accurate seeking.
|
||||
|
||||
Quality presets: potato (480p), low (720p), medium (1080p), high (original).
|
||||
|
||||
## Features
|
||||
|
||||
### File management
|
||||
- Browse server video roots (`GET /api/roots`, `GET /api/files`)
|
||||
- Hide/unhide files per profile (`POST/DELETE /api/hidden/{filename}`)
|
||||
- Sort by name/size, filter hidden
|
||||
|
||||
### Playback
|
||||
- Play/pause/resume from pause point
|
||||
- AB-loop with current spread/clips settings
|
||||
- Play region adapts to spread changes without restarting
|
||||
- Quality selector
|
||||
|
||||
### Timeline (Canvas)
|
||||
- Cursor position, markers, play position indicator
|
||||
- Click to seek, drag cursor
|
||||
- Lock mode: cursor locked to marker, double-click jumps to end of clip span
|
||||
- Autoclip: when paused, auto-adjust clip count to fit pause position
|
||||
|
||||
### Crop & keyframes
|
||||
- Portrait ratio selector (9:16, 4:5, 1:1, off)
|
||||
- Crop center slider with live canvas overlay
|
||||
- Crop keyframes at arbitrary timeline positions
|
||||
- Subject tracking (triggered server-side)
|
||||
- Random portrait/square toggles
|
||||
|
||||
### Export
|
||||
- Configurable: clips, spread, short side, format (MP4/WebP sequence)
|
||||
- Label + category annotation
|
||||
- Encoder selection (libx264 / h264_nvenc)
|
||||
- Subprofiles with folder suffix routing
|
||||
- Number keys 1-9 for subprofile quick export, E for main
|
||||
- WebSocket progress (`WS /ws/export`), per-clip completion
|
||||
- Delete/re-export from marker context menu
|
||||
|
||||
### Profiles
|
||||
- Profile switcher, markers reload per profile
|
||||
- Subprofile management (add/remove)
|
||||
|
||||
### Settings
|
||||
- Server URL (configurable)
|
||||
- Default quality preset
|
||||
- All settings persisted client-side via Tauri store
|
||||
|
||||
## Server API endpoints used
|
||||
|
||||
```
|
||||
GET /api/roots
|
||||
GET /api/files?root={root}
|
||||
GET /api/video/{path}?root={root}
|
||||
GET /api/stream/{path}?root={root}&quality={quality}
|
||||
GET /api/audio/{path}?root={root}
|
||||
GET /api/cache/status/{path}?root={root}
|
||||
GET /api/markers/{filename}?profile={profile}
|
||||
GET /api/profiles
|
||||
GET /api/labels
|
||||
POST /api/export
|
||||
GET /api/export/{job_id}
|
||||
DELETE /api/export?output_path={path}
|
||||
POST /api/hidden/{filename}?profile={profile}
|
||||
DELETE /api/hidden/{filename}?profile={profile}
|
||||
GET /api/hidden?profile={profile}
|
||||
WS /ws/export
|
||||
```
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
client/
|
||||
├── src-tauri/
|
||||
│ ├── src/
|
||||
│ │ ├── main.rs (Tauri entry, app setup)
|
||||
│ │ ├── mpv.rs (mpv sidecar spawn + IPC)
|
||||
│ │ ├── commands.rs (Tauri commands for Svelte)
|
||||
│ │ └── lib.rs
|
||||
│ ├── Cargo.toml
|
||||
│ └── tauri.conf.json
|
||||
├── src/
|
||||
│ ├── App.svelte
|
||||
│ ├── lib/
|
||||
│ │ ├── api.ts (server API client)
|
||||
│ │ ├── mpv.ts (mpv IPC bridge via Tauri commands)
|
||||
│ │ ├── ws.ts (WebSocket export progress)
|
||||
│ │ └── stores.ts (Svelte stores: files, markers, settings)
|
||||
│ ├── components/
|
||||
│ │ ├── FileBrowser.svelte
|
||||
│ │ ├── Timeline.svelte
|
||||
│ │ ├── CropOverlay.svelte
|
||||
│ │ ├── ExportPanel.svelte
|
||||
│ │ ├── SettingsPanel.svelte
|
||||
│ │ └── ProfileBar.svelte
|
||||
│ └── main.ts
|
||||
├── package.json
|
||||
└── vite.config.ts
|
||||
```
|
||||
|
||||
## Implementation order
|
||||
|
||||
1. Scaffold Tauri + Svelte project
|
||||
2. mpv sidecar: spawn, IPC, basic play/pause/seek
|
||||
3. API client module + server connection
|
||||
4. File browser component
|
||||
5. Video playback: load file → stream URL → mpv
|
||||
6. Canvas timeline: cursor, seek, markers
|
||||
7. Export panel + WebSocket progress
|
||||
8. Crop overlay + keyframes
|
||||
9. Lock mode, autoclip, play region
|
||||
10. Profiles, subprofiles, hidden files
|
||||
11. Keyboard shortcuts
|
||||
12. Settings persistence
|
||||
13. Package for Linux (.deb / .AppImage)
|
||||
14. Package for Mac (.dmg)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,207 @@
|
||||
# 8-cut Server API Design
|
||||
|
||||
## Goal
|
||||
|
||||
Run 8-cut as a FastAPI server on Unraid (Docker) so a Tauri desktop client on Mac can edit remotely over WireGuard — no file transfers, no auth.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Unraid (Docker container):
|
||||
FastAPI + ffmpeg + SQLite
|
||||
├── /api/files list videos from mounted volumes
|
||||
├── /api/stream/{path} transcoded video (cached, no audio)
|
||||
├── /api/audio/{path} full-quality audio (cached, passthrough)
|
||||
├── /api/video/{path} raw file (for reference/download)
|
||||
├── /api/markers CRUD markers per profile
|
||||
├── /api/profiles list/create profiles
|
||||
├── /api/export trigger + manage exports
|
||||
├── /api/labels label history
|
||||
├── /api/hidden hidden file management
|
||||
└── ws://…/ws/export real-time export progress
|
||||
|
||||
Mac (Tauri + Svelte + libmpv):
|
||||
├── mpv plays stream URL (video) + audio URL separately
|
||||
├── Canvas timeline + crop overlay + keyframes
|
||||
├── Full UI: profiles, subprofiles, settings
|
||||
└── Stateless — all state lives on server
|
||||
```
|
||||
|
||||
## Docker mounts
|
||||
|
||||
| Mount | Purpose | Env var |
|
||||
|-------------|--------------------------------|--------------|
|
||||
| `/videos` | Source video files (read-only) | `MEDIA_DIRS` |
|
||||
| `/exports` | Export output | `EXPORT_DIR` |
|
||||
| `/data` | SQLite DB + transcode cache | `DB_PATH`, `CACHE_DIR` |
|
||||
|
||||
`MEDIA_DIRS` supports multiple paths: `/videos1,/videos2`.
|
||||
|
||||
## Video streaming with transcode cache
|
||||
|
||||
The client needs low-bitrate video for scrubbing over the network but full-quality audio for accurate editing.
|
||||
|
||||
**Flow:**
|
||||
1. Client requests `/api/stream/{path}?quality=low`
|
||||
2. Server checks cache: `{CACHE_DIR}/{quality}/{hash}.mp4`
|
||||
3. If cached → serve with range requests (instant seeking)
|
||||
4. If not → start background ffmpeg transcode, return `202 Accepted` with job ID
|
||||
5. Client polls or gets WebSocket notification when ready
|
||||
6. Audio: `/api/audio/{path}` extracts audio (passthrough, fast) to cache on first request
|
||||
|
||||
**Quality presets:**
|
||||
|
||||
| Preset | Resolution | Bitrate |
|
||||
|----------|-----------|----------|
|
||||
| `potato` | 480p | ~500 Kbps |
|
||||
| `low` | 720p | ~2 Mbps |
|
||||
| `medium` | 1080p | ~5 Mbps |
|
||||
| `high` | original | ~10 Mbps |
|
||||
|
||||
Each quality level cached separately. Client can switch quality — mpv reloads the URL.
|
||||
|
||||
**mpv on client:**
|
||||
```
|
||||
video = http://server/api/stream/file.mp4?quality=low
|
||||
audio = http://server/api/audio/file.mp4
|
||||
```
|
||||
mpv's `--audio-file=` flag plays both in sync with frame-accurate seeking.
|
||||
|
||||
## API endpoints
|
||||
|
||||
### Files
|
||||
```
|
||||
GET /api/files?root={root}
|
||||
→ [{path, name, size, duration?, markers_count}]
|
||||
|
||||
GET /api/video/{path}
|
||||
→ raw file with range requests
|
||||
|
||||
GET /api/stream/{path}?quality=low|medium|high|potato
|
||||
→ cached transcoded video (no audio), range requests
|
||||
→ 202 if transcode in progress
|
||||
|
||||
GET /api/audio/{path}
|
||||
→ cached full-quality audio, range requests
|
||||
→ 202 if extraction in progress
|
||||
|
||||
GET /api/cache/status/{path}
|
||||
→ {qualities: {potato: "ready", low: "transcoding", ...}, audio: "ready"}
|
||||
```
|
||||
|
||||
### Markers & profiles
|
||||
```
|
||||
GET /api/markers/{filename}?profile=default
|
||||
→ [{start_time, marker_number, output_path}]
|
||||
|
||||
GET /api/profiles
|
||||
→ ["default", "intense", ...]
|
||||
|
||||
GET /api/labels
|
||||
→ ["dog barking", "rain", ...]
|
||||
```
|
||||
|
||||
### Export
|
||||
```
|
||||
POST /api/export
|
||||
body: {input_path, cursor, folder_suffix?, name, clips, spread,
|
||||
short_side?, portrait_ratio?, crop_center, format,
|
||||
label?, category?, profile, crop_keyframes?,
|
||||
rand_portrait?, rand_square?, track_subject?}
|
||||
→ {job_id}
|
||||
|
||||
GET /api/export/{job_id}
|
||||
→ {status, completed, total, outputs: [...]}
|
||||
|
||||
DELETE /api/export/{output_path}
|
||||
→ delete from DB + disk
|
||||
|
||||
WS /ws/export
|
||||
→ server pushes: {type: "clip_done", path: "..."} | {type: "all_done"} | {type: "error", msg: "..."}
|
||||
```
|
||||
|
||||
### Hidden files
|
||||
```
|
||||
POST /api/hidden/{filename}?profile=default
|
||||
DELETE /api/hidden/{filename}?profile=default
|
||||
GET /api/hidden?profile=default
|
||||
→ ["file1.mp4", "file2.mp4"]
|
||||
```
|
||||
|
||||
## Code reuse from main.py
|
||||
|
||||
**Extracted to shared module (used by both server and Qt app):**
|
||||
- `ProcessedDB` — SQLite operations
|
||||
- `build_ffmpeg_command` — ffmpeg command construction
|
||||
- `build_audio_extract_command`
|
||||
- `build_export_path` / `build_sequence_dir`
|
||||
- `detect_hw_encoders`
|
||||
- `upsert_clip_annotation` / `remove_clip_annotation`
|
||||
- `apply_keyframes_to_jobs` / `resolve_keyframe`
|
||||
- `track_centers_for_jobs` (subject tracking)
|
||||
|
||||
**Server-specific (new):**
|
||||
- FastAPI app + route handlers
|
||||
- Transcode cache manager
|
||||
- Export worker (plain threading, replaces QThread-based ExportWorker)
|
||||
- File listing / media root scanning
|
||||
- WebSocket export progress broadcaster
|
||||
|
||||
**Tauri client (new, Svelte):**
|
||||
- mpv integration via Tauri plugin or sidecar
|
||||
- Canvas-based timeline widget
|
||||
- Canvas-based crop overlay
|
||||
- All UI controls
|
||||
- API client module
|
||||
|
||||
## Dockerfile
|
||||
|
||||
```dockerfile
|
||||
FROM python:3.12-slim
|
||||
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /app
|
||||
COPY server/ .
|
||||
RUN pip install --no-cache-dir fastapi uvicorn
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
8-cut/
|
||||
├── main.py (existing Qt app, unchanged)
|
||||
├── core/ (shared logic, extracted from main.py)
|
||||
│ ├── __init__.py
|
||||
│ ├── db.py (ProcessedDB)
|
||||
│ ├── ffmpeg.py (build commands, detect encoders)
|
||||
│ ├── export.py (ExportWorker — plain threading)
|
||||
│ ├── paths.py (build_export_path, build_sequence_dir)
|
||||
│ └── annotations.py (dataset.json helpers)
|
||||
├── server/
|
||||
│ ├── app.py (FastAPI app)
|
||||
│ ├── routes/
|
||||
│ │ ├── files.py
|
||||
│ │ ├── stream.py
|
||||
│ │ ├── markers.py
|
||||
│ │ ├── export.py
|
||||
│ │ └── hidden.py
|
||||
│ ├── cache.py (transcode cache manager)
|
||||
│ ├── ws.py (WebSocket handler)
|
||||
│ └── config.py (env vars, settings)
|
||||
├── client/ (Tauri + Svelte — future)
|
||||
│ └── ...
|
||||
├── Dockerfile
|
||||
└── docker-compose.yml
|
||||
```
|
||||
|
||||
## Implementation order
|
||||
|
||||
1. Extract shared logic from main.py → `core/`
|
||||
2. Update main.py to import from `core/` (verify Qt app still works)
|
||||
3. Build FastAPI server with file listing + video serving
|
||||
4. Add transcode cache + audio extraction
|
||||
5. Add markers/profiles/labels/hidden API
|
||||
6. Add export endpoint + WebSocket progress
|
||||
7. Dockerfile + docker-compose
|
||||
8. (Later) Tauri client
|
||||
@@ -0,0 +1,948 @@
|
||||
# Server API Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Extract shared logic from main.py into a `core/` package, then build the FastAPI server that serves video files, manages the DB, and runs exports.
|
||||
|
||||
**Architecture:** Shared logic (DB, ffmpeg, paths, annotations, tracking) moves to `core/`. Both `main.py` (Qt app) and `server/` import from `core/`. The server adds HTTP video streaming with transcode cache, REST endpoints, and WebSocket export progress.
|
||||
|
||||
**Tech Stack:** Python 3.12, FastAPI, uvicorn, SQLite, ffmpeg
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Create core/ package — paths and helpers
|
||||
|
||||
**Files:**
|
||||
- Create: `core/__init__.py`
|
||||
- Create: `core/paths.py`
|
||||
|
||||
**Step 1: Create core/__init__.py**
|
||||
|
||||
```python
|
||||
# empty — package marker
|
||||
```
|
||||
|
||||
**Step 2: Create core/paths.py**
|
||||
|
||||
Extract from main.py lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`.
|
||||
|
||||
```python
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _frozen_path() -> Path:
|
||||
if getattr(sys, "frozen", False):
|
||||
return Path(sys._MEIPASS)
|
||||
return Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def _bin(name: str) -> str:
|
||||
p = _frozen_path() / name
|
||||
if p.exists():
|
||||
return str(p)
|
||||
return name
|
||||
|
||||
|
||||
def _log(*args) -> None:
|
||||
ts = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||
|
||||
|
||||
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||
group = f"{basename}_{counter:03d}"
|
||||
name = f"{group}_{sub}" if sub is not None else group
|
||||
return os.path.join(folder, group, name + ".mp4")
|
||||
|
||||
|
||||
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||
group = f"{basename}_{counter:03d}"
|
||||
name = f"{group}_{sub}" if sub is not None else group
|
||||
return os.path.join(folder, group, name)
|
||||
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
m = int(seconds // 60)
|
||||
s = int(seconds % 60 * 10) / 10
|
||||
return f"{m}:{s:04.1f}"
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add core/
|
||||
git commit -m "feat: create core/paths module with shared path helpers"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Create core/ffmpeg.py
|
||||
|
||||
**Files:**
|
||||
- Create: `core/ffmpeg.py`
|
||||
|
||||
**Step 1: Create core/ffmpeg.py**
|
||||
|
||||
Extract from main.py lines 77-112 and 244-289: `_RATIOS`, `_portrait_crop_filter`, `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`, `detect_hw_encoders`. (Lines 115-188 are also ffmpeg-related. Lines 191-241 are annotations — extracted separately in Task 4.)
|
||||
|
||||
```python
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
|
||||
_RATIOS: dict[str, tuple[int, int]] = {
|
||||
"9:16": (9, 16),
|
||||
"4:5": (4, 5),
|
||||
"1:1": (1, 1),
|
||||
}
|
||||
|
||||
|
||||
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||
num, den = _RATIOS[ratio]
|
||||
cw = f"ih*{num}/{den}"
|
||||
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||
return f"crop={cw}:ih:{x}:0"
|
||||
|
||||
|
||||
def resolve_keyframe(
|
||||
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||
t: float,
|
||||
tolerance: float = 0.05,
|
||||
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||
result = None
|
||||
for kf in keyframes:
|
||||
if kf[0] <= t + tolerance:
|
||||
result = kf
|
||||
else:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def apply_keyframes_to_jobs(
|
||||
jobs: list[tuple[float, str, str | None, float]],
|
||||
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||
base_center: float,
|
||||
base_ratio: str | None,
|
||||
base_rand_p: bool,
|
||||
base_rand_s: bool,
|
||||
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||
result = []
|
||||
for s, o, _r, _c in jobs:
|
||||
kf = resolve_keyframe(keyframes, s)
|
||||
if kf is not None:
|
||||
_, center, ratio, rp, rs = kf
|
||||
else:
|
||||
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||
result.append((s, o, ratio, center, rp, rs))
|
||||
return result
|
||||
|
||||
|
||||
def build_ffmpeg_command(
|
||||
input_path: str, start: float, output_path: str,
|
||||
short_side: int | None = None,
|
||||
portrait_ratio: str | None = None,
|
||||
crop_center: float = 0.5,
|
||||
image_sequence: bool = False,
|
||||
encoder: str = "libx264",
|
||||
) -> list[str]:
|
||||
use_hw_vaapi = encoder == "h264_vaapi" and not image_sequence
|
||||
cmd = [_bin("ffmpeg"), "-y"]
|
||||
if use_hw_vaapi:
|
||||
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||
"-vaapi_device", "/dev/dri/renderD128"]
|
||||
cmd += ["-threads", "0", "-ss", str(start), "-i", input_path, "-t", "8"]
|
||||
filters: list[str] = []
|
||||
if portrait_ratio is not None:
|
||||
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||
if short_side is not None:
|
||||
filters.append(
|
||||
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||
)
|
||||
if use_hw_vaapi:
|
||||
if filters:
|
||||
filters.insert(0, "hwdownload")
|
||||
filters.insert(1, "format=nv12")
|
||||
filters.append("format=nv12")
|
||||
filters.append("hwupload")
|
||||
if filters:
|
||||
cmd += ["-vf", ",".join(filters)]
|
||||
if image_sequence:
|
||||
cmd += ["-an", "-c:v", "libwebp", "-quality", "92", "-compression_level", "1",
|
||||
os.path.join(output_path, "frame_%04d.webp")]
|
||||
else:
|
||||
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
||||
return cmd
|
||||
|
||||
|
||||
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||
audio_path = sequence_dir + ".wav"
|
||||
return [_bin("ffmpeg"), "-y", "-ss", str(start), "-i", input_path,
|
||||
"-t", "8", "-vn", "-c:a", "pcm_s16le", audio_path]
|
||||
|
||||
|
||||
def detect_hw_encoders() -> list[str]:
|
||||
_HW_ENCODERS = ["h264_nvenc", "h264_vaapi", "h264_qsv", "h264_amf", "h264_videotoolbox"]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
output = result.stdout
|
||||
except Exception:
|
||||
return []
|
||||
available = []
|
||||
for enc in _HW_ENCODERS:
|
||||
if re.search(rf'\b{enc}\b', output):
|
||||
available.append(enc)
|
||||
if available:
|
||||
_log(f"HW encoders detected: {', '.join(available)}")
|
||||
else:
|
||||
_log("No HW encoders detected — GPU export unavailable")
|
||||
return available
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add core/ffmpeg.py
|
||||
git commit -m "feat: create core/ffmpeg module with ffmpeg helpers"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Create core/db.py
|
||||
|
||||
**Files:**
|
||||
- Create: `core/db.py`
|
||||
|
||||
**Step 1: Create core/db.py**
|
||||
|
||||
Extract the entire `ProcessedDB` class from main.py lines 398-626. Import `_log` from `core.paths`.
|
||||
|
||||
```python
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from .paths import _log
|
||||
|
||||
|
||||
class ProcessedDB:
|
||||
_SCHEMA_VERSION = 3
|
||||
|
||||
def __init__(self, db_path: str | None = None):
|
||||
# ... exact copy of existing class ...
|
||||
```
|
||||
|
||||
Copy the full class body verbatim — all methods unchanged.
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add core/db.py
|
||||
git commit -m "feat: create core/db module with ProcessedDB"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Create core/annotations.py
|
||||
|
||||
**Files:**
|
||||
- Create: `core/annotations.py`
|
||||
|
||||
**Step 1: Create core/annotations.py**
|
||||
|
||||
Extract from main.py lines 191-241: `build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`.
|
||||
|
||||
```python
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def build_annotation_json_path(folder: str) -> str:
|
||||
return os.path.join(folder, "dataset.json")
|
||||
|
||||
|
||||
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||
json_path = build_annotation_json_path(folder)
|
||||
if not os.path.exists(json_path):
|
||||
return
|
||||
abs_path = os.path.abspath(clip_path)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
entries = json.load(f)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return
|
||||
entries = [e for e in entries if e.get("path") != abs_path]
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||
if not label.strip():
|
||||
return
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
json_path = build_annotation_json_path(folder)
|
||||
entries: list[dict] = []
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
entries = json.load(f)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
entries = []
|
||||
abs_path = os.path.abspath(clip_path)
|
||||
entry: dict = {"path": abs_path, "label": label}
|
||||
for i, e in enumerate(entries):
|
||||
if e.get("path") == abs_path:
|
||||
entries[i] = entry
|
||||
break
|
||||
else:
|
||||
entries.append(entry)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add core/annotations.py
|
||||
git commit -m "feat: create core/annotations module"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Create core/export.py
|
||||
|
||||
**Files:**
|
||||
- Create: `core/export.py`
|
||||
|
||||
**Step 1: Create core/export.py**
|
||||
|
||||
A plain-threading version of `ExportWorker` (no QThread dependency). Used by the server. The Qt app continues using its own QThread-based worker.
|
||||
|
||||
```python
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from .ffmpeg import build_ffmpeg_command, build_audio_extract_command
|
||||
from .paths import _bin, _log
|
||||
|
||||
|
||||
class ExportRunner:
|
||||
"""Run ffmpeg export jobs in a background thread pool.
|
||||
|
||||
Callbacks:
|
||||
on_clip_done(path: str)
|
||||
on_all_done()
|
||||
on_error(msg: str)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
jobs: list[tuple[float, str, str | None, float]],
|
||||
short_side: int | None = None,
|
||||
image_sequence: bool = False,
|
||||
max_workers: int | None = None,
|
||||
encoder: str = "libx264",
|
||||
on_clip_done: Callable[[str], None] | None = None,
|
||||
on_all_done: Callable[[], None] | None = None,
|
||||
on_error: Callable[[str], None] | None = None,
|
||||
):
|
||||
self._input = input_path
|
||||
self._jobs = jobs
|
||||
self._short_side = short_side
|
||||
self._image_sequence = image_sequence
|
||||
self._max_workers = max_workers
|
||||
self._encoder = encoder
|
||||
self._on_clip_done = on_clip_done
|
||||
self._on_all_done = on_all_done
|
||||
self._on_error = on_error
|
||||
self._cancel = False
|
||||
self._procs: list[subprocess.Popen] = []
|
||||
self._procs_lock = threading.Lock()
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def cancel(self):
|
||||
self._cancel = True
|
||||
with self._procs_lock:
|
||||
for proc in self._procs:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return self._thread is not None and self._thread.is_alive()
|
||||
|
||||
def _run_one(self, start: float, output: str,
|
||||
portrait_ratio: str | None, crop_center: float) -> str:
|
||||
if self._cancel:
|
||||
raise RuntimeError("cancelled")
|
||||
if self._image_sequence:
|
||||
os.makedirs(output, exist_ok=True)
|
||||
cmd = build_ffmpeg_command(
|
||||
self._input, start, output,
|
||||
short_side=self._short_side,
|
||||
portrait_ratio=portrait_ratio,
|
||||
crop_center=crop_center,
|
||||
image_sequence=self._image_sequence,
|
||||
encoder=self._encoder,
|
||||
)
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
with self._procs_lock:
|
||||
self._procs.append(proc)
|
||||
try:
|
||||
_, stderr = proc.communicate(timeout=120)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
raise RuntimeError("ffmpeg timed out")
|
||||
finally:
|
||||
with self._procs_lock:
|
||||
self._procs.remove(proc)
|
||||
if self._cancel:
|
||||
raise RuntimeError("cancelled")
|
||||
if proc.returncode != 0:
|
||||
msg = stderr.decode(errors='replace')[-500:] if stderr else "ffmpeg failed"
|
||||
raise RuntimeError(msg)
|
||||
if self._image_sequence:
|
||||
audio_cmd = build_audio_extract_command(self._input, start, output)
|
||||
subprocess.run(audio_cmd, capture_output=True, text=True, timeout=60)
|
||||
return output
|
||||
|
||||
def _run(self):
|
||||
cap = self._max_workers or (os.cpu_count() or 2)
|
||||
workers = min(len(self._jobs), cap)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(self._run_one, s, o, pr, cc): o
|
||||
for s, o, pr, cc in self._jobs
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
if self._cancel:
|
||||
break
|
||||
try:
|
||||
path = fut.result()
|
||||
if self._on_clip_done:
|
||||
self._on_clip_done(path)
|
||||
except Exception as e:
|
||||
if "cancelled" not in str(e) and self._on_error:
|
||||
self._on_error(str(e))
|
||||
except Exception as e:
|
||||
if self._on_error:
|
||||
self._on_error(str(e))
|
||||
return
|
||||
if self._cancel:
|
||||
return
|
||||
if self._on_all_done:
|
||||
self._on_all_done()
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add core/export.py
|
||||
git commit -m "feat: create core/export module with ExportRunner"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Create core/tracking.py
|
||||
|
||||
**Files:**
|
||||
- Create: `core/tracking.py`
|
||||
|
||||
**Step 1: Create core/tracking.py**
|
||||
|
||||
Extract from main.py lines 294-395: YOLO tracking functions.
|
||||
|
||||
```python
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
def _get_yolo():
|
||||
global _yolo_model
|
||||
if _yolo_model is None:
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
_yolo_model = YOLO("yolov8n.pt")
|
||||
_log("YOLO model loaded")
|
||||
except ImportError:
|
||||
_log("ultralytics not installed — tracking disabled")
|
||||
return None
|
||||
except Exception as e:
|
||||
_log(f"YOLO load failed: {e}")
|
||||
return None
|
||||
return _yolo_model
|
||||
|
||||
|
||||
def extract_frame_cv(video_path: str, time: float):
|
||||
try:
|
||||
import cv2
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
return None
|
||||
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||
os.close(fd)
|
||||
try:
|
||||
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||
"-frames:v", "1", tmp]
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
return cv2.imread(tmp)
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
def detect_subject_center(
|
||||
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||
) -> tuple[int | None, float, float] | None:
|
||||
model = _get_yolo()
|
||||
if model is None:
|
||||
return None
|
||||
frame = extract_frame_cv(video_path, time)
|
||||
if frame is None:
|
||||
return None
|
||||
results = model(frame, verbose=False)
|
||||
if not results or len(results[0].boxes) == 0:
|
||||
return None
|
||||
h, w = frame.shape[:2]
|
||||
dets = []
|
||||
for box in results[0].boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
cls = int(box.cls[0])
|
||||
cx = (x1 + x2) / 2 / w
|
||||
cy = (y1 + y2) / 2 / h
|
||||
dets.append((cls, cx, cy))
|
||||
def score(d):
|
||||
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||
return cls_penalty + dist
|
||||
best = min(dets, key=score)
|
||||
return best
|
||||
|
||||
|
||||
def track_centers_for_jobs(
|
||||
video_path: str, cursor: float, crop_center: float,
|
||||
starts: list[float],
|
||||
) -> list[float]:
|
||||
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||
if ref is None:
|
||||
_log("Tracking: no detection at cursor, using fixed center")
|
||||
return [crop_center] * len(starts)
|
||||
target_cls, last_x, last_y = ref
|
||||
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||
centers = []
|
||||
for t in starts:
|
||||
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||
if det is not None:
|
||||
_, cx, cy = det
|
||||
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||
centers.append(cx)
|
||||
last_x, last_y = cx, cy
|
||||
else:
|
||||
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||
centers.append(last_x)
|
||||
return centers
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add core/tracking.py
|
||||
git commit -m "feat: create core/tracking module with YOLO subject tracking"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Update main.py to import from core/
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py`
|
||||
|
||||
**Step 1: Replace function definitions with imports**
|
||||
|
||||
At the top of main.py, after the existing stdlib imports (line 17), add:
|
||||
|
||||
```python
|
||||
from core.paths import _bin, _log, build_export_path, build_sequence_dir, format_time
|
||||
from core.ffmpeg import (
|
||||
_RATIOS, resolve_keyframe, apply_keyframes_to_jobs,
|
||||
build_ffmpeg_command, build_audio_extract_command, detect_hw_encoders,
|
||||
)
|
||||
from core.db import ProcessedDB
|
||||
from core.annotations import remove_clip_annotation, upsert_clip_annotation
|
||||
from core.tracking import track_centers_for_jobs
|
||||
```
|
||||
|
||||
**Step 2: Delete the extracted function definitions and dead imports**
|
||||
|
||||
Remove definitions from main.py:
|
||||
- Lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`
|
||||
- Lines 77-188: `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`
|
||||
- Lines 191-241: annotation functions (`build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`)
|
||||
- Lines 244-289: `detect_hw_encoders`, `_RATIOS`, `_portrait_crop_filter`
|
||||
- Lines 294-395: tracking functions (`_yolo_model`, `_get_yolo`, `extract_frame_cv`, `detect_subject_center`, `track_centers_for_jobs`)
|
||||
- Lines 398-626: `ProcessedDB` class
|
||||
|
||||
Remove now-dead stdlib imports from the top of main.py:
|
||||
- `re` (only used in `detect_hw_encoders`)
|
||||
- `json` (only used in annotation functions)
|
||||
- `sqlite3` (only used in `ProcessedDB`)
|
||||
- `tempfile` (only used in `extract_frame_cv`)
|
||||
- `datetime`, `timezone` from the datetime import (only used in `_log` and `ProcessedDB`)
|
||||
|
||||
Keep in main.py:
|
||||
- `_SELVA_CATEGORIES` (UI constant, line 291)
|
||||
- `_RATIOS` reference — imported from core.ffmpeg
|
||||
- `ExportWorker` (QThread-based, stays in main.py — the server uses `core.export.ExportRunner` instead)
|
||||
- `_DBWorker` and `FrameGrabber` (QThread-based, stay in main.py)
|
||||
|
||||
**Step 3: Verify Qt app still works**
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
Open a video, export a clip, check markers — verify nothing broke.
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "refactor: import shared logic from core/ instead of inline definitions"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Create server/config.py
|
||||
|
||||
**Files:**
|
||||
- Create: `server/__init__.py` (empty package marker)
|
||||
- Create: `server/config.py`
|
||||
|
||||
**Step 1: Create `server/__init__.py`**
|
||||
|
||||
```python
|
||||
# empty — package marker
|
||||
```
|
||||
|
||||
**Step 2: Create config**
|
||||
|
||||
```python
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MEDIA_DIRS: list[str] = [
|
||||
d.strip() for d in os.environ.get("MEDIA_DIRS", str(Path.home())).split(",") if d.strip()
|
||||
]
|
||||
EXPORT_DIR: str = os.environ.get("EXPORT_DIR", str(Path.home() / "8cut-exports"))
|
||||
DB_PATH: str = os.environ.get("DB_PATH", str(Path.home() / ".8cut.db"))
|
||||
CACHE_DIR: str = os.environ.get("CACHE_DIR", str(Path.home() / ".8cut-cache"))
|
||||
HOST: str = os.environ.get("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.environ.get("PORT", "8000"))
|
||||
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".ts", ".flv", ".wmv"}
|
||||
|
||||
QUALITY_PRESETS = {
|
||||
"potato": {"height": 480, "bitrate": "500k"},
|
||||
"low": {"height": 720, "bitrate": "2M"},
|
||||
"medium": {"height": 1080, "bitrate": "5M"},
|
||||
"high": {"height": 0, "bitrate": "10M"}, # 0 = original resolution
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add server/
|
||||
git commit -m "feat: create server/config with env var settings and quality presets"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 9: Create server/app.py — FastAPI skeleton + file listing
|
||||
|
||||
**Files:**
|
||||
- Create: `server/app.py`
|
||||
- Create: `server/routes/__init__.py`
|
||||
- Create: `server/routes/files.py`
|
||||
|
||||
**Step 1: Create FastAPI app**
|
||||
|
||||
`server/app.py`:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from .routes import files, stream, markers, export, hidden
|
||||
|
||||
app = FastAPI(title="8-cut Server")
|
||||
app.include_router(files.router, prefix="/api")
|
||||
app.include_router(stream.router, prefix="/api")
|
||||
app.include_router(markers.router, prefix="/api")
|
||||
app.include_router(export.router, prefix="/api")
|
||||
app.include_router(hidden.router, prefix="/api")
|
||||
```
|
||||
|
||||
**Step 2: Create file listing route**
|
||||
|
||||
`server/routes/files.py`:
|
||||
```python
|
||||
import os
|
||||
from fastapi import APIRouter, Query
|
||||
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _scan_videos(root: str) -> list[dict]:
|
||||
results = []
|
||||
for dirpath, _, filenames in os.walk(root):
|
||||
for f in sorted(filenames):
|
||||
if os.path.splitext(f)[1].lower() in VIDEO_EXTENSIONS:
|
||||
full = os.path.join(dirpath, f)
|
||||
rel = os.path.relpath(full, root)
|
||||
results.append({
|
||||
"name": f,
|
||||
"path": rel,
|
||||
"root": root,
|
||||
"size": os.path.getsize(full),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
def list_files(root: str | None = Query(None)):
|
||||
dirs = [root] if root and root in MEDIA_DIRS else MEDIA_DIRS
|
||||
files = []
|
||||
for d in dirs:
|
||||
files.extend(_scan_videos(d))
|
||||
return files
|
||||
|
||||
|
||||
@router.get("/roots")
|
||||
def list_roots():
|
||||
return MEDIA_DIRS
|
||||
```
|
||||
|
||||
**Step 3: Create `server/routes/__init__.py`**
|
||||
|
||||
```python
|
||||
# empty — package marker
|
||||
```
|
||||
|
||||
**Step 4: Create stub routers** so app.py imports don't fail. Each file gets a minimal router — later tasks fill in the real endpoints.
|
||||
|
||||
`server/routes/stream.py`:
|
||||
```python
|
||||
from fastapi import APIRouter
|
||||
router = APIRouter()
|
||||
```
|
||||
|
||||
`server/routes/markers.py`:
|
||||
```python
|
||||
from fastapi import APIRouter
|
||||
router = APIRouter()
|
||||
```
|
||||
|
||||
`server/routes/export.py`:
|
||||
```python
|
||||
from fastapi import APIRouter
|
||||
router = APIRouter()
|
||||
```
|
||||
|
||||
`server/routes/hidden.py`:
|
||||
```python
|
||||
from fastapi import APIRouter
|
||||
router = APIRouter()
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add server/
|
||||
git commit -m "feat: add FastAPI app with file listing endpoint"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 10: Create server/routes/stream.py — video serving + transcode cache
|
||||
|
||||
**Files:**
|
||||
- Create: `server/cache.py`
|
||||
- Create: `server/routes/stream.py`
|
||||
|
||||
**Step 1: Create cache manager**
|
||||
|
||||
`server/cache.py` handles:
|
||||
- Computing cache paths from source file hash + quality
|
||||
- Checking cache status
|
||||
- Launching background ffmpeg transcodes
|
||||
- Tracking in-progress jobs
|
||||
|
||||
**Step 2: Create stream routes**
|
||||
|
||||
```
|
||||
GET /api/video/{path} — raw file, range requests
|
||||
GET /api/stream/{path}?quality=low — cached transcode, range requests (202 if not ready)
|
||||
GET /api/audio/{path} — cached audio extraction, range requests (202 if not ready)
|
||||
GET /api/cache/status/{path} — cache status for all qualities
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add server/cache.py server/routes/stream.py
|
||||
git commit -m "feat: add video streaming with transcode cache and audio extraction"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 11: Create server/routes/markers.py — DB endpoints
|
||||
|
||||
**Files:**
|
||||
- Create: `server/routes/markers.py`
|
||||
|
||||
**Step 1: Create markers/profiles/labels routes**
|
||||
|
||||
```
|
||||
GET /api/markers/{filename}?profile=default
|
||||
GET /api/profiles
|
||||
GET /api/labels
|
||||
```
|
||||
|
||||
Uses `ProcessedDB` singleton from `core.db`.
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add server/routes/markers.py
|
||||
git commit -m "feat: add markers, profiles, and labels API endpoints"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 12: Create server/routes/export.py + WebSocket
|
||||
|
||||
**Files:**
|
||||
- Create: `server/routes/export.py`
|
||||
- Create: `server/ws.py`
|
||||
|
||||
**Step 1: Create export routes + WS**
|
||||
|
||||
```
|
||||
POST /api/export — start export job
|
||||
GET /api/export/{id} — check job status
|
||||
DELETE /api/export/{path} — delete export from DB + disk
|
||||
WS /ws/export — real-time progress
|
||||
```
|
||||
|
||||
Uses `ExportRunner` from `core.export`.
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add server/routes/export.py server/ws.py
|
||||
git commit -m "feat: add export endpoint with WebSocket progress"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 13: Create server/routes/hidden.py
|
||||
|
||||
**Files:**
|
||||
- Create: `server/routes/hidden.py`
|
||||
|
||||
**Step 1: Create hidden file routes**
|
||||
|
||||
```
|
||||
POST /api/hidden/{filename}?profile=default
|
||||
DELETE /api/hidden/{filename}?profile=default
|
||||
GET /api/hidden?profile=default
|
||||
```
|
||||
|
||||
**Step 2: Commit**
|
||||
|
||||
```bash
|
||||
git add server/routes/hidden.py
|
||||
git commit -m "feat: add hidden files API endpoints"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 14: Create Dockerfile + docker-compose.yml
|
||||
|
||||
**Files:**
|
||||
- Create: `Dockerfile`
|
||||
- Create: `docker-compose.yml`
|
||||
|
||||
**Step 1: Create Dockerfile**
|
||||
|
||||
```dockerfile
|
||||
FROM python:3.12-slim
|
||||
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /app
|
||||
COPY core/ core/
|
||||
COPY server/ server/
|
||||
# Note: ultralytics + opencv-python needed only if subject tracking is used.
|
||||
# Add them here if tracking is required on the server.
|
||||
RUN pip install --no-cache-dir fastapi uvicorn
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
**Step 2: Create docker-compose.yml**
|
||||
|
||||
```yaml
|
||||
services:
|
||||
8cut:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- /path/to/videos:/videos:ro
|
||||
- /path/to/exports:/exports
|
||||
- 8cut-data:/data
|
||||
environment:
|
||||
MEDIA_DIRS: /videos
|
||||
EXPORT_DIR: /exports
|
||||
DB_PATH: /data/8cut.db
|
||||
CACHE_DIR: /data/cache
|
||||
|
||||
volumes:
|
||||
8cut-data:
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add Dockerfile docker-compose.yml
|
||||
git commit -m "feat: add Dockerfile and docker-compose for server deployment"
|
||||
```
|
||||
@@ -0,0 +1,97 @@
|
||||
# Audio Similarity Scanning — Design
|
||||
|
||||
**Goal:** Scan a video's audio track and highlight segments that match the sound profile of existing reference clips, so the user can quickly find similar moments without scrubbing manually.
|
||||
|
||||
**Runs in:** Python/Qt client (`main.py`), not the server.
|
||||
|
||||
---
|
||||
|
||||
## Core Module: `core/audio_scan.py`
|
||||
|
||||
New module alongside `core/tracking.py`. Two main functions:
|
||||
|
||||
- `build_profile(clip_paths: list[str]) -> dict` — extracts MFCCs (20 coefficients) from each clip using `librosa`, returns a profile containing both the averaged vector and individual clip vectors.
|
||||
- `scan_video(video_path: str, profile: dict, mode: str, threshold: float, hop: float) -> list[tuple[float, float, float]]` — slides an 8s window across the video's audio, returns `(start_time, end_time, score)` tuples for segments above threshold.
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
- Audio loaded via `librosa.load()` (handles video files directly, mono, 22050Hz).
|
||||
- MFCCs: `librosa.feature.mfcc(n_mfcc=20)`, averaged over time axis to produce a single vector per window/clip.
|
||||
- Similarity: cosine similarity (`numpy` dot product on L2-normalized vectors).
|
||||
|
||||
### Matching Modes
|
||||
|
||||
- **Average mode:** Compare each window to the mean of all reference MFCC vectors. Fast, good when references are homogeneous.
|
||||
- **Nearest mode:** Compare each window to every reference vector, take the max score. Better when references have variety within the style.
|
||||
|
||||
### Parameters
|
||||
|
||||
- `threshold` (float, 0.0–1.0): minimum cosine similarity to include a segment. Default 0.7.
|
||||
- `hop` (float, seconds): step size for the sliding window. Default 1.0s.
|
||||
- Window size fixed at 8s to match reference clip length.
|
||||
|
||||
---
|
||||
|
||||
## UI Integration in `main.py`
|
||||
|
||||
### Controls
|
||||
|
||||
Added near the existing tracking checkbox area:
|
||||
|
||||
- **"Scan" button** — triggers audio scan on current video.
|
||||
- **Threshold slider** (0.0–1.0, step 0.05) — controls match strictness.
|
||||
- **Mode combobox** — "Average" / "Nearest".
|
||||
- **Reference source combobox** — "Current Profile" / "Custom Folder" (shows folder picker when "Custom Folder" selected).
|
||||
|
||||
### Scan Workflow
|
||||
|
||||
1. User clicks Scan.
|
||||
2. Reference clips collected: either all export `output_path` values from the current profile (via DB) or all audio/video files in a custom folder.
|
||||
3. Scan runs in a `QThread` so UI stays responsive.
|
||||
4. On completion, results sent to Timeline widget via signal.
|
||||
|
||||
### Timeline Display
|
||||
|
||||
- New `set_scan_regions(regions: list[tuple[float, float, float]])` method on Timeline.
|
||||
- Drawn as semi-transparent colored rectangles behind existing markers.
|
||||
- Color intensity proportional to score (brighter = higher match).
|
||||
- Cleared on file change or re-scan.
|
||||
|
||||
### Keyboard Shortcut
|
||||
|
||||
- `S` — jump cursor to the next scan region (similar to `M` for next marker).
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
Reference clips (DB export paths or folder)
|
||||
|
|
||||
librosa.load() each -> MFCC vectors (20-dim)
|
||||
|
|
||||
Profile: { mean_vector, clip_vectors[] }
|
||||
|
|
||||
Current video -> librosa.load() full audio (mono 22050Hz)
|
||||
|
|
||||
Sliding 8s window (hop=1s) -> MFCC per window
|
||||
|
|
||||
Cosine similarity vs profile -> score per position
|
||||
|
|
||||
Threshold filter -> [(start, end, score), ...]
|
||||
|
|
||||
Timeline: semi-transparent highlight regions
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- 2-hour video at 22050Hz mono ~ 380MB memory.
|
||||
- MFCC extraction + sliding window: ~10-30s.
|
||||
- QThread keeps UI responsive.
|
||||
|
||||
## What This Does NOT Do
|
||||
|
||||
- No DB schema changes — scan results are ephemeral (visual only).
|
||||
- No auto-export — user decides what to cut.
|
||||
- No server integration — runs entirely in the Python client.
|
||||
- No GPU/ML model dependency — just librosa + numpy.
|
||||
@@ -0,0 +1,739 @@
|
||||
# Audio Similarity Scanning — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Scan a video's audio track to find segments matching a reference sound profile, displayed as highlighted regions on the timeline.
|
||||
|
||||
**Architecture:** New `core/audio_scan.py` module extracts MFCC features from reference clips and slides an 8s window across the target video's audio, scoring each position via cosine similarity. A `ScanWorker` QThread runs the scan in the background, and results are drawn as semi-transparent rectangles on the existing Timeline widget.
|
||||
|
||||
**Tech Stack:** Python 3, librosa 0.11, numpy, PyQt6
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Core audio_scan module — build_profile
|
||||
|
||||
**Files:**
|
||||
- Create: `core/audio_scan.py`
|
||||
- Create: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the tests**
|
||||
|
||||
```python
|
||||
# tests/test_audio_scan.py
|
||||
import tempfile, os
|
||||
import numpy as np
|
||||
from core.audio_scan import build_profile, _extract_mfcc
|
||||
|
||||
|
||||
def _make_wav(path: str, duration: float = 8.0, sr: int = 22050):
|
||||
"""Create a short sine-wave WAV file for testing."""
|
||||
import soundfile as sf
|
||||
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
|
||||
audio = 0.5 * np.sin(2 * np.pi * 440 * t)
|
||||
sf.write(path, audio, sr)
|
||||
|
||||
|
||||
def test_extract_mfcc_returns_1d_vector():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
vec = _extract_mfcc(f.name)
|
||||
assert vec.shape == (20,)
|
||||
assert not np.isnan(vec).any()
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_single_clip():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
profile = build_profile([f.name])
|
||||
assert "mean_vector" in profile
|
||||
assert "clip_vectors" in profile
|
||||
assert profile["mean_vector"].shape == (20,)
|
||||
assert len(profile["clip_vectors"]) == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_multiple_clips():
|
||||
paths = []
|
||||
try:
|
||||
for i in range(3):
|
||||
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
freq = 440 + i * 200
|
||||
import soundfile as sf
|
||||
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||
sf.write(f.name, 0.5 * np.sin(2 * np.pi * freq * t), 22050)
|
||||
paths.append(f.name)
|
||||
f.close()
|
||||
|
||||
profile = build_profile(paths)
|
||||
assert len(profile["clip_vectors"]) == 3
|
||||
assert profile["mean_vector"].shape == (20,)
|
||||
finally:
|
||||
for p in paths:
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
def test_build_profile_skips_missing_files():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
profile = build_profile([f.name, "/no/such/file.wav"])
|
||||
assert len(profile["clip_vectors"]) == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_empty_returns_none():
|
||||
result = build_profile([])
|
||||
assert result is None
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: FAIL with `ModuleNotFoundError: No module named 'core.audio_scan'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
```python
|
||||
# core/audio_scan.py
|
||||
"""Audio similarity scanning — MFCC-based profile matching."""
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
from .paths import _log
|
||||
|
||||
_N_MFCC = 20
|
||||
_SR = 22050
|
||||
|
||||
|
||||
def _extract_mfcc(path: str, sr: int = _SR) -> np.ndarray:
|
||||
"""Load audio from a file and return a mean MFCC vector (20-dim)."""
|
||||
y, _ = librosa.load(path, sr=sr, mono=True)
|
||||
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=_N_MFCC)
|
||||
return mfcc.mean(axis=1) # average over time → (20,)
|
||||
|
||||
|
||||
def build_profile(clip_paths: list[str]) -> dict | None:
|
||||
"""Extract MFCCs from reference clips.
|
||||
|
||||
Returns dict with:
|
||||
- mean_vector: averaged MFCC across all clips (20,)
|
||||
- clip_vectors: list of individual MFCC vectors
|
||||
Returns None if no clips could be loaded.
|
||||
"""
|
||||
vectors = []
|
||||
for p in clip_paths:
|
||||
try:
|
||||
vec = _extract_mfcc(p)
|
||||
vectors.append(vec)
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: skip {p}: {e}")
|
||||
if not vectors:
|
||||
return None
|
||||
arr = np.stack(vectors)
|
||||
return {
|
||||
"mean_vector": arr.mean(axis=0),
|
||||
"clip_vectors": vectors,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: all 5 PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add audio_scan module with build_profile"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Core audio_scan module — scan_video
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py`
|
||||
- Modify: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the tests**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
from core.audio_scan import scan_video
|
||||
|
||||
|
||||
def test_scan_video_finds_matching_region():
|
||||
"""A video made of the same sine wave as the reference should match."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
_make_wav(ref.name, duration=8.0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
_make_wav(vid.name, duration=20.0)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0)
|
||||
assert len(regions) > 0
|
||||
for start, end, score in regions:
|
||||
assert abs((end - start) - 8.0) < 1e-9
|
||||
assert score >= 0.5
|
||||
assert score >= 0.5
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
|
||||
|
||||
def test_scan_video_nearest_mode():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
_make_wav(ref.name, duration=8.0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
_make_wav(vid.name, duration=20.0)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.5, hop=1.0)
|
||||
assert len(regions) > 0
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
|
||||
|
||||
def test_scan_video_high_threshold_no_match():
|
||||
"""Different frequencies with very high threshold should not match."""
|
||||
import soundfile as sf
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||
sf.write(ref.name, 0.5 * np.sin(2 * np.pi * 440 * t), 22050)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
# White noise — very different from sine wave
|
||||
sf.write(vid.name, np.random.randn(22050 * 20).astype(np.float32) * 0.1, 22050)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="average", threshold=0.99, hop=1.0)
|
||||
assert len(regions) == 0
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_scan_video_finds_matching_region -v`
|
||||
Expected: FAIL with `ImportError: cannot import name 'scan_video'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
Add to `core/audio_scan.py`:
|
||||
|
||||
```python
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Cosine similarity between two vectors.
|
||||
|
||||
Returns value in [-1, 1]. Negative means anti-correlated (very
|
||||
dissimilar). For threshold filtering this is fine — negative scores
|
||||
never exceed the threshold. Scores near 0 may be uncorrelated or
|
||||
weakly anti-correlated.
|
||||
"""
|
||||
na = np.linalg.norm(a)
|
||||
nb = np.linalg.norm(b)
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return float(np.dot(a, b) / (na * nb))
|
||||
|
||||
|
||||
def scan_video(
|
||||
video_path: str,
|
||||
profile: dict,
|
||||
mode: str = "average",
|
||||
threshold: float = 0.7,
|
||||
hop: float = 1.0,
|
||||
window: float = 8.0,
|
||||
cancel_flag: object = None,
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Slide a window across the video audio and score against the profile.
|
||||
|
||||
Args:
|
||||
video_path: path to video/audio file
|
||||
profile: dict from build_profile()
|
||||
mode: "average" (compare to mean) or "nearest" (max over all clips)
|
||||
threshold: minimum cosine similarity to include
|
||||
hop: step size in seconds
|
||||
window: window size in seconds (default 8s)
|
||||
cancel_flag: object with _cancel bool attribute; checked each iteration
|
||||
|
||||
Returns:
|
||||
list of (start_time, end_time, score) for regions above threshold
|
||||
"""
|
||||
_log(f"audio_scan: loading {video_path}")
|
||||
y, sr = librosa.load(video_path, sr=_SR, mono=True)
|
||||
duration = len(y) / sr
|
||||
_log(f"audio_scan: {duration:.1f}s loaded, scanning with hop={hop}s")
|
||||
|
||||
win_samples = int(window * sr)
|
||||
hop_samples = int(hop * sr)
|
||||
|
||||
results = []
|
||||
pos = 0
|
||||
while pos + win_samples <= len(y):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: cancelled")
|
||||
return results
|
||||
|
||||
chunk = y[pos : pos + win_samples]
|
||||
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=_N_MFCC)
|
||||
vec = mfcc.mean(axis=1)
|
||||
|
||||
if mode == "nearest":
|
||||
score = max(
|
||||
_cosine_similarity(vec, cv) for cv in profile["clip_vectors"]
|
||||
)
|
||||
else: # average
|
||||
score = _cosine_similarity(vec, profile["mean_vector"])
|
||||
|
||||
if score >= threshold:
|
||||
start_t = pos / sr
|
||||
results.append((start_t, start_t + window, score))
|
||||
|
||||
pos += hop_samples
|
||||
|
||||
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
|
||||
return results
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: all 8 PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add scan_video with average and nearest modes"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Timeline — draw scan regions
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (Timeline class, around lines 209-260 and 300-375)
|
||||
|
||||
**Step 1: Add scan region storage to Timeline.__init__**
|
||||
|
||||
In `main.py`, find the Timeline class `__init__` method (around line 198). After `self._markers` initialization (line 209), add:
|
||||
|
||||
```python
|
||||
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
|
||||
```
|
||||
|
||||
**Step 2: Add set_scan_regions method**
|
||||
|
||||
After the `set_markers` method (line 249-252), add:
|
||||
|
||||
```python
|
||||
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
|
||||
"""regions: list of (start_time, end_time, score)"""
|
||||
self._scan_regions = regions
|
||||
self.update()
|
||||
|
||||
def clear_scan_regions(self) -> None:
|
||||
self._scan_regions = []
|
||||
self.update()
|
||||
```
|
||||
|
||||
**Step 3: Draw scan regions in paintEvent**
|
||||
|
||||
In `paintEvent` (starts around line 282), find the marker drawing section (line 363, comment `# ── export markers`). BEFORE that section, add:
|
||||
|
||||
```python
|
||||
# ── scan regions ──────────────────────────────────────────────
|
||||
if self._scan_regions and self._duration > 0:
|
||||
for (start, end, score) in self._scan_regions:
|
||||
x1 = int(start / self._duration * w)
|
||||
x2 = int(end / self._duration * w)
|
||||
alpha = int(40 + score * 80) # 40–120 opacity
|
||||
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
|
||||
```
|
||||
|
||||
**Step 4: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: app starts without errors. No scan regions visible yet (none set).
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: timeline scan region rendering"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: ScanWorker QThread
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (add ScanWorker class, after ExportWorker around line 165)
|
||||
|
||||
**Step 1: Add the ScanWorker class**
|
||||
|
||||
After the `ExportWorker` class (ends around line 165), add:
|
||||
|
||||
```python
|
||||
class ScanWorker(QThread):
|
||||
"""Runs audio similarity scan off the main thread."""
|
||||
finished = pyqtSignal(list) # emits list of (start, end, score)
|
||||
error = pyqtSignal(str)
|
||||
progress = pyqtSignal(str) # status message
|
||||
|
||||
def __init__(self, video_path: str, clip_paths: list[str],
|
||||
mode: str = "average", threshold: float = 0.7):
|
||||
super().__init__()
|
||||
self._video_path = video_path
|
||||
self._clip_paths = clip_paths
|
||||
self._mode = mode
|
||||
self._threshold = threshold
|
||||
self._cancel = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._cancel = True
|
||||
|
||||
def run(self):
|
||||
from core.audio_scan import build_profile, scan_video
|
||||
try:
|
||||
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
|
||||
profile = build_profile(self._clip_paths)
|
||||
if self._cancel:
|
||||
return
|
||||
if profile is None:
|
||||
self.error.emit("No valid reference clips found")
|
||||
return
|
||||
self.progress.emit("Scanning audio...")
|
||||
regions = scan_video(
|
||||
self._video_path, profile,
|
||||
mode=self._mode, threshold=self._threshold,
|
||||
cancel_flag=self,
|
||||
)
|
||||
if not self._cancel:
|
||||
self.finished.emit(regions)
|
||||
except Exception as e:
|
||||
if not self._cancel:
|
||||
self.error.emit(str(e))
|
||||
```
|
||||
|
||||
**Step 2: Verify import works**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -c "from main import ScanWorker; print('ok')"`
|
||||
Expected: `ok`
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add ScanWorker QThread for background scanning"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: DB helper — get_all_export_paths
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/db.py`
|
||||
- Modify: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the test**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
from core.db import ProcessedDB
|
||||
db = ProcessedDB(path)
|
||||
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||
paths = db.get_all_export_paths("test")
|
||||
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||
finally:
|
||||
os.unlink(path)
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||
Expected: FAIL with `AttributeError: 'ProcessedDB' object has no attribute 'get_all_export_paths'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
Add to `core/db.py`, after the `get_markers` method. Note: no lock needed — follows
|
||||
the codebase convention where read-only methods don't acquire the lock.
|
||||
|
||||
```python
|
||||
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||
"""Return all unique output_path values for a given profile."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/db.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add get_all_export_paths to ProcessedDB"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: UI controls for audio scanning
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (MainWindow class — control creation ~1490-1575, layout ~1620-1640)
|
||||
|
||||
**Step 1: Add scan control widgets**
|
||||
|
||||
In the MainWindow `__init__`, find the control creation section. After `self._chk_track` (around line 1501), add:
|
||||
|
||||
```python
|
||||
# ── audio scan controls ──────────────────────────────────────
|
||||
self._btn_scan = QPushButton("Scan")
|
||||
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
|
||||
self._btn_scan.clicked.connect(self._start_scan)
|
||||
|
||||
self._sld_threshold = QDoubleSpinBox()
|
||||
self._sld_threshold.setRange(0.0, 1.0)
|
||||
self._sld_threshold.setSingleStep(0.05)
|
||||
self._sld_threshold.setValue(0.7)
|
||||
self._sld_threshold.setPrefix("Thr: ")
|
||||
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
|
||||
|
||||
self._cmb_scan_mode = QComboBox()
|
||||
self._cmb_scan_mode.addItems(["Average", "Nearest"])
|
||||
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
|
||||
|
||||
self._cmb_scan_ref = QComboBox()
|
||||
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
|
||||
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
|
||||
self._scan_folder: str = ""
|
||||
|
||||
self._scan_worker: ScanWorker | None = None
|
||||
```
|
||||
|
||||
**Step 2: Add controls to settings_row layout**
|
||||
|
||||
Find the `settings_row` assembly (around line 1620). Before `settings_row.addStretch()` (around line 1635), add:
|
||||
|
||||
```python
|
||||
settings_row.addWidget(self._btn_scan)
|
||||
settings_row.addWidget(self._sld_threshold)
|
||||
settings_row.addWidget(self._cmb_scan_mode)
|
||||
settings_row.addWidget(self._cmb_scan_ref)
|
||||
```
|
||||
|
||||
**Step 3: Add handler methods**
|
||||
|
||||
Add these methods to MainWindow (after `_jump_to_next_marker` around line 2410):
|
||||
|
||||
```python
|
||||
def _on_scan_ref_changed(self, index: int) -> None:
|
||||
if index == 1: # Custom Folder
|
||||
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
|
||||
if folder:
|
||||
self._scan_folder = folder
|
||||
else:
|
||||
self._cmb_scan_ref.setCurrentIndex(0)
|
||||
|
||||
def _cleanup_scan_worker(self) -> None:
|
||||
"""Disconnect signals and schedule deletion of old scan worker."""
|
||||
if self._scan_worker is not None:
|
||||
try:
|
||||
self._scan_worker.finished.disconnect()
|
||||
self._scan_worker.error.disconnect()
|
||||
self._scan_worker.progress.disconnect()
|
||||
except TypeError:
|
||||
pass # already disconnected
|
||||
self._scan_worker.deleteLater()
|
||||
self._scan_worker = None
|
||||
|
||||
def _start_scan(self) -> None:
|
||||
if not self._file_path:
|
||||
self._show_status("No video loaded")
|
||||
return
|
||||
if self._scan_worker and self._scan_worker.isRunning():
|
||||
self._show_status("Scan already running")
|
||||
return
|
||||
|
||||
# Clean up previous worker
|
||||
self._cleanup_scan_worker()
|
||||
|
||||
# Collect reference clip paths
|
||||
if self._cmb_scan_ref.currentIndex() == 0:
|
||||
# Current profile — all exports across all files in this profile
|
||||
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
|
||||
if os.path.exists(p)]
|
||||
else:
|
||||
# Custom folder
|
||||
if not self._scan_folder:
|
||||
self._show_status("No reference folder selected")
|
||||
return
|
||||
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
|
||||
clip_paths = [
|
||||
os.path.join(self._scan_folder, f)
|
||||
for f in sorted(os.listdir(self._scan_folder))
|
||||
if f.lower().endswith(exts)
|
||||
]
|
||||
|
||||
if not clip_paths:
|
||||
self._show_status("No reference clips found")
|
||||
return
|
||||
|
||||
mode = self._cmb_scan_mode.currentText().lower()
|
||||
threshold = self._sld_threshold.value()
|
||||
|
||||
self._btn_scan.setEnabled(False)
|
||||
self._scan_file_path = self._file_path # remember which file we're scanning
|
||||
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
|
||||
|
||||
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
|
||||
self._scan_worker.finished.connect(self._on_scan_done)
|
||||
self._scan_worker.error.connect(self._on_scan_error)
|
||||
self._scan_worker.progress.connect(self._show_status)
|
||||
self._scan_worker.start()
|
||||
|
||||
def _on_scan_done(self, regions: list) -> None:
|
||||
self._btn_scan.setEnabled(True)
|
||||
# Ignore stale results if the user switched files during scan
|
||||
if self._file_path != getattr(self, '_scan_file_path', None):
|
||||
return
|
||||
self._timeline.set_scan_regions(regions)
|
||||
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
||||
|
||||
def _on_scan_error(self, msg: str) -> None:
|
||||
self._btn_scan.setEnabled(True)
|
||||
self._show_status(f"Scan error: {msg}")
|
||||
```
|
||||
|
||||
**Step 4: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: Scan button, threshold spinner, mode dropdown, and reference source dropdown visible in the settings row. Clicking Scan with no file loaded shows "No video loaded" in status.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add scan UI controls and start_scan handler"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Keyboard shortcut — jump to next scan region
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py`
|
||||
|
||||
**Step 1: Add the keyboard shortcut**
|
||||
|
||||
Find the shortcut definitions (around line 1728, where `QShortcut(QKeySequence("M"), ...)` is defined). Add after it:
|
||||
|
||||
```python
|
||||
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
|
||||
```
|
||||
|
||||
**Step 2: Add the jump method**
|
||||
|
||||
After `_on_scan_error` (or after `_jump_to_next_marker`), add:
|
||||
|
||||
```python
|
||||
def _jump_to_next_scan_region(self) -> None:
|
||||
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
|
||||
if not regions:
|
||||
return
|
||||
for (start, _end, _score) in regions:
|
||||
if start > self._cursor + 0.1:
|
||||
self._step_cursor(start - self._cursor)
|
||||
return
|
||||
# Wrap to first region
|
||||
self._step_cursor(regions[0][0] - self._cursor)
|
||||
```
|
||||
|
||||
**Step 3: Update help text**
|
||||
|
||||
Find the help/shortcuts tooltip (around line 1757). Add a row:
|
||||
|
||||
```python
|
||||
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
|
||||
```
|
||||
|
||||
**Step 4: Clear scan regions and cancel running scan on file change**
|
||||
|
||||
Find `_load_file` method (around line 1931). After the existing marker/state resets, add:
|
||||
|
||||
```python
|
||||
self._timeline.clear_scan_regions()
|
||||
if self._scan_worker and self._scan_worker.isRunning():
|
||||
self._scan_worker.cancel()
|
||||
self._cleanup_scan_worker()
|
||||
self._btn_scan.setEnabled(True)
|
||||
```
|
||||
|
||||
**Step 5: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: S key does nothing when no scan regions exist. After a scan, S jumps through matched regions.
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add S shortcut and clear scan on file change"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Final integration test
|
||||
|
||||
**Step 1: End-to-end manual test**
|
||||
|
||||
1. Open the app: `cd /media/p5/8-cut && python main.py`
|
||||
2. Load a video file
|
||||
3. Export a few clips (these become the reference)
|
||||
4. Set reference source to "Current Profile"
|
||||
5. Click "Scan"
|
||||
6. Verify: status shows progress messages, then "Scan complete: N matching regions"
|
||||
7. Verify: cyan-tinted regions appear on the timeline
|
||||
8. Press S to jump through scan regions
|
||||
9. Change threshold and re-scan — verify different number of regions
|
||||
10. Switch mode to "Nearest" and re-scan
|
||||
11. Switch reference to "Custom Folder", pick a folder with clips
|
||||
12. Re-scan and verify results
|
||||
|
||||
**Step 2: Run all tests**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/ -v`
|
||||
Expected: all tests PASS
|
||||
|
||||
**Step 3: Final commit**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "feat: audio similarity scanning complete"
|
||||
```
|
||||
@@ -0,0 +1,98 @@
|
||||
# Audio Pipeline Improvements Design
|
||||
|
||||
Date: 2026-04-19
|
||||
|
||||
## Goal
|
||||
|
||||
Improve audio scan classification accuracy, especially for non-speech sounds (suction, gagging, impacts), through three changes:
|
||||
|
||||
1. Multi-layer feature extraction from existing HuBERT/Wav2Vec2 models
|
||||
2. Two new embedding models: AST (AudioSet-supervised) and EAT (self-supervised + AudioSet finetuned)
|
||||
3. Calibrated classifier for better threshold behavior
|
||||
|
||||
## 1. Multi-Layer Feature Extraction
|
||||
|
||||
### Current behavior
|
||||
|
||||
`model(waveforms)` extracts embeddings from the **last transformer layer only**.
|
||||
|
||||
### Change
|
||||
|
||||
Use `model.extract_features(waveforms)` (torchaudio API) to get all layer outputs. Select layers at quartile boundaries, mean-pool each over time, concatenate.
|
||||
|
||||
| Model | Layers | Single-layer dim | Multi-layer dim (4 quartiles) |
|
||||
|-------|--------|-------------------|-------------------------------|
|
||||
| HUBERT_XLARGE | 48 | 1280 | 5120 |
|
||||
| HUBERT_LARGE | 24 | 1024 | 4096 |
|
||||
| HUBERT_BASE | 12 | 768 | 3072 |
|
||||
| WAV2VEC2_BASE | 12 | 768 | 3072 |
|
||||
|
||||
### Implementation
|
||||
|
||||
- New entries in `_EMBED_MODELS`: `"HUBERT_XLARGE_ML"` -> 5120, etc.
|
||||
- `_extract_w2v_windows`: when model name ends with `_ML`, call `extract_features()` instead of `model()`, select quartile layers, concat
|
||||
- Cache key: model name includes `_ML` suffix -> separate cache files
|
||||
- No change to classifier or training pipeline (HistGBT handles high-dim fine)
|
||||
|
||||
## 2. AST (Audio Spectrogram Transformer)
|
||||
|
||||
### What
|
||||
|
||||
`MIT/ast-finetuned-audioset-10-10-0.4593` via HuggingFace `transformers`. 86M params, 768-dim, supervised on AudioSet 527 sound classes.
|
||||
|
||||
### Integration
|
||||
|
||||
- Load: `ASTModel.from_pretrained()` + `ASTFeatureExtractor`
|
||||
- Preprocessing: `ASTFeatureExtractor` handles mel spectrogram from 16kHz raw audio
|
||||
- Batching: prepare `input_values` per window, stack into batch, forward through model
|
||||
- Multi-layer: `output_hidden_states=True` returns 13 layers; `AST_ML` variant concats quartile layers -> 3072-dim
|
||||
- Model cached via `_get_w2v_model()` same lazy-load pattern
|
||||
|
||||
### Entries
|
||||
|
||||
- `"AST"` -> 768
|
||||
- `"AST_ML"` -> 3072
|
||||
|
||||
## 3. EAT (Efficient Audio Transformer)
|
||||
|
||||
### What
|
||||
|
||||
`worstchan/EAT-base_epoch30_finetune_AS2M` via HuggingFace with `trust_remote_code=True`. 88M params, 768-dim, self-supervised + AudioSet finetuned.
|
||||
|
||||
### Integration
|
||||
|
||||
- Load: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
||||
- Preprocessing: manual 128-bin Kaldi fbank mel spectrogram via torchaudio, normalize with EAT constants `(mel - (-4.268)) / (4.569 * 2)`, reshape to `[B, 1, T, 128]`
|
||||
- Feature extraction: `model.extract_features(mel)` returns `[B, seq, 768]`; CLS token `[:, 0, :]` for utterance-level, or mean-pool `[:, 1:, :]` for frame-level. Use mean-pool for consistency with other models.
|
||||
- Multi-layer: not natively supported, skip for now
|
||||
|
||||
### Entry
|
||||
|
||||
- `"EAT"` -> 768
|
||||
|
||||
## 4. Calibrated Classifier
|
||||
|
||||
Wrap `HistGradientBoostingClassifier` in `CalibratedClassifierCV(clf, cv=3, method='isotonic')` after fitting. Gives well-calibrated probabilities -> threshold slider maps more linearly to precision/recall.
|
||||
|
||||
One change in `train_classifier()`, no UI changes needed.
|
||||
|
||||
## 5. Requirements
|
||||
|
||||
Add to `requirements.txt`:
|
||||
```
|
||||
transformers>=4.30
|
||||
timm>=0.9
|
||||
```
|
||||
|
||||
Both AST and EAT need `transformers`. EAT additionally needs `timm` (used internally by its custom model code). Both setup scripts (`setup_env.sh`, `setup-windows.ps1`) install from `requirements.txt` so no changes needed there.
|
||||
|
||||
## Cache Compatibility
|
||||
|
||||
- All new model variants get distinct cache keys via model name in the hash
|
||||
- Existing caches for HUBERT_XLARGE, BEATs, etc. remain valid and untouched
|
||||
- New models create new `.npz` files in the same `cache/w2v/` directory
|
||||
|
||||
## UI Changes
|
||||
|
||||
- `_EMBED_MODELS` dict additions appear automatically in Train dialog model dropdown and scan model dropdown
|
||||
- No other UI changes needed
|
||||
@@ -0,0 +1,588 @@
|
||||
# Audio Pipeline Improvements Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Improve audio scan accuracy with multi-layer extraction, AST/EAT models, and calibrated classifier.
|
||||
|
||||
**Architecture:** All changes are in `core/audio_scan.py`. The embedding extraction functions gain new model-type branches (AST, EAT, multi-layer). The classifier gets a calibration wrapper. `_EMBED_MODELS` dict and `_get_w2v_model()` are extended. No UI changes needed — new models appear automatically in dropdowns.
|
||||
|
||||
**Tech Stack:** torchaudio (existing), transformers (new dep), timm (new dep), sklearn.calibration (existing dep)
|
||||
|
||||
**Key design notes:**
|
||||
- `_get_w2v_model()` resolves `_ML` suffixed names to their base model for loading (e.g. `HUBERT_XLARGE_ML` loads `HUBERT_XLARGE`). Both share the same GPU model — only the extraction path differs (last-layer vs multi-layer). The global `_w2v_model_name` stores the **base** name so switching between `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` does NOT trigger a reload.
|
||||
- Cache keys use the **full** model name (including `_ML`), so single-layer and multi-layer caches coexist as separate `.npz` files.
|
||||
- AST and EAT are separate model types that do NOT share the torchaudio loading path — they get their own `elif` branches in `_get_w2v_model()`.
|
||||
- Both `_extract_w2v_windows` and `_extract_w2v_targeted` need identical changes to their batch inference blocks. Keep them in sync.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add transformers and timm to requirements
|
||||
|
||||
**Files:**
|
||||
- Modify: `requirements.txt`
|
||||
|
||||
**Step 1: Add dependencies**
|
||||
|
||||
Add after the `torchaudio` line in `requirements.txt`:
|
||||
|
||||
```
|
||||
transformers>=4.30
|
||||
timm>=0.9
|
||||
```
|
||||
|
||||
**Step 2: Verify install**
|
||||
|
||||
Run: `pip install transformers timm`
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add requirements.txt
|
||||
git commit -m "deps: add transformers and timm for AST/EAT models"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Multi-layer extraction for torchaudio models
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-58` (_EMBED_MODELS dict)
|
||||
- Modify: `core/audio_scan.py:96-100` (_embed_dim)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model)
|
||||
- Modify: `core/audio_scan.py:189-205` (_extract_w2v_windows batch loop)
|
||||
- Modify: `core/audio_scan.py:278-293` (_extract_w2v_targeted batch loop)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
def test_embed_dim_multi_layer():
|
||||
from core.audio_scan import _embed_dim
|
||||
# Multi-layer models should report concatenated dimension
|
||||
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||
# Single-layer unchanged
|
||||
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||
Expected: FAIL — `_embed_dim("HUBERT_XLARGE_ML")` returns 768 (default fallback)
|
||||
|
||||
**Step 3: Add multi-layer entries to _EMBED_MODELS**
|
||||
|
||||
In `core/audio_scan.py:50-58`, add after existing entries:
|
||||
|
||||
```python
|
||||
_EMBED_MODELS = {
|
||||
"WAV2VEC2_BASE": 768,
|
||||
"WAV2VEC2_LARGE": 1024,
|
||||
"WAV2VEC2_LARGE_LV60K": 1024,
|
||||
"HUBERT_BASE": 768,
|
||||
"HUBERT_LARGE": 1024,
|
||||
"HUBERT_XLARGE": 1280,
|
||||
"BEATS": 768,
|
||||
# Multi-layer variants (4 quartile layers concatenated)
|
||||
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Add helper to resolve base model and layer indices**
|
||||
|
||||
Add after `_embed_dim()` (around line 101):
|
||||
|
||||
```python
|
||||
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||
|
||||
Returns None for single-layer models.
|
||||
Layer indices are 0-based into the list returned by extract_features().
|
||||
"""
|
||||
if not model_name.endswith("_ML"):
|
||||
return None
|
||||
base = model_name[:-3] # strip "_ML"
|
||||
if base not in _EMBED_MODELS:
|
||||
return None
|
||||
# Layer counts per model family
|
||||
layer_counts = {
|
||||
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||
"AST": 12,
|
||||
}
|
||||
n = layer_counts.get(base)
|
||||
if n is None:
|
||||
return None
|
||||
# Select 4 layers at quartile boundaries (0-indexed)
|
||||
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||
return base, indices
|
||||
```
|
||||
|
||||
Note: AST is included in the layer_counts dict here already so Task 3 doesn't need to modify it again.
|
||||
|
||||
**Step 6: Write test for _ml_config**
|
||||
|
||||
```python
|
||||
def test_ml_config():
|
||||
from core.audio_scan import _ml_config
|
||||
assert _ml_config("HUBERT_XLARGE") is None
|
||||
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||
assert base == "HUBERT_XLARGE"
|
||||
assert layers == [11, 23, 35, 47]
|
||||
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||
assert base == "HUBERT_BASE"
|
||||
assert layers == [2, 5, 8, 11]
|
||||
```
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_ml_config -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 7: Modify _get_w2v_model to resolve ML base names**
|
||||
|
||||
In `_get_w2v_model()` (line 68), the comparison key must use the resolved base name so that `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` share the same loaded model without reloading:
|
||||
|
||||
```python
|
||||
def _get_w2v_model(model_name: str | None = None):
|
||||
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||
global _w2v_model, _w2v_device, _w2v_model_name
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
# Multi-layer variants use the same base model weights
|
||||
ml = _ml_config(model_name)
|
||||
load_name = ml[0] if ml else model_name
|
||||
if _w2v_model is None or _w2v_model_name != load_name:
|
||||
import torch
|
||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if load_name == "BEATS":
|
||||
... # existing BEATs code unchanged
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||
_w2v_model.eval()
|
||||
_w2v_model_name = load_name
|
||||
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||
return _w2v_model, _w2v_device
|
||||
```
|
||||
|
||||
**Step 8: Modify _extract_w2v_windows batch inference**
|
||||
|
||||
In `_extract_w2v_windows`, compute `ml_cfg` **once** before the batch loop (after line 173 `is_beats = ...`):
|
||||
|
||||
```python
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
```
|
||||
|
||||
Then replace the batch inference block (lines 197-204):
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
```
|
||||
|
||||
**Step 9: Modify _extract_w2v_targeted batch inference (keep in sync)**
|
||||
|
||||
In `_extract_w2v_targeted`, add `ml_cfg` computation after line 276 `is_beats = ...`:
|
||||
|
||||
```python
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
```
|
||||
|
||||
Then replace the batch inference block (lines 285-292) with the same branching logic as Step 8:
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
```
|
||||
|
||||
Note: `_extract_w2v_targeted` appends to `embeddings_list` (not `embeddings`).
|
||||
|
||||
**Step 10: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 11: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: multi-layer extraction for HuBERT/Wav2Vec2 models"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: AST model integration
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add AST entries)
|
||||
- Modify: `core/audio_scan.py:45-47` (add _ast_feature_extractor global)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add AST loading branch)
|
||||
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add AST inference branch)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
```python
|
||||
def test_embed_dim_ast():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("AST") == 768
|
||||
assert _embed_dim("AST_ML") == 3072
|
||||
```
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_ast -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 2: Add AST entries to _EMBED_MODELS**
|
||||
|
||||
Add to the dict (after the ML entries):
|
||||
|
||||
```python
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
```
|
||||
|
||||
Run test again — should PASS now.
|
||||
|
||||
**Step 3: Add module-level global for AST feature extractor**
|
||||
|
||||
Near line 47 (after `_w2v_model_name = None`):
|
||||
|
||||
```python
|
||||
_ast_feature_extractor = None
|
||||
```
|
||||
|
||||
**Step 4: Add AST loading branch in _get_w2v_model**
|
||||
|
||||
In `_get_w2v_model()`, add an `elif` branch **before** the torchaudio fallback `else`:
|
||||
|
||||
```python
|
||||
elif load_name == "AST":
|
||||
from transformers import ASTModel, ASTFeatureExtractor
|
||||
_w2v_model = ASTModel.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
).to(_w2v_device)
|
||||
global _ast_feature_extractor
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
```
|
||||
|
||||
Note: `_ast_feature_extractor` is recreated on every model load (not cached separately) — simple and correct since the feature extractor is lightweight and model reloads are rare.
|
||||
|
||||
**Step 5: Add AST inference branch in both extraction functions**
|
||||
|
||||
In both `_extract_w2v_windows` and `_extract_w2v_targeted`, compute `is_ast` once before the loop:
|
||||
|
||||
```python
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
```
|
||||
|
||||
Then in the batch inference block, add after the `elif ml_cfg` branch and before `else`:
|
||||
|
||||
```python
|
||||
elif is_ast:
|
||||
# AST uses its own feature extractor for mel spectrogram
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
```
|
||||
|
||||
Important: `chunks` is already a list of numpy arrays (built in the loop at lines 194-196). Pass it directly as `list(chunks)` — the `ASTFeatureExtractor` accepts a list of numpy arrays and handles batching/padding internally. Verified: `ASTFeatureExtractor([np.array, np.array, ...], sampling_rate=16000, return_tensors="pt", padding=True)` returns `input_values` of shape `[B, 1024, 128]`.
|
||||
|
||||
**Step 6: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 7: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add AST (Audio Spectrogram Transformer) embedding model"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: EAT model integration
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add EAT entry)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add EAT loading branch)
|
||||
- Add: `core/audio_scan.py` (_eat_preprocess helper function)
|
||||
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add EAT inference branch)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
```python
|
||||
def test_embed_dim_eat():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("EAT") == 768
|
||||
```
|
||||
|
||||
**Step 2: Add EAT entry to _EMBED_MODELS**
|
||||
|
||||
```python
|
||||
"EAT": 768,
|
||||
```
|
||||
|
||||
Note: No `EAT_ML` variant — EAT's `extract_features()` does not natively support multi-layer output. Can be added later if needed by monkey-patching.
|
||||
|
||||
**Step 3: Add EAT loading branch in _get_w2v_model**
|
||||
|
||||
Add after the AST branch, before the torchaudio `else`:
|
||||
|
||||
```python
|
||||
elif load_name == "EAT":
|
||||
from transformers import AutoModel
|
||||
_w2v_model = AutoModel.from_pretrained(
|
||||
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||
trust_remote_code=True,
|
||||
).to(_w2v_device)
|
||||
```
|
||||
|
||||
**Step 4: Add EAT preprocessing helper**
|
||||
|
||||
Add as a module-level function near `_get_w2v_model`:
|
||||
|
||||
```python
|
||||
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||
|
||||
Returns tensor of shape [B, 1, T, 128].
|
||||
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
TARGET_LEN = 1024
|
||||
MEAN, STD = -4.268, 4.569
|
||||
|
||||
mels = []
|
||||
for chunk in chunks:
|
||||
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||
fbank = kaldi.fbank(
|
||||
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||
)
|
||||
# Pad or truncate to TARGET_LEN
|
||||
if fbank.shape[0] < TARGET_LEN:
|
||||
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||
else:
|
||||
fbank = fbank[:TARGET_LEN]
|
||||
fbank = (fbank - MEAN) / (STD * 2)
|
||||
mels.append(fbank)
|
||||
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||
```
|
||||
|
||||
**Step 5: Add EAT inference branch in both extraction functions**
|
||||
|
||||
Compute `is_eat` once before the loop:
|
||||
|
||||
```python
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||
```
|
||||
|
||||
Then in the batch inference block, add after the `elif is_ast` branch and before `else`:
|
||||
|
||||
```python
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
# Mean-pool frame-level tokens (skip CLS at index 0)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
```
|
||||
|
||||
Important: `model.extract_features()` returns a plain `torch.Tensor` of shape `[B, 513, 768]` (not a tuple). Index 0 is the CLS token, indices 1-512 are frame-level patch embeddings. We mean-pool the frame tokens for consistency with how other models are pooled.
|
||||
|
||||
**Step 6: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 7: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add EAT (Efficient Audio Transformer) embedding model"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Calibrated classifier
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:424-429` (train_classifier, wrap clf)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Modify train_classifier**
|
||||
|
||||
After the existing `clf.fit()` call (line 428), add calibration with a safe guard:
|
||||
|
||||
```python
|
||||
clf.fit(X[train_idx], y_arr[train_idx])
|
||||
_log("audio_scan: classifier trained")
|
||||
|
||||
# Calibrate probabilities for better threshold behavior
|
||||
# Requires at least 6 samples per class for stable 3-fold isotonic calibration
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
min_class = min(int(n_pos), int(n_neg_sample))
|
||||
if min_class >= 6:
|
||||
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||
clf = cal_clf
|
||||
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||
else:
|
||||
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||
```
|
||||
|
||||
Why `min_class >= 6`: `CalibratedClassifierCV` uses stratified k-fold internally. With `cv=3`, each fold needs at least 2 samples per class. `min_class >= 6` guarantees this. With fewer samples, the uncalibrated HistGBT probabilities are still reasonable — calibration is an enhancement, not a requirement.
|
||||
|
||||
Previous plan bug: `cv=min(3, n_pos, n_neg_sample)` could produce `cv=1` when `n_pos=1`, which raises `ValueError` (minimum is 2). Even `cv=2` with 2 positives causes one fold to have only 1 positive, making isotonic regression unstable. The `>= 6` guard avoids all these edge cases.
|
||||
|
||||
**Step 2: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py
|
||||
git commit -m "feat: calibrate classifier probabilities with isotonic regression"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Integration test with real model (manual)
|
||||
|
||||
This task is manual — it requires GPU and a real video file.
|
||||
|
||||
**Step 1: Test multi-layer extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='HUBERT_XLARGE_ML')
|
||||
print(f'HUBERT_XLARGE_ML: {emb.shape}') # expect (13, 5120)
|
||||
assert emb.shape[1] == _embed_dim('HUBERT_XLARGE_ML')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 2: Test AST extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='AST')
|
||||
print(f'AST: {emb.shape}') # expect (13, 768)
|
||||
assert emb.shape[1] == _embed_dim('AST')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 3: Test AST multi-layer**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='AST_ML')
|
||||
print(f'AST_ML: {emb.shape}') # expect (13, 3072)
|
||||
assert emb.shape[1] == _embed_dim('AST_ML')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 4: Test EAT extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='EAT')
|
||||
print(f'EAT: {emb.shape}') # expect (13, 768)
|
||||
assert emb.shape[1] == _embed_dim('EAT')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 5: Test model switching doesn't reload unnecessarily**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _get_w2v_model
|
||||
import core.audio_scan as m
|
||||
# Load HUBERT_XLARGE
|
||||
_get_w2v_model('HUBERT_XLARGE')
|
||||
name1 = m._w2v_model_name
|
||||
# Switch to ML variant — should NOT reload
|
||||
_get_w2v_model('HUBERT_XLARGE_ML')
|
||||
name2 = m._w2v_model_name
|
||||
assert name1 == name2 == 'HUBERT_XLARGE', f'Expected no reload, got {name1} -> {name2}'
|
||||
print('PASS: no reload on ML switch')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 6: Test full train+scan cycle in app**
|
||||
|
||||
Load app, select each new model from scan model dropdown, scan a video, train, verify results display correctly.
|
||||
|
||||
**Step 7: Final commit and push**
|
||||
|
||||
```bash
|
||||
git push
|
||||
```
|
||||
@@ -0,0 +1,226 @@
|
||||
# ComfyUI-8cut Node Pack Design
|
||||
|
||||
Date: 2026-04-19
|
||||
|
||||
## Goal
|
||||
|
||||
Port 8-cut's video scanning, training, review, and export workflow to a ComfyUI node pack. The primary motivation is **remote access** — ComfyUI's web UI allows browser-based operation over the network, and HTML5 `<video>` handles streaming compression natively. No tensor-based image pipeline; videos stay as file paths throughout.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Approach
|
||||
|
||||
Monolithic Review Node + simple pipeline nodes. One central **VideoReview** node embeds the full interactive player/timeline/region table as a large DOM widget. Other nodes (Scan, Train, Export) are headless pipeline nodes that pass lightweight metadata.
|
||||
|
||||
### Core reuse
|
||||
|
||||
The entire `8-cut/core/` package is Qt-free and reusable as-is:
|
||||
- `core/audio_scan.py` — `scan_video()`, `train_classifier()`, `load_classifier()`
|
||||
- `core/db.py` — `ProcessedDB` (SQLite, all scan/training/export persistence)
|
||||
- `core/ffmpeg.py` — `build_ffmpeg_command()` (clip export)
|
||||
- `core/tracking.py` — YOLO-based subject tracking
|
||||
- `core/paths.py` — path helpers, `format_time()`
|
||||
|
||||
No porting required — these are imported directly.
|
||||
|
||||
---
|
||||
|
||||
## Node Pack Structure
|
||||
|
||||
```
|
||||
ComfyUI-8cut/
|
||||
__init__.py # NODE_CLASS_MAPPINGS, WEB_DIRECTORY
|
||||
core/ # symlink or copy of 8-cut/core/
|
||||
data/
|
||||
8cut.db # separate SQLite DB (can copy from ~/.8cut.db)
|
||||
models/ # trained classifiers (.joblib)
|
||||
nodes/
|
||||
load_video.py
|
||||
audio_scan.py
|
||||
video_review.py
|
||||
train_model.py
|
||||
export_clips.py
|
||||
server_routes.py # custom API routes
|
||||
web/
|
||||
js/
|
||||
video_review.js # timeline + player + scan panel widget
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Custom Types
|
||||
|
||||
No tensors anywhere in the pipeline. All data flows as lightweight metadata:
|
||||
|
||||
| Type | Python value | Purpose |
|
||||
|------|-------------|---------|
|
||||
| `VIDEO_PATH` | `str` (absolute path) | Video file reference |
|
||||
| `SCAN_REGIONS` | `list[dict]` with start/end/score/model/disabled | Scan output / review edits |
|
||||
| `SCAN_MODEL` | `str` (path to .joblib) | Trained classifier |
|
||||
|
||||
---
|
||||
|
||||
## Nodes
|
||||
|
||||
### LoadVideo
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `video_path` (STRING, file browser), `profile` (STRING combo from DB profiles) |
|
||||
| **Output** | `VIDEO_PATH`, `filename` (STRING) |
|
||||
| **Logic** | Validates path exists, returns it. Populates profile combo via API route. |
|
||||
|
||||
### AudioScan
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_MODEL`, `threshold` (FLOAT 0-1), `hop` (FLOAT) |
|
||||
| **Output** | `SCAN_REGIONS` |
|
||||
| **Logic** | Calls `core.audio_scan.scan_video()` directly. Progress via `PromptServer.send_sync("progress", ...)`. |
|
||||
|
||||
### VideoReview (interactive, blocking)
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS` (optional) |
|
||||
| **Output** | `SCAN_REGIONS` (edited) |
|
||||
| **OUTPUT_NODE** | `True` |
|
||||
| **Logic** | Execution pauses here. User interacts via the widget. Clicks "Continue" to pass edited regions downstream. |
|
||||
|
||||
The widget layout:
|
||||
|
||||
```
|
||||
+-------------------------------------+
|
||||
| [video player (HTML5 <video>)] |
|
||||
| +- timeline with scan regions ----+|
|
||||
| | cursor + region drag/resize ||
|
||||
| +---------------------------------+|
|
||||
| +- model tabs [EAT_LARGE][HuBERT]+|
|
||||
| | Time | End | Score ||
|
||||
| | 1:23 | 1:31 | 0.92 ||
|
||||
| | 3:45 | 3:53 | 0.87 ||
|
||||
| | [Add Negative] [Export] [Continue]|
|
||||
| +---------------------------------+|
|
||||
+-------------------------------------+
|
||||
```
|
||||
|
||||
Widget size: ~640x500px minimum, resizable via LiteGraph.
|
||||
|
||||
**Blocking mechanism**: The node's `run()` method blocks on a server-side event/queue. The frontend signals completion via `POST /8cut/review_done/{node_id}`, which unblocks `run()` and returns the edited `SCAN_REGIONS`.
|
||||
|
||||
### TrainModel
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `profile` (STRING combo), `positive_folder` (STRING combo), `negative_folder` (STRING combo, optional), `embed_model` (STRING combo from `_EMBED_MODELS`), `use_hard_negatives` (BOOL) |
|
||||
| **Output** | `SCAN_MODEL` |
|
||||
| **Logic** | Queries `db.get_training_data()` to assemble `video_infos`, calls `core.audio_scan.train_classifier()`. Saves to `models/{profile}_{embed_model}.joblib` with version rotation. Progress via ComfyUI progress bar. |
|
||||
|
||||
### ExportClips
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS`, `output_folder` (STRING), `short_side` (INT), `format` (combo MP4/WEBM), `spread` (FLOAT), `clip_count` (INT), `fuse_gap` (FLOAT) |
|
||||
| **Output** | exported file paths (list) |
|
||||
| **Logic** | Region fusion via `_build_export_spans()`, then `core.ffmpeg.build_ffmpeg_command()` per clip. Records each clip in DB via `db.add()`. |
|
||||
|
||||
### Typical workflow
|
||||
|
||||
```
|
||||
[LoadVideo] --> [AudioScan] --> [VideoReview] --> [ExportClips]
|
||||
^
|
||||
[TrainModel]
|
||||
```
|
||||
|
||||
### Training loop (hard negatives round-trip)
|
||||
|
||||
1. Scan with existing model -> regions in VideoReview
|
||||
2. Review -> mark false positives as negatives (DB)
|
||||
3. Train -> new model uses hard negatives
|
||||
4. Rescan -> better results
|
||||
5. Repeat
|
||||
|
||||
---
|
||||
|
||||
## API Routes
|
||||
|
||||
### Video serving
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/video` | GET | Serve raw video file via `web.FileResponse`. Query param: `path`. Browser decodes mp4/h264 natively — key for remote streaming. |
|
||||
| `/8cut/video_transcode` | GET | Fallback: transcode to webm on-the-fly via ffmpeg `StreamResponse` for browser-incompatible formats (some MKV, odd codecs). |
|
||||
|
||||
### Region editing (from VideoReview widget)
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/toggle_region` | POST | `toggle_scan_result_disabled()` |
|
||||
| `/8cut/resize_region` | POST | `update_scan_result()` |
|
||||
| `/8cut/delete_region` | POST | `delete_scan_result()` |
|
||||
| `/8cut/add_negatives` | POST | `add_hard_negatives()` |
|
||||
| `/8cut/scan_versions` | GET | `get_scan_versions()` |
|
||||
| `/8cut/review_done/{node_id}` | POST | Unblock the VideoReview node's `run()`, pass final regions |
|
||||
|
||||
### Data queries (for combo widget population)
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/profiles` | GET | `db.get_profiles()` |
|
||||
| `/8cut/export_folders` | GET | `db.get_export_folders()` |
|
||||
| `/8cut/models` | GET | List available `.joblib` models |
|
||||
|
||||
---
|
||||
|
||||
## Frontend JS Widget (`web/js/video_review.js`)
|
||||
|
||||
Registered via `app.registerExtension()`. Hooks into the VideoReview node's `onNodeCreated` and `onExecuted` callbacks.
|
||||
|
||||
### Components
|
||||
|
||||
1. **Video player** — HTML5 `<video>` element, src pointed at `/8cut/video?path=...`
|
||||
2. **Timeline** — `<canvas>` overlay below the video. Renders:
|
||||
- Scan region rectangles (color-coded by score, red for negatives, gray for disabled)
|
||||
- Cursor line (click to seek)
|
||||
- Drag handles on region edges (resize)
|
||||
- Waveform (optional, fetched via separate route)
|
||||
3. **Region table** — HTML table with model tabs. Click row to seek. Columns: Time, End, Score.
|
||||
4. **Action buttons** — Add Negative, Export, Continue
|
||||
5. **Version combo** — dropdown to switch scan history versions
|
||||
|
||||
### Interaction flow
|
||||
|
||||
- Widget activates when `onExecuted` fires with scan regions
|
||||
- User clicks/drags timeline, edits regions, marks negatives
|
||||
- Each edit hits an API route (immediate DB persistence)
|
||||
- "Continue" sends `POST /8cut/review_done/{node_id}` with final region state
|
||||
- Node's `run()` unblocks, passes `SCAN_REGIONS` downstream
|
||||
|
||||
---
|
||||
|
||||
## DB
|
||||
|
||||
Separate SQLite DB at `ComfyUI-8cut/data/8cut.db`. Uses the existing `ProcessedDB` class unchanged — same schema, same migration code. Users can copy their existing `~/.8cut.db` to carry over scan history, training data, and hard negatives.
|
||||
|
||||
---
|
||||
|
||||
## Dependencies
|
||||
|
||||
Same as 8-cut's `requirements.txt` minus PyQt6/python-mpv:
|
||||
- `torch`, `torchaudio`, `torchvision` (from CUDA index)
|
||||
- `transformers>=4.30,<5.0`, `timm>=0.9`
|
||||
- `librosa`, `scikit-learn`, `joblib`, `soundfile`, `numpy`
|
||||
- `ultralytics` (YOLO tracking)
|
||||
|
||||
ComfyUI already provides torch. The node pack's install script just needs the audio/ML extras.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
1. **Node pack skeleton** — structure, `__init__.py`, custom types, API routes for video serving
|
||||
2. **LoadVideo + AudioScan** — headless nodes, no widget needed yet
|
||||
3. **VideoReview widget (minimal)** — video player + static region display + Continue button
|
||||
4. **VideoReview interactivity** — timeline click/drag, region editing, negative marking
|
||||
5. **TrainModel + ExportClips** — complete the pipeline
|
||||
6. **Polish** — version history, waveform overlay, transcode fallback
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,205 @@
|
||||
# Scan History & Hard Negative Management — Final Design
|
||||
|
||||
Date: 2026-04-19 (implemented on `feat/training-ui`)
|
||||
|
||||
## Goal
|
||||
|
||||
1. Keep scan result history per `(file, model)` so users can track classifier improvement across training iterations
|
||||
2. Make hard negatives manageable — viewable, removable, and optionally disabled per training run
|
||||
3. Fix latent bug: `get_export_folders()` doesn't filter by `scan_export`
|
||||
|
||||
---
|
||||
|
||||
## 1. Ghost Folder Fix
|
||||
|
||||
### Bug
|
||||
|
||||
`get_export_folders()` queried all `output_path` rows without filtering `scan_export`. Folders that only contained scan-exported clips appeared in training dropdowns with 0 clips.
|
||||
|
||||
### Implementation (`core/db.py`)
|
||||
|
||||
**`get_export_folders(profile, include_scan_exports=False)`** — new parameter. When `False` (default), the SQL query adds `AND scan_export = 0` to exclude scan-only folders. The `get_training_stats()` method passes this through and also filters its return dict to remove folders with 0 clips:
|
||||
|
||||
```python
|
||||
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
`tests/test_db.py::test_export_folders_excludes_scan_exports` — verifies scan-only folders are excluded by default and included when `include_scan_exports=True`.
|
||||
|
||||
---
|
||||
|
||||
## 2. Scan Result History
|
||||
|
||||
### Schema
|
||||
|
||||
Added column to `scan_results`:
|
||||
|
||||
```sql
|
||||
scan_timestamp TEXT NOT NULL DEFAULT ''
|
||||
```
|
||||
|
||||
All rows from the same scan share one timestamp string with **microsecond precision** (`%Y%m%d_%H%M%S_%f`, e.g. `"20260419_143022_123456"`). Microsecond precision prevents version collisions on fast successive scans.
|
||||
|
||||
Migration adds the column via `ALTER TABLE` for existing databases. Legacy rows keep `scan_timestamp = ''`.
|
||||
|
||||
### DB methods (`core/db.py`)
|
||||
|
||||
**`save_scan_results(filename, profile, model, regions, max_versions=5)`**
|
||||
1. Inserts new rows with current microsecond-precision timestamp
|
||||
2. Counts distinct timestamps for this `(filename, profile, model)`
|
||||
3. Prunes oldest timestamps beyond `max_versions`
|
||||
|
||||
No more DELETE-then-INSERT — all versions coexist in the table.
|
||||
|
||||
**`get_scan_versions(filename, profile, model)`**
|
||||
Returns `[{timestamp, count, max_score}, ...]` ordered newest first. Filters `scan_timestamp != ''` so legacy rows don't appear as named versions.
|
||||
|
||||
**`get_scan_results(filename, profile, scan_timestamp=None)`**
|
||||
- With `scan_timestamp`: returns rows matching that exact version
|
||||
- Without (default): uses `INNER JOIN` subquery with `MAX(scan_timestamp)` per model to return only the latest version. Legacy rows (empty timestamp) sort before any real timestamp, so they're returned when no versioned scans exist.
|
||||
|
||||
### UI (`main.py` — `ScanResultsPanel`)
|
||||
|
||||
Each model tab wraps its `QTableWidget` in a container `QWidget` with a `QComboBox` for version selection:
|
||||
|
||||
```
|
||||
container (QWidget)
|
||||
├── cmb_version (QComboBox) — hidden when ≤ 1 version
|
||||
└── table (QTableWidget)
|
||||
```
|
||||
|
||||
**Helper methods** unwrap this container:
|
||||
- `_current_table()` — returns `QTableWidget` from active tab (handles both raw table and container)
|
||||
- `_tab_table(index)` — same by tab index
|
||||
|
||||
**Version combo** is populated by `_populate_version_combos()` after every `load_for_file()` and `add_scan_results()` call. Labels use `datetime.strptime` parsing with try/except fallback for robustness:
|
||||
|
||||
```
|
||||
2026-04-19 14:30 (12 regions, best: 0.95)
|
||||
```
|
||||
|
||||
**Version switching** via `_on_version_changed(model, idx)`:
|
||||
1. Reads `scan_timestamp` from combo's `userData`
|
||||
2. Calls `get_scan_results(filename, profile, scan_timestamp=ts)`
|
||||
3. Repopulates the table in-place
|
||||
4. **Clears the undo stack** — stale undo entries from a different version would corrupt data
|
||||
5. Emits `regions_edited` to refresh the timeline
|
||||
|
||||
**Tab switch** connects `tab_changed` signal to `_on_scan_regions_edited` (not just `_update_scan_export_count`), so the timeline updates scan regions when switching model tabs.
|
||||
|
||||
### Cache interaction
|
||||
|
||||
Embedding cache is per `(file, model)` and doesn't change across scans. History stores classified regions (start, end, score), not embeddings.
|
||||
|
||||
### Test
|
||||
|
||||
`tests/test_db.py::test_scan_result_history` — saves 3 versions, verifies counts, ordering, and latest-by-default behavior.
|
||||
|
||||
---
|
||||
|
||||
## 3. Hard Negative Management
|
||||
|
||||
### Schema
|
||||
|
||||
Added column to `hard_negatives`:
|
||||
|
||||
```sql
|
||||
source_model TEXT NOT NULL DEFAULT ''
|
||||
```
|
||||
|
||||
Migration adds the column via `ALTER TABLE` for existing databases.
|
||||
|
||||
### DB methods (`core/db.py`)
|
||||
|
||||
**`add_hard_negatives(filename, profile, times, source_path="", source_model="")`** — now stores which embedding model produced the scan that led to the negative marking.
|
||||
|
||||
**`get_hard_negatives(profile)`** — returns all rows as `[{id, filename, start_time, source_path, source_model}, ...]` for the management dialog.
|
||||
|
||||
**`delete_hard_negatives_by_ids(ids)`** — bulk delete by row IDs.
|
||||
|
||||
**`get_training_data(..., use_hard_negatives=True)`** — new parameter. When `False`, the hard negatives query is skipped entirely. Non-destructive — negatives remain in DB.
|
||||
|
||||
### Source model tracking (`main.py`)
|
||||
|
||||
`_on_scan_negatives()` now passes `source_model=self._scan_panel.current_model_name()` when marking negatives from scan results. `current_model_name()` extracts the model name from the active tab text (stripping the count suffix).
|
||||
|
||||
### Training toggle (`main.py` — `TrainDialog`)
|
||||
|
||||
Checkbox **"Use hard negatives in training"** (default checked) with "Manage..." button in an HBox layout. The toggle:
|
||||
- Updates live training stats preview via debounced `_update_stats()`
|
||||
- Passes `use_hard_negatives` through `_open_train_dialog()` to `get_training_data()`
|
||||
|
||||
### Management dialog (`main.py` — `HardNegativesDialog`)
|
||||
|
||||
Accessible from TrainDialog's "Manage..." button. Features:
|
||||
|
||||
| Component | Details |
|
||||
|-----------|---------|
|
||||
| **Filter combo** | `(all)` + each distinct `source_model` found in data |
|
||||
| **Summary label** | `<b>N</b> hard negatives` |
|
||||
| **Table** | File, Time (`{:.1f}s`), Source Model, hidden ID column |
|
||||
| **Delete Selected** | Multi-select aware, skips hidden (filtered) rows |
|
||||
| **Clear All** | **Filter-aware**: if a model filter is active, only deletes negatives for that model with an appropriate confirmation message. If `(all)`, deletes everything. |
|
||||
| **Close** | Closes dialog, triggers stats refresh in parent TrainDialog |
|
||||
|
||||
`blockSignals(True)` guards prevent spurious filter callbacks during `_load()` repopulation.
|
||||
|
||||
### Tests
|
||||
|
||||
- `test_hard_negatives_source_model` — verifies source_model stored and retrieved
|
||||
- `test_training_data_skips_hard_negatives` — verifies `use_hard_negatives=False` excludes them
|
||||
- `test_delete_hard_negatives_by_ids` — verifies bulk deletion by ID
|
||||
|
||||
---
|
||||
|
||||
## 4. Runtime Fixes (discovered during testing)
|
||||
|
||||
### EAT/torchvision ABI mismatch
|
||||
|
||||
**Problem:** `torchvision` installed from PyPI (CPU build) was incompatible with `torch` from CUDA wheel index, causing `operator torchvision::nms does not exist`.
|
||||
|
||||
**Fix:** Added `torchvision` to the explicit torch install line in both setup scripts:
|
||||
```bash
|
||||
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||
```
|
||||
|
||||
Also added `--extra-index-url "$TORCH_INDEX"` to the `pip install -r requirements.txt` line to prevent transitive dependencies (timm, ultralytics) from pulling CPU-only torch packages.
|
||||
|
||||
Applied to: `setup_env.sh` (both conda and venv paths), `setup-windows.ps1`.
|
||||
|
||||
### EAT / transformers 5.x incompatibility
|
||||
|
||||
**Problem:** transformers 5.x broke EAT's remote model code (`'EATModel' object has no attribute 'all_tied_weights_keys'`).
|
||||
|
||||
**Fix:** Pinned `transformers>=4.30,<5.0` in `requirements.txt`.
|
||||
|
||||
### NumPy non-writable array warning
|
||||
|
||||
**Problem:** Cached HuBERT/EAT embeddings loaded from disk are read-only numpy arrays. `torch.from_numpy()` on a non-writable array triggers a deprecation warning.
|
||||
|
||||
**Fix:** In `core/audio_scan.py`, changed EAT preprocessing to copy the array:
|
||||
```python
|
||||
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||
```
|
||||
|
||||
### Timeline not updating on tab switch
|
||||
|
||||
**Problem:** Switching model tabs in the scan results panel didn't refresh the timeline's highlighted regions because `tab_changed` was only connected to `_update_scan_export_count`.
|
||||
|
||||
**Fix:** Connected `tab_changed` to `_on_scan_regions_edited` instead, which handles both timeline refresh and export count update.
|
||||
|
||||
---
|
||||
|
||||
## File Summary
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `core/db.py` | Schema migrations, `get_export_folders` filter, versioned `save_scan_results`, `get_scan_versions`, version-aware `get_scan_results`, `add_hard_negatives` with `source_model`, `get_hard_negatives`, `delete_hard_negatives_by_ids`, `get_training_data` with `use_hard_negatives` |
|
||||
| `main.py` | `HardNegativesDialog` class, `TrainDialog` hard neg toggle + manage button, `ScanResultsPanel` container/combo architecture, version combo population and switching, `current_model_name()`, tab-switch timeline fix |
|
||||
| `core/audio_scan.py` | `np.array(chunk)` copy for read-only numpy arrays in EAT preprocessing |
|
||||
| `requirements.txt` | `transformers>=4.30,<5.0` pin |
|
||||
| `setup_env.sh` | `torchvision` in torch install, `--extra-index-url` on requirements install |
|
||||
| `setup-windows.ps1` | `torchvision` in torch install, `--extra-index-url` on requirements install, removed skip-if-exists guard |
|
||||
| `tests/test_db.py` | 5 tests covering all DB-layer changes |
|
||||
@@ -0,0 +1,94 @@
|
||||
# Scan History & Hard Negative Management — Implementation Log
|
||||
|
||||
> All tasks complete. See the design doc for the final specification.
|
||||
|
||||
**Branch:** `feat/training-ui`
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Fix ghost folder bug in get_export_folders -- DONE
|
||||
|
||||
**Commit:** `2614a76 fix: get_export_folders respects scan_export filter`
|
||||
|
||||
- `core/db.py` — `get_export_folders(profile, include_scan_exports=False)`: filters `scan_export = 0` by default
|
||||
- `core/db.py` — `get_training_stats()`: passes `include_scan_exports` through, filters out 0-clip folders
|
||||
- `tests/test_db.py` — `test_export_folders_excludes_scan_exports`
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Scan result history — schema and DB methods -- DONE
|
||||
|
||||
**Commit:** `4fb2ae1 feat: scan result history — keep N versions per (file, model)`
|
||||
|
||||
- `core/db.py` — added `scan_timestamp TEXT NOT NULL DEFAULT ''` column with migration
|
||||
- `core/db.py` — `save_scan_results()`: versioned insert with microsecond-precision timestamp (`%Y%m%d_%H%M%S_%f`), auto-prunes beyond `max_versions=5`
|
||||
- `core/db.py` — `get_scan_versions()`: returns `[{timestamp, count, max_score}, ...]` newest first
|
||||
- `core/db.py` — `get_scan_results(scan_timestamp=None)`: `INNER JOIN` subquery with `MAX(scan_timestamp)` for latest-by-default
|
||||
- `tests/test_db.py` — `test_scan_result_history`
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Scan history UI — version selector in ScanResultsPanel -- DONE
|
||||
|
||||
**Commit:** `8ed9fbf feat: scan version selector in results panel`
|
||||
|
||||
- `main.py` — `_add_tab()`: wraps table in container `QWidget` with version `QComboBox` (hidden when ≤ 1 version)
|
||||
- `main.py` — `_current_table()` / `_tab_table(idx)`: unwrap container to get `QTableWidget`
|
||||
- `main.py` — `_populate_version_combos()`: queries `get_scan_versions()`, formats labels with `datetime.strptime` + try/except fallback
|
||||
- `main.py` — `_on_version_changed()`: reloads table from specific version, clears undo stack, emits `regions_edited`
|
||||
- `main.py` — `current_model_name()`: extracts model name from tab text
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Hard negatives — schema and training toggle -- DONE
|
||||
|
||||
**Commit:** `edc5784 feat: hard negative source_model tracking, training toggle`
|
||||
|
||||
- `core/db.py` — added `source_model TEXT NOT NULL DEFAULT ''` column to `hard_negatives` with migration
|
||||
- `core/db.py` — `add_hard_negatives(source_model="")`: stores originating model
|
||||
- `core/db.py` — `get_hard_negatives(profile)`: returns full rows as list of dicts
|
||||
- `core/db.py` — `delete_hard_negatives_by_ids(ids)`: bulk delete by row IDs
|
||||
- `core/db.py` — `get_training_data(use_hard_negatives=True)`: conditionally skips hard negatives query
|
||||
- `main.py` — `TrainDialog`: "Use hard negatives" checkbox + "Manage..." button in HBox layout
|
||||
- `main.py` — `_on_scan_negatives()`: passes `source_model=self._scan_panel.current_model_name()`
|
||||
- `tests/test_db.py` — `test_hard_negatives_source_model`, `test_training_data_skips_hard_negatives`, `test_delete_hard_negatives_by_ids`
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Hard negatives management dialog -- DONE
|
||||
|
||||
**Commit:** `e6db83f feat: hard negatives management dialog with filter and bulk delete`
|
||||
|
||||
- `main.py` — `HardNegativesDialog`: table with File/Time/Source Model/hidden ID columns, model filter combo, delete selected, filter-aware clear all, close button
|
||||
- Filter-aware "Clear All": respects active model filter, shows appropriate confirmation message
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Code review fixes -- DONE
|
||||
|
||||
**Commit:** `5d45b8d fix: timestamp collision, undo stack invalidation, label parsing, filter-aware clear`
|
||||
|
||||
Four issues found during code review:
|
||||
1. **Timestamp collision** — second-precision timestamps could merge versions on sub-second calls. Fixed with microsecond precision `%f`
|
||||
2. **Undo stack invalidation** — switching scan versions left stale undo entries. Fixed by clearing undo stack in `_on_version_changed()`
|
||||
3. **Timestamp label fragile parsing** — hard-coded string slicing. Fixed with `datetime.strptime` + try/except fallback
|
||||
4. **Clear All ignoring filter** — deleted all negatives regardless of model filter. Fixed to respect active filter
|
||||
|
||||
---
|
||||
|
||||
### Runtime fixes (discovered during manual testing)
|
||||
|
||||
| Commit | Fix |
|
||||
|--------|-----|
|
||||
| `a3c657c` | Install `torchvision` from CUDA wheel index (was pulling CPU build from PyPI) |
|
||||
| `3c3b1d7` | Remove "skip if torch exists" guard in Windows setup so re-runs fix broken envs |
|
||||
| `fd043f4` | Pin `transformers>=4.30,<5.0` — EAT remote model code incompatible with transformers 5.x |
|
||||
| `7d6fee9` | Copy read-only numpy array before `torch.from_numpy()` in EAT preprocessing |
|
||||
| `bd345ab` | Connect `tab_changed` to `_on_scan_regions_edited` so timeline refreshes on tab switch |
|
||||
| `d8b3972` | Add `--extra-index-url` to `pip install -r requirements.txt` in both setup scripts |
|
||||
|
||||
---
|
||||
|
||||
### Test results
|
||||
|
||||
All 68 tests pass (5 new DB tests + 63 existing).
|
||||
+20
-1
@@ -1,4 +1,23 @@
|
||||
# Core GUI
|
||||
PyQt6>=6.4
|
||||
python-mpv>=1.0
|
||||
pytest>=7.0
|
||||
|
||||
# Audio & ML
|
||||
librosa>=0.10
|
||||
numpy>=1.24
|
||||
scikit-learn>=1.3
|
||||
joblib>=1.3
|
||||
soundfile>=0.12
|
||||
|
||||
# Deep learning — install via setup_env.sh for correct CUDA version,
|
||||
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
torch>=2.0
|
||||
torchaudio>=2.0
|
||||
transformers>=4.30,<5.0 # EAT remote model code incompatible with transformers 5.x
|
||||
timm>=0.9
|
||||
|
||||
# Object detection
|
||||
ultralytics>=8.0
|
||||
|
||||
# Dev
|
||||
pytest>=7.0
|
||||
|
||||
+40
-8
@@ -1,19 +1,46 @@
|
||||
# 8-cut Windows setup script
|
||||
# Run once: powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||
#
|
||||
# Prerequisites: Python 3.10+ must be installed and on PATH
|
||||
# Prerequisites: Python 3.11+ must be installed and on PATH
|
||||
# https://www.python.org/downloads/
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
trap { Write-Host "`n$_" -ForegroundColor Red; Read-Host "Press Enter to close"; exit 1 }
|
||||
$root = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
|
||||
Write-Host "=== 8-cut Windows Setup ===" -ForegroundColor Cyan
|
||||
|
||||
# ── Python deps ────────────────────────────────────────────
|
||||
Write-Host "`nInstalling Python dependencies..."
|
||||
pip install PyQt6 python-mpv
|
||||
# ── Virtual environment ───────────────────────────────────
|
||||
$venvDir = Join-Path $root ".venv"
|
||||
if (Test-Path (Join-Path $venvDir "Scripts\python.exe")) {
|
||||
Write-Host "`nVirtual environment already exists, activating..." -ForegroundColor Green
|
||||
} else {
|
||||
Write-Host "`nCreating virtual environment..."
|
||||
python -m venv $venvDir
|
||||
Write-Host "Virtual environment created at $venvDir" -ForegroundColor Green
|
||||
}
|
||||
& "$venvDir\Scripts\Activate.ps1"
|
||||
|
||||
# ── libmpv ─────────────────────────────────────────────────
|
||||
# ── PyTorch ───────────────────────────────────────────────
|
||||
# Detect NVIDIA GPU via nvidia-smi
|
||||
$hasNvidia = Get-Command nvidia-smi -ErrorAction SilentlyContinue
|
||||
if ($hasNvidia) {
|
||||
$torchIndex = "https://download.pytorch.org/whl/cu128"
|
||||
Write-Host "`nNVIDIA GPU detected — using CUDA 12.8 PyTorch index" -ForegroundColor Green
|
||||
} else {
|
||||
$torchIndex = "https://download.pytorch.org/whl/cpu"
|
||||
Write-Host "`nNo NVIDIA GPU detected — using CPU-only PyTorch index" -ForegroundColor Yellow
|
||||
}
|
||||
# Always install/upgrade torch stack from correct index
|
||||
# (pip install is a no-op if already at the right version)
|
||||
Write-Host "Installing PyTorch + torchaudio + torchvision..."
|
||||
pip install torch torchaudio torchvision --index-url $torchIndex
|
||||
|
||||
# ── Python deps ───────────────────────────────────────────
|
||||
Write-Host "`nInstalling project dependencies..."
|
||||
pip install -r (Join-Path $root "requirements.txt") --extra-index-url $torchIndex
|
||||
|
||||
# ── libmpv ────────────────────────────────────────────────
|
||||
$mpvDll = Join-Path $root "libmpv-2.dll"
|
||||
if (Test-Path $mpvDll) {
|
||||
Write-Host "`nlibmpv-2.dll already present, skipping." -ForegroundColor Green
|
||||
@@ -30,12 +57,11 @@ if (Test-Path $mpvDll) {
|
||||
Write-Host "libmpv-2.dll downloaded." -ForegroundColor Green
|
||||
}
|
||||
|
||||
# ── ffmpeg ─────────────────────────────────────────────────
|
||||
# ── ffmpeg ────────────────────────────────────────────────
|
||||
$ffmpeg = Join-Path $root "ffmpeg.exe"
|
||||
if (Test-Path $ffmpeg) {
|
||||
Write-Host "`nffmpeg.exe already present, skipping." -ForegroundColor Green
|
||||
} else {
|
||||
# Check if ffmpeg is on PATH
|
||||
$onPath = Get-Command ffmpeg -ErrorAction SilentlyContinue
|
||||
if ($onPath) {
|
||||
Write-Host "`nffmpeg found on PATH: $($onPath.Source)" -ForegroundColor Green
|
||||
@@ -54,6 +80,12 @@ if (Test-Path $ffmpeg) {
|
||||
}
|
||||
}
|
||||
|
||||
# ── Verify ────────────────────────────────────────────────
|
||||
Write-Host "`n--- Verification ---" -ForegroundColor Cyan
|
||||
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||
python -c "import sklearn, librosa, torchaudio; print('All imports OK')"
|
||||
|
||||
Write-Host "`n=== Setup complete ===" -ForegroundColor Cyan
|
||||
Write-Host "Run 8-cut with: python main.py"
|
||||
Write-Host "Run 8-cut with: .venv\Scripts\python.exe main.py"
|
||||
Write-Host "Or double-click: 8cut.bat"
|
||||
Read-Host "`nPress Enter to close"
|
||||
|
||||
Executable
+114
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# 8-cut environment setup — supports conda (miniforge) or python venv
|
||||
#
|
||||
# Usage:
|
||||
# ./setup_env.sh # auto-detect (prefers conda if available)
|
||||
# ./setup_env.sh --conda # force conda
|
||||
# ./setup_env.sh --venv # force python venv
|
||||
# ─��────────────────────────────��───────────────────────────────────────
|
||||
|
||||
ENV_NAME="8cut"
|
||||
PYTHON_VERSION="3.12"
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||
|
||||
# Auto-detect GPU for PyTorch index URL
|
||||
if command -v nvidia-smi &>/dev/null; then
|
||||
TORCH_INDEX="https://download.pytorch.org/whl/cu128"
|
||||
echo "NVIDIA GPU detected — will install PyTorch with CUDA 12.8"
|
||||
else
|
||||
TORCH_INDEX="https://download.pytorch.org/whl/cpu"
|
||||
echo "No NVIDIA GPU detected — will install CPU-only PyTorch"
|
||||
fi
|
||||
|
||||
# ── Parse args ────────────────────────────────────────────────────────
|
||||
|
||||
MODE=""
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--conda) MODE="conda" ;;
|
||||
--venv) MODE="venv" ;;
|
||||
*) echo "Unknown arg: $arg"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -z "$MODE" ]; then
|
||||
if command -v conda &>/dev/null; then
|
||||
MODE="conda"
|
||||
else
|
||||
MODE="venv"
|
||||
fi
|
||||
echo "Auto-detected mode: $MODE"
|
||||
fi
|
||||
|
||||
# ── Conda setup ─────────────��─────────────────────────────────────────
|
||||
|
||||
setup_conda() {
|
||||
echo "==> Setting up conda environment: $ENV_NAME"
|
||||
|
||||
# Source conda shell hooks if not already active
|
||||
if ! command -v conda &>/dev/null; then
|
||||
echo "conda not found in PATH"
|
||||
exit 1
|
||||
fi
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
||||
if conda env list | grep -qw "$ENV_NAME"; then
|
||||
echo " Environment '$ENV_NAME' already exists, updating..."
|
||||
conda activate "$ENV_NAME"
|
||||
else
|
||||
echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..."
|
||||
conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION"
|
||||
conda activate "$ENV_NAME"
|
||||
fi
|
||||
|
||||
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||
|
||||
echo " Installing project dependencies..."
|
||||
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||
|
||||
echo ""
|
||||
echo "Done! Activate with:"
|
||||
echo " conda activate $ENV_NAME"
|
||||
}
|
||||
|
||||
# ── Venv setup ───────��────────────────────────────────────────────────
|
||||
|
||||
setup_venv() {
|
||||
echo "==> Setting up Python venv at: $VENV_DIR"
|
||||
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
python3 -m venv "$VENV_DIR"
|
||||
echo " Created venv"
|
||||
else
|
||||
echo " Venv already exists, updating..."
|
||||
fi
|
||||
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||
|
||||
echo " Installing project dependencies..."
|
||||
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||
|
||||
echo ""
|
||||
echo "Done! Activate with:"
|
||||
echo " source $VENV_DIR/bin/activate"
|
||||
}
|
||||
|
||||
# ── Run ───────────────────────────────────────────────────────────────
|
||||
|
||||
case "$MODE" in
|
||||
conda) setup_conda ;;
|
||||
venv) setup_venv ;;
|
||||
esac
|
||||
|
||||
echo ""
|
||||
echo "Verify with:"
|
||||
echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\""
|
||||
echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\""
|
||||
@@ -0,0 +1,73 @@
|
||||
import tempfile, os
|
||||
import numpy as np
|
||||
from core.audio_scan import scan_video, load_classifier, default_model_path
|
||||
|
||||
|
||||
def test_scan_video_no_model_returns_empty():
|
||||
"""scan_video with no model should return empty list."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
import soundfile as sf
|
||||
sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000)
|
||||
try:
|
||||
regions = scan_video(vid.name, model=None)
|
||||
assert regions == []
|
||||
finally:
|
||||
os.unlink(vid.name)
|
||||
|
||||
|
||||
def test_load_classifier_missing_returns_none():
|
||||
assert load_classifier("/no/such/model.joblib") is None
|
||||
|
||||
|
||||
def test_default_model_path_contains_profile():
|
||||
path = default_model_path("test_profile")
|
||||
assert "test_profile" in path
|
||||
assert path.endswith(".joblib")
|
||||
|
||||
|
||||
def test_embed_dim_multi_layer():
|
||||
from core.audio_scan import _embed_dim
|
||||
# Multi-layer models should report concatenated dimension
|
||||
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||
# Single-layer unchanged
|
||||
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||
|
||||
|
||||
def test_ml_config():
|
||||
from core.audio_scan import _ml_config
|
||||
assert _ml_config("HUBERT_XLARGE") is None
|
||||
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||
assert base == "HUBERT_XLARGE"
|
||||
assert layers == [11, 23, 35, 47]
|
||||
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||
assert base == "HUBERT_BASE"
|
||||
assert layers == [2, 5, 8, 11]
|
||||
|
||||
|
||||
def test_embed_dim_ast():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("AST") == 768
|
||||
assert _embed_dim("AST_ML") == 3072
|
||||
|
||||
|
||||
def test_embed_dim_eat():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("EAT") == 768
|
||||
|
||||
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
from core.db import ProcessedDB
|
||||
db = ProcessedDB(path)
|
||||
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||
paths = db.get_all_export_paths("test")
|
||||
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||
finally:
|
||||
os.unlink(path)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from core.db import ProcessedDB
|
||||
|
||||
|
||||
def test_export_folders_excludes_scan_exports():
|
||||
"""Scan-export-only folders should not appear when include_scan_exports=False."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
# Manual export
|
||||
db.add("a.mp4", 10.0, "/out/mp4_Intense/g1/clip.mp4", profile="test")
|
||||
# Scan export to different folder
|
||||
db.add("a.mp4", 20.0, "/out/mp4_ScanOnly/g1/clip.mp4", profile="test",
|
||||
scan_export=True)
|
||||
folders = db.get_export_folders("test")
|
||||
assert "mp4_Intense" in folders
|
||||
assert "mp4_ScanOnly" not in folders, "scan-only folder should be excluded"
|
||||
# With include_scan_exports=True, both should appear
|
||||
folders_all = db.get_export_folders("test", include_scan_exports=True)
|
||||
assert "mp4_ScanOnly" in folders_all
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_scan_result_history():
|
||||
"""save_scan_results should keep multiple versions."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
# Save three versions (microsecond-precision timestamps avoid collisions)
|
||||
db.save_scan_results("v.mp4", "test", "MODEL_A", [(0, 8, 0.9)])
|
||||
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||
[(0, 8, 0.8), (10, 18, 0.7)])
|
||||
db.save_scan_results("v.mp4", "test", "MODEL_A", [(5, 13, 0.95)])
|
||||
versions = db.get_scan_versions("v.mp4", "test", "MODEL_A")
|
||||
assert len(versions) == 3
|
||||
# Most recent first
|
||||
assert versions[0]["count"] == 1 # latest: 1 region
|
||||
assert versions[1]["count"] == 2 # middle: 2 regions
|
||||
assert versions[2]["count"] == 1 # oldest: 1 region
|
||||
# get_scan_results returns latest version by default
|
||||
results = db.get_scan_results("v.mp4", "test")
|
||||
assert len(results.get("MODEL_A", [])) == 1
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_hard_negatives_source_model():
|
||||
"""Hard negatives should store source_model."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 2
|
||||
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_training_data_skips_hard_negatives():
|
||||
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
# Create a source file that "exists" — use the temp db file itself
|
||||
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||
source_path=path)
|
||||
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||
# With hard negatives
|
||||
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||
# Without hard negatives
|
||||
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||
assert len(data_with) >= 1
|
||||
# The "with" case should have the hard negative time in neg list
|
||||
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_delete_hard_negatives_by_ids():
|
||||
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||
source_path="/a.mp4")
|
||||
rows = db.get_hard_negatives("test")
|
||||
assert len(rows) == 3
|
||||
# Delete first two
|
||||
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||
remaining = db.get_hard_negatives("test")
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["start_time"] == 30.0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
+25
-24
@@ -1,24 +1,25 @@
|
||||
import tempfile, os, json
|
||||
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, build_annotation_json_path, upsert_clip_annotation, resolve_keyframe, apply_keyframes_to_jobs
|
||||
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, resolve_keyframe, apply_keyframes_to_jobs
|
||||
from core.annotations import build_annotation_json_path, upsert_clip_annotation
|
||||
from main import ProcessedDB
|
||||
|
||||
|
||||
def test_build_export_path_first():
|
||||
assert build_export_path("/out", "clip", 1) == "/out/clip_001/clip_001.mp4"
|
||||
assert build_export_path("/out", "clip", 1) == "/out/clip_001.mp4"
|
||||
|
||||
def test_build_export_path_counter():
|
||||
assert build_export_path("/out", "clip", 42) == "/out/clip_042/clip_042.mp4"
|
||||
assert build_export_path("/out", "clip", 42) == "/out/clip_042.mp4"
|
||||
|
||||
def test_build_export_path_deep_counter():
|
||||
assert build_export_path("/out", "shot", 999) == "/out/shot_999/shot_999.mp4"
|
||||
assert build_export_path("/out", "shot", 999) == "/out/shot_999.mp4"
|
||||
|
||||
def test_build_export_path_sub():
|
||||
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0.mp4"
|
||||
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001/clip_001_2.mp4"
|
||||
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001_0.mp4"
|
||||
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001_2.mp4"
|
||||
|
||||
def test_build_sequence_dir_sub():
|
||||
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0"
|
||||
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001/clip_001_1"
|
||||
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001_0"
|
||||
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001_1"
|
||||
|
||||
def test_format_time_seconds():
|
||||
assert format_time(0.0) == "0:00.0"
|
||||
@@ -177,10 +178,10 @@ def test_audio_extract_timing():
|
||||
|
||||
|
||||
def test_build_sequence_dir_basic():
|
||||
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001/clip_001"
|
||||
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001"
|
||||
|
||||
def test_build_sequence_dir_counter():
|
||||
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042/clip_042"
|
||||
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042"
|
||||
|
||||
def test_ffmpeg_command_image_sequence():
|
||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/seq_001", image_sequence=True)
|
||||
@@ -264,13 +265,13 @@ def test_db_get_group_returns_all_sub_clips():
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_2.mp4")
|
||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_2.mp4")
|
||||
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||
assert len(group) == 3
|
||||
assert "/out/clip_001/clip_001_0.mp4" in group
|
||||
assert "/out/clip_001/clip_001_2.mp4" in group
|
||||
assert "/out/vid_001/clip_001_0.mp4" in group
|
||||
assert "/out/vid_001/clip_001_2.mp4" in group
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
@@ -280,10 +281,10 @@ def test_db_get_group_isolates_by_start_time():
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||
assert len(group) == 2
|
||||
finally:
|
||||
os.unlink(path)
|
||||
@@ -294,10 +295,10 @@ def test_db_delete_group_removes_all():
|
||||
path = f.name
|
||||
try:
|
||||
db = ProcessedDB(path)
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
||||
deleted = db.delete_group("/out/clip_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||
deleted = db.delete_group("/out/vid_001/clip_001_0.mp4")
|
||||
assert len(deleted) == 2
|
||||
# clip_002 should still exist
|
||||
markers = db.get_markers("video.mp4")
|
||||
|
||||
Reference in New Issue
Block a user