Compare commits
211 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eab5c690c7 | |||
| 4445f0e7f4 | |||
| ed63d04abf | |||
| 7ae1720b9e | |||
| 514607eddd | |||
| 4299de5f97 | |||
| 86ab606059 | |||
| 87ccd8650c | |||
| ad9e564991 | |||
| 4baac54930 | |||
| 879684ce25 | |||
| 92774216d4 | |||
| 02fd0f0919 | |||
| c537ac678d | |||
| 755f7e5131 | |||
| 1eb7de2a1a | |||
| d7680283a2 | |||
| bf4b6dad2d | |||
| 4715c0ce49 | |||
| e5ce59c065 | |||
| cbbdfeadb1 | |||
| 8a7d761815 | |||
| 140a424469 | |||
| bc6e30a2d4 | |||
| 2ea3a9149a | |||
| e820c106af | |||
| 780832d4aa | |||
| 6037f15e7b | |||
| 035eaf3894 | |||
| 35ea1baec8 | |||
| 6a71386ed8 | |||
| d1fb35af8e | |||
| c55693094d | |||
| 5832d08b26 | |||
| b4cfa7561a | |||
| 0ccc29709e | |||
| 7e917d00a6 | |||
| 2ffb81eaa3 | |||
| b448085242 | |||
| 7cf90c1e5c | |||
| 5aa6878cf6 | |||
| 0e903812fa | |||
| d23ae2e88a | |||
| d97de8de10 | |||
| c6673228fa | |||
| fa4104eded | |||
| 9f7d2e1185 | |||
| c2e6c62c00 | |||
| 8aa8d8805b | |||
| 35c67f4bd5 | |||
| b738a19304 | |||
| dbd8e6a8ac | |||
| 73dfea4ae9 | |||
| 2170e72cbd | |||
| c9915914c4 | |||
| 251747bb0b | |||
| 13c4d3f7f6 | |||
| 1d49ce7cee | |||
| 109bc658c3 | |||
| ec7138f51b | |||
| 68c633ab46 | |||
| d0a94e7b68 | |||
| 632c2dc076 | |||
| 0f335c5e66 | |||
| f1f8fd5244 | |||
| 299779cf29 | |||
| 56218c18f4 | |||
| 2c45aff668 | |||
| 07e2f733b9 | |||
| 8c5a4c4524 | |||
| 4e5b631efb | |||
| ec77b8224f | |||
| 9becd5a06d | |||
| fae5560e2d | |||
| 07e3a1223c | |||
| 3af6e05fb7 | |||
| d787871735 | |||
| 85c08d7c48 | |||
| f6966a092a | |||
| 7cee3ab768 | |||
| 47f910644d | |||
| e972c7a2ae | |||
| cb805c5bda | |||
| bf14247b00 | |||
| 73396659dc | |||
| c8bc629419 | |||
| de8840e1eb | |||
| def966a913 | |||
| bc4ae21153 | |||
| a731fbfc32 | |||
| 1bdeb33a6f | |||
| 387ed7bc6a | |||
| f268d61fe4 | |||
| 24db32c09f | |||
| 0f6ae88ea6 | |||
| 4d99cf6015 | |||
| b75fa85ff5 | |||
| e7d47331c6 | |||
| 7cd31ebe55 | |||
| 3a37dddfd9 | |||
| b249705506 | |||
| aaf405dd3d | |||
| cb2060beb8 | |||
| 0db412baf4 | |||
| 876026d1f6 | |||
| 6c1d42adfe | |||
| d8b3972bdc | |||
| bd345abca2 | |||
| 7d6fee9df1 | |||
| fd043f4172 | |||
| 3c3b1d74bb | |||
| a3c657c66e | |||
| 5d45b8d8eb | |||
| e6db83f00b | |||
| edc5784ba6 | |||
| 8ed9fbf557 | |||
| 4fb2ae144f | |||
| 2614a765d5 | |||
| c020c0dfec | |||
| e7b791fbfa | |||
| f5361a963e | |||
| 8fb8581816 | |||
| 5b25e85e98 | |||
| e3f133ef84 | |||
| 4736f150b0 | |||
| 52aa982aa2 | |||
| 07457d0d6f | |||
| c5d613fc5f | |||
| 7855ea62c2 | |||
| 70be5974cf | |||
| a0286d5cf9 | |||
| 2b7dfb330d | |||
| 518554f788 | |||
| 282156e8ed | |||
| 3417a0f603 | |||
| cd0552197f | |||
| 7dffcb08eb | |||
| 93bcb23fa7 | |||
| eda7826a40 | |||
| e7e20b0fe6 | |||
| 814ef946eb | |||
| 2e738df9ae | |||
| 6ddfcde8ee | |||
| b161412d94 | |||
| 5a9e068903 | |||
| 6870e5aaf3 | |||
| f597ff29e8 | |||
| e1789d4e71 | |||
| 7834b1d05c | |||
| 12ed183f1b | |||
| f2c38aee79 | |||
| 8ab5bdba77 | |||
| c6c5934fe8 | |||
| 73d5367424 | |||
| 1e2cebd424 | |||
| c439aca9b9 | |||
| afda9b2d9f | |||
| fd42791c9f | |||
| 4cf54f2642 | |||
| e7f4de9ec1 | |||
| 9cf9e3233f | |||
| e17d8f67aa | |||
| b1980de6d1 | |||
| 85e0641440 | |||
| 834b89b682 | |||
| a67e189aa0 | |||
| 2b6c56cd15 | |||
| 0f6082061f | |||
| 9662b815db | |||
| 9776b83ac5 | |||
| 39f873bec2 | |||
| 409eb82e5c | |||
| 297aafa51c | |||
| b4cf972d59 | |||
| 5cc1e52e75 | |||
| 6bf0b0ae99 | |||
| b6fbda01dd | |||
| 51d41f0a56 | |||
| 16bd1a9ae0 | |||
| 2036c49b52 | |||
| b12758c53c | |||
| 3d484952c2 | |||
| 12dae93671 | |||
| 1e65fd6b0f | |||
| f7756320e5 | |||
| cd0331d4ce | |||
| 38c6174f83 | |||
| 5b22bceed2 | |||
| 80f21915e3 | |||
| b09ba3fa9e | |||
| 5b7a55a05d | |||
| 2200da491f | |||
| 3d6469c60c | |||
| 6a4ac8b8ed | |||
| 1f6906c946 | |||
| dfba88a601 | |||
| e94c088df0 | |||
| 9569103edd | |||
| 079afeee7c | |||
| fbbfa6fdce | |||
| 56920a5247 | |||
| 08c1dd8b33 | |||
| 2b63ad1857 | |||
| 72f6a4e8f5 | |||
| 799a2ab353 | |||
| 066f4431ba | |||
| 97f9ef7073 | |||
| 592e40c1a6 | |||
| 73dd7a1569 | |||
| 7abf0b4d4c | |||
| 9e5bd4a8ec |
@@ -0,0 +1,36 @@
|
|||||||
|
name: Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # manual only — build locally and push to ghcr.io
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- uses: docker/metadata-action@v5
|
||||||
|
id: meta
|
||||||
|
with:
|
||||||
|
images: ghcr.io/${{ github.repository }}-server
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=sha,prefix=
|
||||||
|
|
||||||
|
- uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
@@ -3,3 +3,8 @@ __pycache__/
|
|||||||
*.pyo
|
*.pyo
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.worktrees/
|
.worktrees/
|
||||||
|
.venv/
|
||||||
|
models/
|
||||||
|
cache/
|
||||||
|
*.joblib
|
||||||
|
*.pt
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
@echo off
|
@echo off
|
||||||
cd /d "%~dp0"
|
cd /d "%~dp0"
|
||||||
|
if exist ".venv\Scripts\python.exe" (
|
||||||
|
.venv\Scripts\python.exe main.py %*
|
||||||
|
) else (
|
||||||
python main.py %*
|
python main.py %*
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Launch 8-cut with auto-detected venv/conda environment
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
ENV_NAME="8cut"
|
||||||
|
CONDA_PREFIX_BASE="/media/p5/miniforge3"
|
||||||
|
export LD_PRELOAD=/usr/lib/libstdc++.so.6
|
||||||
|
|
||||||
|
# 1. Try .venv in project dir
|
||||||
|
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
|
||||||
|
source "$SCRIPT_DIR/.venv/bin/activate"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. Try conda env (works without shell init)
|
||||||
|
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
|
||||||
|
if [ -x "$CONDA_PYTHON" ]; then
|
||||||
|
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. Try conda via shell hook (interactive shells)
|
||||||
|
if command -v conda &>/dev/null; then
|
||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Fallback to system Python
|
||||||
|
exec python3 "$SCRIPT_DIR/main.py" "$@"
|
||||||
@@ -30,6 +30,11 @@ mpv_dir = Path(os.environ.get("MPV_DIR", base))
|
|||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
|
|
||||||
|
# Bundled assets (icons, logo) — must exist at runtime under sys._MEIPASS/assets
|
||||||
|
assets_dir = base / "assets"
|
||||||
|
if assets_dir.exists():
|
||||||
|
datas.append((str(assets_dir), "assets"))
|
||||||
|
|
||||||
# YOLOv8 model (optional — large, skip if missing)
|
# YOLOv8 model (optional — large, skip if missing)
|
||||||
yolo = base / "yolov8n.pt"
|
yolo = base / "yolov8n.pt"
|
||||||
if yolo.exists():
|
if yolo.exists():
|
||||||
|
|||||||
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Calibration — per-video normalized features + classifier."""
|
||||||
|
import sys, os, time, warnings
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from sklearn.ensemble import GradientBoostingClassifier
|
||||||
|
|
||||||
|
from core.audio_scan import _SR, _WINDOW
|
||||||
|
|
||||||
|
_HOP_LENGTH = 1024
|
||||||
|
_N_FFT = 2048
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||||
|
PROFILE_NAME = "JAV_missionary"
|
||||||
|
TOLERANCE = 12.0
|
||||||
|
NEG_MARGIN = 120.0
|
||||||
|
|
||||||
|
|
||||||
|
def extract_rich_features(y, sr=_SR):
|
||||||
|
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
|
||||||
|
hop = _HOP_LENGTH
|
||||||
|
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
|
||||||
|
rms = librosa.feature.rms(S=S, hop_length=hop)
|
||||||
|
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
|
||||||
|
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
|
||||||
|
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
|
||||||
|
flatness = librosa.feature.spectral_flatness(S=S)
|
||||||
|
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
|
||||||
|
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
|
||||||
|
|
||||||
|
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
|
||||||
|
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
|
||||||
|
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
|
||||||
|
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
|
||||||
|
band_feats = []
|
||||||
|
for flo, fhi in bands:
|
||||||
|
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
|
||||||
|
if mask.sum() > 0:
|
||||||
|
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
|
||||||
|
else:
|
||||||
|
band_feats.append(np.zeros((1, mel_S.shape[1])))
|
||||||
|
|
||||||
|
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
|
||||||
|
|
||||||
|
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
|
||||||
|
band_feats[0].shape[1])
|
||||||
|
return np.vstack([
|
||||||
|
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
|
||||||
|
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
|
||||||
|
] + [b[:, :min_t] for b in band_feats]
|
||||||
|
+ [sc[:, :min_t]])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_window_stats(feat, hop=1.0):
|
||||||
|
"""Sliding window mean/std → (timestamps, feature_vectors)."""
|
||||||
|
n_feats, T = feat.shape
|
||||||
|
fps = _SR / _HOP_LENGTH
|
||||||
|
win_frames = int(_WINDOW * fps)
|
||||||
|
hop_frames = int(hop * fps)
|
||||||
|
if win_frames > T:
|
||||||
|
return np.array([]), np.array([])
|
||||||
|
|
||||||
|
cumsum = np.zeros((n_feats, T + 1))
|
||||||
|
cumsum[:, 1:] = np.cumsum(feat, axis=1)
|
||||||
|
cumsq = np.zeros((n_feats, T + 1))
|
||||||
|
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
|
||||||
|
|
||||||
|
starts = np.arange(0, T - win_frames + 1, hop_frames)
|
||||||
|
ends = starts + win_frames
|
||||||
|
sums = cumsum[:, ends] - cumsum[:, starts]
|
||||||
|
sq_sums = cumsq[:, ends] - cumsq[:, starts]
|
||||||
|
means = sums / win_frames
|
||||||
|
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
|
||||||
|
|
||||||
|
return starts / fps, np.vstack([means, stds]).T
|
||||||
|
|
||||||
|
|
||||||
|
def label_windows(timestamps, gt_intense, gt_soft):
|
||||||
|
all_gt = list(gt_intense) + list(gt_soft)
|
||||||
|
labels = np.zeros(len(timestamps), dtype=int)
|
||||||
|
for i, t in enumerate(timestamps):
|
||||||
|
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||||
|
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||||
|
if di < TOLERANCE:
|
||||||
|
labels[i] = 1
|
||||||
|
elif da > NEG_MARGIN:
|
||||||
|
labels[i] = -1
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
db = ProcessedDB()
|
||||||
|
rows = db._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
|
||||||
|
(PROFILE_NAME,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
intense_by_video, soft_by_video = {}, {}
|
||||||
|
for fn, st, op in rows:
|
||||||
|
if '/mp4_Intense/' in op:
|
||||||
|
intense_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif '/mp4_Soft/' in op:
|
||||||
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
|
videos = [fn for fn in intense_by_video
|
||||||
|
if os.path.exists(os.path.join(PLEX_DIR, fn))]
|
||||||
|
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
|
||||||
|
videos = videos[:n_vids]
|
||||||
|
print(f"Processing {len(videos)} videos...")
|
||||||
|
|
||||||
|
all_data_raw = [] # raw features
|
||||||
|
all_data_norm = [] # per-video z-scored features
|
||||||
|
|
||||||
|
for vi, vname in enumerate(videos):
|
||||||
|
vpath = os.path.join(PLEX_DIR, vname)
|
||||||
|
gt_intense = sorted(intense_by_video.get(vname, set()))
|
||||||
|
gt_soft = sorted(soft_by_video.get(vname, set()))
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
y, _ = librosa.load(vpath, sr=_SR, mono=True)
|
||||||
|
feat = extract_rich_features(y)
|
||||||
|
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
|
||||||
|
dt = time.time() - t0
|
||||||
|
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
labels = label_windows(timestamps, gt_intense, gt_soft)
|
||||||
|
|
||||||
|
# Per-video z-score normalization
|
||||||
|
vid_mean = window_vectors.mean(axis=0)
|
||||||
|
vid_std = window_vectors.std(axis=0)
|
||||||
|
vid_std = np.maximum(vid_std, 1e-6)
|
||||||
|
normed = (window_vectors - vid_mean) / vid_std
|
||||||
|
|
||||||
|
n_pos = (labels == 1).sum()
|
||||||
|
n_neg = (labels == -1).sum()
|
||||||
|
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
|
||||||
|
|
||||||
|
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
|
||||||
|
all_data_norm.append((vi, vname, timestamps, normed, labels))
|
||||||
|
|
||||||
|
# Run CV for both raw and normalized
|
||||||
|
for label, data in [("RAW features", all_data_raw),
|
||||||
|
("PER-VIDEO NORMALIZED features", all_data_norm)]:
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f" {label}")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
|
||||||
|
all_y_true, all_y_prob = [], []
|
||||||
|
|
||||||
|
for test_idx in range(len(data)):
|
||||||
|
_, vname, _, test_X, test_labels = data[test_idx]
|
||||||
|
test_mask = test_labels != 0
|
||||||
|
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
|
||||||
|
continue
|
||||||
|
X_test = test_X[test_mask]
|
||||||
|
y_test = (test_labels[test_mask] == 1).astype(int)
|
||||||
|
|
||||||
|
X_parts, y_parts = [], []
|
||||||
|
for i, (_, _, _, feats, labs) in enumerate(data):
|
||||||
|
if i == test_idx:
|
||||||
|
continue
|
||||||
|
m = labs != 0
|
||||||
|
if m.sum() == 0:
|
||||||
|
continue
|
||||||
|
X_parts.append(feats[m])
|
||||||
|
y_parts.append((labs[m] == 1).astype(int))
|
||||||
|
|
||||||
|
if not X_parts:
|
||||||
|
continue
|
||||||
|
X_train = np.vstack(X_parts)
|
||||||
|
y_train = np.concatenate(y_parts)
|
||||||
|
|
||||||
|
pos_idx = np.where(y_train == 1)[0]
|
||||||
|
neg_idx = np.where(y_train == 0)[0]
|
||||||
|
if len(pos_idx) == 0 or len(neg_idx) == 0:
|
||||||
|
continue
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
n_neg = min(len(neg_idx), len(pos_idx) * 3)
|
||||||
|
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
|
||||||
|
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||||
|
|
||||||
|
clf = GradientBoostingClassifier(
|
||||||
|
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
|
||||||
|
)
|
||||||
|
clf.fit(X_train[train_idx], y_train[train_idx])
|
||||||
|
probs = clf.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
tp = ((probs >= 0.5) & (y_test == 1)).sum()
|
||||||
|
fp = ((probs >= 0.5) & (y_test == 0)).sum()
|
||||||
|
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
|
||||||
|
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
|
||||||
|
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
|
||||||
|
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
|
||||||
|
|
||||||
|
all_y_true.extend(y_test)
|
||||||
|
all_y_prob.extend(probs)
|
||||||
|
|
||||||
|
if not all_y_true:
|
||||||
|
print(" No test results.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true = np.array(all_y_true)
|
||||||
|
y_prob = np.array(all_y_prob)
|
||||||
|
pos_probs = y_prob[y_true == 1]
|
||||||
|
neg_probs = y_prob[y_true == 0]
|
||||||
|
|
||||||
|
if len(pos_probs) > 0 and len(neg_probs) > 0:
|
||||||
|
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
|
||||||
|
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
|
||||||
|
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
|
||||||
|
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
|
||||||
|
|
||||||
|
best_f1, best_thr = 0, 0
|
||||||
|
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
|
||||||
|
for thr in np.arange(0.10, 0.91, 0.05):
|
||||||
|
tp = ((y_prob >= thr) & (y_true == 1)).sum()
|
||||||
|
fp = ((y_prob >= thr) & (y_true == 0)).sum()
|
||||||
|
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
|
||||||
|
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||||
|
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
|
||||||
|
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||||
|
if f1 > best_f1:
|
||||||
|
best_f1, best_thr = f1, thr
|
||||||
|
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
|
||||||
|
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
|
||||||
|
|
||||||
|
# Feature importance
|
||||||
|
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
|
||||||
|
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
|
||||||
|
pos_idx = np.where(y_all == 1)[0]
|
||||||
|
neg_idx = np.where(y_all == 0)[0]
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
|
||||||
|
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
|
||||||
|
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
|
||||||
|
|
||||||
|
feat_names = (
|
||||||
|
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
|
||||||
|
+ [f"mel{i}" for i in range(8)]
|
||||||
|
+ [f"sc{i}" for i in range(7)]
|
||||||
|
)
|
||||||
|
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
|
||||||
|
imp = clf.feature_importances_
|
||||||
|
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
|
||||||
|
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train an audio scan classifier from DB ground truth.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python 8cut_train.py # default model, auto-detect positive
|
||||||
|
python 8cut_train.py --model BEATS # specific embedding model
|
||||||
|
python 8cut_train.py --positive mp4_Intense # explicit positive folder
|
||||||
|
python 8cut_train.py --positive mp4_Intense --model BEATS # both
|
||||||
|
"""
|
||||||
|
import sys, os, warnings
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
PROFILE_NAME = "JAV_missionary"
|
||||||
|
|
||||||
|
# Fallback for old DB rows without source_path
|
||||||
|
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
embed_model = None
|
||||||
|
if "--model" in sys.argv:
|
||||||
|
idx = sys.argv.index("--model")
|
||||||
|
if idx + 1 < len(sys.argv):
|
||||||
|
embed_model = sys.argv[idx + 1]
|
||||||
|
if embed_model not in _EMBED_MODELS:
|
||||||
|
print(f"Unknown model: {embed_model}")
|
||||||
|
print(f"Available: {', '.join(_EMBED_MODELS)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
positive_suffix = None
|
||||||
|
if "--positive" in sys.argv:
|
||||||
|
idx = sys.argv.index("--positive")
|
||||||
|
if idx + 1 < len(sys.argv):
|
||||||
|
positive_suffix = sys.argv[idx + 1]
|
||||||
|
|
||||||
|
db = ProcessedDB()
|
||||||
|
|
||||||
|
# If --positive given, use the new DB helper
|
||||||
|
if positive_suffix:
|
||||||
|
video_infos = db.get_training_data(
|
||||||
|
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
|
||||||
|
)
|
||||||
|
if not video_infos:
|
||||||
|
print(f"No training data found for positive='{positive_suffix}'")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# Legacy fallback: classify by folder path pattern
|
||||||
|
rows = db._con.execute(
|
||||||
|
"SELECT filename, start_time, output_path, source_path"
|
||||||
|
" FROM processed WHERE profile = ?",
|
||||||
|
(PROFILE_NAME,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
intense_by_video, soft_by_video = {}, {}
|
||||||
|
source_by_fn = {}
|
||||||
|
for fn, st, op, sp in rows:
|
||||||
|
if sp:
|
||||||
|
source_by_fn[fn] = sp
|
||||||
|
if "/mp4_Intense/" in op or "_Intense/" in op:
|
||||||
|
intense_by_video.setdefault(fn, set()).add(st)
|
||||||
|
elif "/mp4_Soft/" in op or "_Soft/" in op:
|
||||||
|
soft_by_video.setdefault(fn, set()).add(st)
|
||||||
|
|
||||||
|
video_infos = []
|
||||||
|
for fn in intense_by_video:
|
||||||
|
# Try source_path from DB first, fall back to PLEX_DIR
|
||||||
|
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
|
||||||
|
if not os.path.exists(vpath):
|
||||||
|
print(f" skip (not found): {fn}")
|
||||||
|
continue
|
||||||
|
gt_intense = sorted(intense_by_video[fn])
|
||||||
|
gt_soft = sorted(soft_by_video.get(fn, set()))
|
||||||
|
video_infos.append((vpath, gt_intense, gt_soft))
|
||||||
|
|
||||||
|
label = embed_model or "WAV2VEC2_BASE"
|
||||||
|
print(f"Training {label} model on {len(video_infos)} videos...")
|
||||||
|
model_path = default_model_path(PROFILE_NAME)
|
||||||
|
result = train_classifier(
|
||||||
|
video_infos, model_path=model_path, embed_model=embed_model,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
print("Training failed: no valid samples or missing class balance")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"Model saved to {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -8,7 +8,7 @@
|
|||||||
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets.
|
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets. Includes audio classification for automated scanning and batch export.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
@@ -22,19 +22,54 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
### Clip export
|
||||||
|
|
||||||
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
||||||
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
||||||
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
||||||
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
||||||
- **Random portrait** — optionally apply a random portrait crop to a subset of each batch
|
- **Random portrait/square** — optionally apply a random crop to a subset of each batch
|
||||||
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
||||||
- **Sound annotation** — label and category fields saved to the clip database; label also written to `dataset.json`
|
- **Hardware encoding** — GPU-accelerated export via NVENC, VAAPI, QSV, AMF, or VideoToolbox
|
||||||
- **Export history** — timeline markers show previously exported clips; double-click to enter overwrite mode; right-click to delete
|
- **Subject tracking** — auto-adjust crop center using YOLOv8 detection (optional)
|
||||||
- **End-frame preview** — floating window shows the last frame of the selection region
|
|
||||||
- **Playlist** — drag-and-drop or use the Open Files button; right-click to remove items
|
### Audio scanning
|
||||||
- **Playback loop** — plays the exact selection region on loop so you can preview what will be exported
|
|
||||||
- **Group operations** — delete or overwrite acts on all sub-clips in a batch, not just one
|
- **Embedding models** — WAV2VEC2 (base/large), HuBERT (base/large/xlarge), BEATs
|
||||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait") for the same video
|
- **Train classifier** — train a gradient boosting classifier on your exported clips to find similar audio
|
||||||
|
- **Scan video** — detect regions matching your trained model with configurable threshold
|
||||||
|
- **Scan All** — batch scan every video in the playlist
|
||||||
|
- **Region fusion** — merge overlapping detections into contiguous regions
|
||||||
|
- **Hard negatives** — mark false positives to refine training
|
||||||
|
- **Model versioning** — timestamped backups with rollback support
|
||||||
|
- **Scan export** — batch export from scan results with spread and minimum duration filtering
|
||||||
|
|
||||||
|
### Scan results panel
|
||||||
|
|
||||||
|
- **Tabbed results** — one tab per model, showing start/end/score per region
|
||||||
|
- **Disable regions** — Delete/Backspace toggles regions off (greyed out, excluded from export) without removing them
|
||||||
|
- **Resize regions** — double-click Time or End cells to edit, or drag region edges directly on the timeline
|
||||||
|
- **Grey ghost** — trimmed portions of resized regions shown as grey overlay on timeline
|
||||||
|
- **Undo** — Ctrl+Z reverts the last disable, resize, drag, or negative toggle
|
||||||
|
|
||||||
|
### Organization
|
||||||
|
|
||||||
|
- **Sound annotation** — label and category fields saved to the clip database and `dataset.json`
|
||||||
|
- **Export history** — timeline markers show previously exported clips; double-click to overwrite; right-click to delete
|
||||||
|
- **Playlist** — drag-and-drop video queue with progress tracking
|
||||||
|
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait")
|
||||||
|
- **Subprofiles** — lightweight export folder variants for multiple output targets
|
||||||
|
- **Review mode** — clean timeline view for navigating scan results without export clutter
|
||||||
|
|
||||||
|
### Interface
|
||||||
|
|
||||||
|
- **Menu bar** — File / Edit / Scan / View / Help hold the occasional actions (open files, train, scan all, profiles); the profile selector and shortcuts (`?`) sit in the top-right corner
|
||||||
|
- **Control deck** — a compact tabbed panel under the video groups the settings into **Export** (label, name, folder, format, resize, duration/clips/spread, workers), **Crop & Track**, and **Scan** (model, threshold, fuse, scan/auto/speech/review)
|
||||||
|
- **Side-by-side panels** — pin deck panels to view them as resizable columns: right-click a deck tab → *Show side-by-side*, or toggle them under *View ▸ Side-by-side panels*; drag the dividers to reallocate space, and the layout persists between sessions
|
||||||
|
- **Per-tab export folder** — each file-list tab remembers its own output folder; switching tabs follows that tab's folder, and a guardrail warns when the loaded video doesn't match the destination
|
||||||
|
- **Duplicate tab** — right-click a file-list tab → *Duplicate tab* to clone its files into a new tab with its own export folder
|
||||||
|
- **LTX-2 export mode** — per-tab **Foley | LTX-2** toggle (right-click a tab, shown with an `[LTX2]` badge): LTX-2 clips are frame-exact (`frames % 8 == 1`), forced to 25 fps, and center-cropped so width & height are divisible by 32 — for LTX-2 video-to-audio datasets; applies to manual, re-export, and auto-export
|
||||||
|
- **Status bar** — export/scan progress and messages, with the current file · profile · worker count always shown
|
||||||
|
|
||||||
## Keyboard shortcuts
|
## Keyboard shortcuts
|
||||||
|
|
||||||
@@ -50,37 +85,158 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
|||||||
| `M` | Jump to next marker (wraps) |
|
| `M` | Jump to next marker (wraps) |
|
||||||
| `N` | Next file in playlist |
|
| `N` | Next file in playlist |
|
||||||
| `G` | Toggle cursor lock |
|
| `G` | Toggle cursor lock |
|
||||||
|
| `Delete` / `Backspace` | Toggle disable on selected scan regions |
|
||||||
|
| `Ctrl+Z` | Undo last scan panel action |
|
||||||
| `?` / `F1` | Show keyboard shortcuts |
|
| `?` / `F1` | Show keyboard shortcuts |
|
||||||
|
|
||||||
Shortcuts are suppressed when a text field has focus.
|
Shortcuts are suppressed when a text field has focus.
|
||||||
|
|
||||||
## Requirements
|
## Installation
|
||||||
|
|
||||||
- Python 3.11+
|
### Prerequisites
|
||||||
- `ffmpeg` on `PATH`
|
|
||||||
- PyQt6
|
|
||||||
- python-mpv (requires libmpv)
|
|
||||||
|
|
||||||
|
- **Python 3.11+** — [python.org/downloads](https://www.python.org/downloads/)
|
||||||
|
- **ffmpeg** — video encoding
|
||||||
|
- **libmpv** — video playback
|
||||||
|
|
||||||
|
### Quick start (all platforms)
|
||||||
|
|
||||||
|
The setup script creates a virtual environment and installs everything including PyTorch with CUDA support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux / macOS
|
||||||
|
./setup_env.sh
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux / macOS
|
||||||
|
./8cut.sh
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
8cut.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
The launch scripts auto-detect your venv or conda environment.
|
||||||
|
|
||||||
|
### Manual installation
|
||||||
|
|
||||||
|
#### 1. Install system dependencies
|
||||||
|
|
||||||
|
**Linux (Arch):**
|
||||||
|
```bash
|
||||||
|
pacman -S python mpv ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Debian/Ubuntu):**
|
||||||
|
```bash
|
||||||
|
apt install python3 python3-venv libmpv-dev ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Windows:**
|
||||||
|
```powershell
|
||||||
|
# ffmpeg
|
||||||
|
winget install ffmpeg
|
||||||
|
|
||||||
|
# libmpv — download mpv-2.dll and place next to main.py
|
||||||
|
# https://sourceforge.net/projects/mpv-player-windows/files/libmpv/
|
||||||
|
```
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install python mpv ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Create a virtual environment
|
||||||
|
|
||||||
|
A virtual environment keeps 8-cut's dependencies isolated from your system Python. This is strongly recommended — PyTorch alone is several GB and can conflict with other projects.
|
||||||
|
|
||||||
|
**Using venv (recommended):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create the venv in the project directory
|
||||||
|
python3 -m venv .venv
|
||||||
|
|
||||||
|
# Activate it
|
||||||
|
# Linux / macOS:
|
||||||
|
source .venv/bin/activate
|
||||||
|
# Windows (cmd):
|
||||||
|
.venv\Scripts\activate.bat
|
||||||
|
# Windows (PowerShell):
|
||||||
|
.venv\Scripts\Activate.ps1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using conda / miniforge:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n 8cut python=3.12
|
||||||
|
conda activate 8cut
|
||||||
|
```
|
||||||
|
|
||||||
|
You must activate the environment every time you open a new terminal before running 8-cut. The `8cut.sh` launcher does this automatically.
|
||||||
|
|
||||||
|
#### 3. Install PyTorch
|
||||||
|
|
||||||
|
PyTorch must be installed separately with the correct CUDA version for GPU acceleration. Without CUDA, audio scanning will fall back to CPU (much slower).
|
||||||
|
|
||||||
|
**With NVIDIA GPU (CUDA 12.8):**
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
```
|
||||||
|
|
||||||
|
**CPU only (no GPU):**
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Check available CUDA versions at [pytorch.org/get-started](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
|
#### 4. Install project dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Platform notes
|
#### 5. Verify
|
||||||
|
|
||||||
| Platform | libmpv |
|
```bash
|
||||||
|----------|--------|
|
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||||
| **Linux** | `apt install libmpv-dev` or `pacman -S mpv` |
|
python -c "import librosa, torchaudio, sklearn; print('All imports OK')"
|
||||||
| **macOS** | `brew install mpv` |
|
```
|
||||||
| **Windows** | Download `mpv-2.dll` from [mpv Windows builds](https://sourceforge.net/projects/mpv-player-windows/files/libmpv/) and place it in `PATH` or next to `main.py` |
|
|
||||||
|
|
||||||
Windows also needs `ffmpeg.exe` on `PATH` (e.g. `winget install ffmpeg`).
|
### Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# With venv activated:
|
||||||
|
python main.py
|
||||||
|
|
||||||
|
# Or use the launcher (auto-activates venv/conda):
|
||||||
|
./8cut.sh # Linux / macOS
|
||||||
|
8cut.bat # Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU encoding
|
||||||
|
|
||||||
|
Hardware encoders are auto-detected from ffmpeg. Available encoders by platform:
|
||||||
|
|
||||||
|
| Platform | Encoders |
|
||||||
|
|----------|----------|
|
||||||
|
| **Linux** | `h264_nvenc` (NVIDIA), `h264_vaapi` (AMD/Intel), `h264_qsv` (Intel) |
|
||||||
|
| **Windows** | `h264_nvenc` (NVIDIA), `h264_qsv` (Intel), `h264_amf` (AMD) |
|
||||||
|
| **macOS** | `h264_videotoolbox` |
|
||||||
|
|
||||||
|
Enable the **HW** checkbox in the export controls to use GPU encoding.
|
||||||
|
|
||||||
|
### Optional: audio scanning
|
||||||
|
|
||||||
|
Audio scanning requires PyTorch (installed above). Embedding models are downloaded on first use and cached in `cache/downloads/`. A CUDA-capable GPU is strongly recommended for training and scanning speed.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```
|
|
||||||
python main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
||||||
|
|
||||||
### Export layout
|
### Export layout
|
||||||
@@ -109,6 +265,20 @@ output/
|
|||||||
clip_001_0.wav
|
clip_001_0.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Scan export layout
|
||||||
|
|
||||||
|
Scan exports create one group folder per detected area:
|
||||||
|
|
||||||
|
```
|
||||||
|
output/
|
||||||
|
clip_037/
|
||||||
|
clip_037_a1_0.mp4 # area 1, clip 0
|
||||||
|
clip_037_a1_1.mp4 # area 1, clip 1
|
||||||
|
clip_038/
|
||||||
|
clip_038_a2_0.mp4 # area 2, clip 0
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
### Sound annotation
|
### Sound annotation
|
||||||
|
|
||||||
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
||||||
@@ -124,9 +294,73 @@ Labels persist between exports so you can cut many clips of the same class witho
|
|||||||
- **Right-click** a marker to delete it from the database
|
- **Right-click** a marker to delete it from the database
|
||||||
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
||||||
|
|
||||||
|
## Audio scan workflow
|
||||||
|
|
||||||
|
### 1. Build a dataset
|
||||||
|
|
||||||
|
Export clips manually from several videos. Clips from the same export folder (e.g. `mp4_Intense`) become your positive training class.
|
||||||
|
|
||||||
|
**Minimum dataset:** ~20 clips from 2–3 different videos. This is enough for the classifier to learn a basic boundary, but expect noisy results — you'll need to mark hard negatives and retrain.
|
||||||
|
|
||||||
|
**Ideal dataset:** 50–100+ clips from 5+ videos covering the full range of variation in your target sound (different recording conditions, distances, intensities). More variety in your positives makes the model generalize better to unseen footage. Negatives are sampled automatically from regions far from your markers, but adding explicit negatives of confusable sounds (e.g. thunder when training for explosions) significantly reduces false positives.
|
||||||
|
|
||||||
|
The classifier improves iteratively: export a small initial set → train → scan → mark false positives as hard negatives → retrain. Each cycle sharpens the decision boundary without needing a large upfront dataset.
|
||||||
|
|
||||||
|
### 2. Train a classifier
|
||||||
|
|
||||||
|
Click **Train** to open the training dialog:
|
||||||
|
|
||||||
|
- **Positive class** — select the export folder containing your target sounds
|
||||||
|
- **Negative class** — optional explicit negatives, or leave as "(auto only)" for automatic sampling
|
||||||
|
- **Model** — embedding model to use (HuBERT XLARGE recommended)
|
||||||
|
- **Auto-neg margin** — distance from markers to sample automatic negatives (30s default)
|
||||||
|
- **Include scan-exported clips** — whether to include previously scan-exported clips in training
|
||||||
|
|
||||||
|
The classifier trains a `HistGradientBoostingClassifier` on audio embeddings and saves to `models/`.
|
||||||
|
|
||||||
|
### 3. Scan videos
|
||||||
|
|
||||||
|
Select a trained model from the dropdown and click **Scan**. Adjust the threshold slider to control sensitivity. Detected regions appear as colored bands on the timeline and as rows in the results panel.
|
||||||
|
|
||||||
|
Audio embeddings are computed once per video and cached to disk (`cache/w2v/`). Subsequent scans with the same embedding model skip the GPU entirely and only re-run the classifier, which takes milliseconds. This makes the retrain → rescan loop nearly free after the first pass.
|
||||||
|
|
||||||
|
### 4. Review and refine
|
||||||
|
|
||||||
|
- Toggle **Review** mode for a clean timeline focused on scan results
|
||||||
|
- **Disable** false positive regions (Delete key) — they stay in the list but are excluded from export
|
||||||
|
- **Resize** regions by dragging edges on the timeline or editing times in the table
|
||||||
|
- **Mark as negative** — add false positives to the hard negative set for retraining
|
||||||
|
- **Ctrl+Z** to undo any of the above
|
||||||
|
|
||||||
|
### 5. Export results
|
||||||
|
|
||||||
|
Click **Export Scan Results** to batch export all enabled regions. The button shows the estimated clip count based on spread and minimum duration settings.
|
||||||
|
|
||||||
|
### 6. Retrain with feedback
|
||||||
|
|
||||||
|
Train again — hard negatives are automatically included. Each training run saves with a timestamp. Click the **⏲** button next to the model dropdown to restore a previous version if results degrade — restoring automatically rescans with the selected version.
|
||||||
|
|
||||||
## Database
|
## Database
|
||||||
|
|
||||||
Export history is stored in `~/.8cut.db` (SQLite). The database records filename, start time, output path, label, category, and all encoding settings for every clip. When you open a file, 8-cut matches the filename and pre-populates the timeline with existing markers.
|
Export history is stored in `~/.8cut.db` (SQLite). Tables:
|
||||||
|
|
||||||
|
| Table | Purpose |
|
||||||
|
|-------|---------|
|
||||||
|
| `processed` | Every exported clip with full encoding settings |
|
||||||
|
| `scan_results` | Audio scan detections per video/model |
|
||||||
|
| `hard_negatives` | Timestamps marked as false positives for training |
|
||||||
|
| `hidden_files` | Playlist files hidden by the user |
|
||||||
|
|
||||||
|
The database auto-migrates when new columns are added.
|
||||||
|
|
||||||
|
## File locations
|
||||||
|
|
||||||
|
| Path | Contents |
|
||||||
|
|------|----------|
|
||||||
|
| `~/.8cut.db` | SQLite database |
|
||||||
|
| `models/` | Trained classifier models (`.joblib`) |
|
||||||
|
| `cache/w2v/` | Embedding cache (`.npz`, keyed by video hash) |
|
||||||
|
| `cache/downloads/` | Downloaded pretrained models |
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
|
||||||
|
<defs>
|
||||||
|
<linearGradient id="g8" x1="0" y1="0" x2="0" y2="1">
|
||||||
|
<stop offset="0%" stop-color="#ffd230"/>
|
||||||
|
<stop offset="100%" stop-color="#e6a800"/>
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
<rect width="64" height="64" rx="13" fill="#161616"/>
|
||||||
|
<rect x="8" y="42" width="48" height="11" rx="2" fill="#2a2a2a" stroke="#333" stroke-width="1"/>
|
||||||
|
<rect x="26" y="42" width="16" height="11" fill="#3c82dc" fill-opacity="0.45"/>
|
||||||
|
<line x1="26" y1="38" x2="26" y2="55" stroke="#ffd230" stroke-width="2"/>
|
||||||
|
<polygon points="22,38 30,38 26,44" fill="#ffd230"/>
|
||||||
|
<text x="32" y="33" font-family="'Helvetica Neue',Helvetica,Arial,sans-serif" font-size="34" font-weight="bold" fill="url(#g8)" text-anchor="middle">8</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 790 B |
@@ -0,0 +1,6 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||||
|
<path d="M7.5 10 V7.5 a4.5 4.5 0 0 1 9 0 V10" stroke="#ffd230" stroke-width="2"/>
|
||||||
|
<rect x="5" y="10" width="14" height="10" rx="2" fill="#ffd230"/>
|
||||||
|
<circle cx="12" cy="14.3" r="1.4" fill="#161616"/>
|
||||||
|
<rect x="11.2" y="14.3" width="1.6" height="3.4" rx="0.8" fill="#161616"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 362 B |
@@ -0,0 +1,6 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||||
|
<path d="M7.5 10 V7.5 a4.5 4.5 0 0 1 8.6 -1.8" stroke="#8a8a8a" stroke-width="2"/>
|
||||||
|
<rect x="5" y="10" width="14" height="10" rx="2" fill="#8a8a8a"/>
|
||||||
|
<circle cx="12" cy="14.3" r="1.4" fill="#1e1e1e"/>
|
||||||
|
<rect x="11.2" y="14.3" width="1.6" height="3.4" rx="0.8" fill="#1e1e1e"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 363 B |
@@ -0,0 +1,4 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||||
|
<rect x="6.5" y="5" width="4" height="14" rx="1.2" fill="#ffd230"/>
|
||||||
|
<rect x="13.5" y="5" width="4" height="14" rx="1.2" fill="#ffd230"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 209 B |
@@ -0,0 +1,3 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||||
|
<path d="M7 5 L19 12 L7 19 Z" fill="#ffd230" stroke="#ffd230" stroke-width="1.5" stroke-linejoin="round"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 177 B |
@@ -0,0 +1,4 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#aad4ff" stroke-width="2" stroke-linecap="round">
|
||||||
|
<circle cx="10.5" cy="10.5" r="6"/>
|
||||||
|
<line x1="15" y1="15" x2="20" y2="20"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 217 B |
@@ -0,0 +1,6 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#ffd230" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<circle cx="6.5" cy="6.5" r="2.6"/>
|
||||||
|
<circle cx="6.5" cy="17.5" r="2.6"/>
|
||||||
|
<line x1="8.8" y1="8" x2="20" y2="17"/>
|
||||||
|
<line x1="8.8" y1="16" x2="20" y2="7"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 322 B |
@@ -0,0 +1,4 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#ffd230" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<polyline points="4,17 10,11 14,14 20,6"/>
|
||||||
|
<polyline points="15,6 20,6 20,11"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 245 B |
@@ -1,2 +1,6 @@
|
|||||||
import sys, os
|
import sys, os
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line("markers", "gui: constructs Qt widgets; needs a display")
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def build_annotation_json_path(folder: str) -> str:
|
||||||
|
return os.path.join(folder, "dataset.json")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||||
|
"""Remove the entry for *clip_path* from <folder>/dataset.json if present."""
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
if not os.path.exists(json_path):
|
||||||
|
return
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return
|
||||||
|
entries = [e for e in entries if e.get("path") != abs_path]
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||||
|
"""Insert or update one entry in <folder>/dataset.json.
|
||||||
|
|
||||||
|
Each entry stores a path relative to *folder* and the sound label.
|
||||||
|
Matches on ``path``; if an entry for the same clip already exists it is
|
||||||
|
replaced (overwrite-export case). Nothing is written when *label* is
|
||||||
|
empty.
|
||||||
|
"""
|
||||||
|
if not label.strip():
|
||||||
|
return
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
entries: list[dict] = []
|
||||||
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
entries = []
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
entry: dict = {"path": abs_path, "label": label}
|
||||||
|
for i, e in enumerate(entries):
|
||||||
|
if e.get("path") == abs_path:
|
||||||
|
entries[i] = entry
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
entries.append(entry)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
@@ -0,0 +1,809 @@
|
|||||||
|
"""Audio scanning — embedding-based classifier for audio event detection."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_SR = 16000 # lower sr = faster
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
|
||||||
|
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
|
||||||
|
cmd = [
|
||||||
|
_bin("ffmpeg"), "-i", path,
|
||||||
|
"-vn", # skip video
|
||||||
|
"-ac", "1", # mono
|
||||||
|
"-ar", str(sr), # resample
|
||||||
|
"-f", "f32le", # raw 32-bit float little-endian
|
||||||
|
"-loglevel", "error",
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"ffmpeg timed out (300s) on {os.path.basename(path)}")
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
|
||||||
|
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||||
|
_WINDOW = 8.0 # seconds
|
||||||
|
_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
_MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
|
||||||
|
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
|
||||||
|
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
|
||||||
|
|
||||||
|
# Redirect torch hub and huggingface downloads into the project
|
||||||
|
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
|
||||||
|
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Embedding extraction (lazy-loaded)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_w2v_model = None
|
||||||
|
_w2v_device = None
|
||||||
|
_w2v_model_name = None
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
|
||||||
|
# Supported embedding models — name → embed_dim
|
||||||
|
_EMBED_MODELS = {
|
||||||
|
"WAV2VEC2_BASE": 768,
|
||||||
|
"WAV2VEC2_LARGE": 1024,
|
||||||
|
"WAV2VEC2_LARGE_LV60K":1024,
|
||||||
|
"HUBERT_BASE": 768,
|
||||||
|
"HUBERT_LARGE": 1024,
|
||||||
|
"HUBERT_XLARGE": 1280,
|
||||||
|
"BEATS": 768,
|
||||||
|
# Multi-layer variants (4 quartile layers concatenated)
|
||||||
|
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
"EAT": 768,
|
||||||
|
"EAT_LARGE": 1024,
|
||||||
|
}
|
||||||
|
_DEFAULT_EMBED_MODEL = "EAT_LARGE"
|
||||||
|
|
||||||
|
_BEATS_CHECKPOINT = os.path.join(
|
||||||
|
_DL_CACHE_DIR, "huggingface", "hub",
|
||||||
|
"models--lpepino--beats_ckpts", "snapshots",
|
||||||
|
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_w2v_model(model_name: str | None = None):
|
||||||
|
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||||
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
# Multi-layer variants use the same base model weights
|
||||||
|
ml = _ml_config(model_name)
|
||||||
|
load_name = ml[0] if ml else model_name
|
||||||
|
if _w2v_model is None or _w2v_model_name != load_name:
|
||||||
|
import torch
|
||||||
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
if load_name == "BEATS":
|
||||||
|
from .beats_model import BEATs, BEATsConfig
|
||||||
|
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||||
|
weights_only=False)
|
||||||
|
cfg = BEATsConfig(checkpoint['cfg'])
|
||||||
|
_w2v_model = BEATs(cfg)
|
||||||
|
_w2v_model.load_state_dict(checkpoint['model'])
|
||||||
|
_w2v_model.to(_w2v_device)
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
|
elif load_name in ("EAT", "EAT_LARGE"):
|
||||||
|
from transformers import AutoModel
|
||||||
|
eat_repo = ("worstchan/EAT-large_epoch20_finetune_AS2M"
|
||||||
|
if load_name == "EAT_LARGE"
|
||||||
|
else "worstchan/EAT-base_epoch30_finetune_AS2M")
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
eat_repo, trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
|
else:
|
||||||
|
import torchaudio
|
||||||
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
|
||||||
|
_w2v_model.eval()
|
||||||
|
_w2v_model_name = load_name
|
||||||
|
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||||
|
return _w2v_model, _w2v_device
|
||||||
|
|
||||||
|
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_dim(model_name: str | None = None) -> int:
|
||||||
|
"""Return embedding dimension for a model name."""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
return _EMBED_MODELS.get(model_name, 768)
|
||||||
|
|
||||||
|
|
||||||
|
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||||
|
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||||
|
|
||||||
|
Returns None for single-layer models.
|
||||||
|
Layer indices are 0-based into the list returned by extract_features().
|
||||||
|
"""
|
||||||
|
if not model_name.endswith("_ML"):
|
||||||
|
return None
|
||||||
|
base = model_name[:-3] # strip "_ML"
|
||||||
|
if base not in _EMBED_MODELS:
|
||||||
|
return None
|
||||||
|
# Layer counts per model family
|
||||||
|
layer_counts = {
|
||||||
|
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||||
|
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||||
|
"AST": 12,
|
||||||
|
}
|
||||||
|
n = layer_counts.get(base)
|
||||||
|
if n is None:
|
||||||
|
return None
|
||||||
|
# Select 4 layers at quartile boundaries (0-indexed)
|
||||||
|
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||||
|
return base, indices
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> str:
|
||||||
|
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
abspath = os.path.abspath(video_path)
|
||||||
|
mtime = os.path.getmtime(abspath)
|
||||||
|
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
|
||||||
|
h = hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||||
|
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_exists(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> bool:
|
||||||
|
"""Check if embedding cache exists for a video."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
return os.path.exists(path)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _w2v_cache_load(video_path: str, hop: float, window: float,
|
||||||
|
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
|
||||||
|
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
|
||||||
|
try:
|
||||||
|
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
if os.path.exists(path):
|
||||||
|
data = np.load(path)
|
||||||
|
_log(f"audio_scan: cache hit ({path})")
|
||||||
|
return data["timestamps"], data["embeddings"]
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache read failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||||
|
hop: float = 1.0, window: float = _WINDOW,
|
||||||
|
video_path: str | None = None,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Extract embeddings for all sliding windows using a torchaudio model.
|
||||||
|
|
||||||
|
If video_path is given, results are cached to disk for fast re-scans.
|
||||||
|
Returns (timestamps, embeddings) where embeddings is (N, D).
|
||||||
|
"""
|
||||||
|
edim = _embed_dim(model_name)
|
||||||
|
|
||||||
|
# Try loading from cache
|
||||||
|
cache_file = None
|
||||||
|
if video_path:
|
||||||
|
try:
|
||||||
|
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
data = np.load(cache_file)
|
||||||
|
_log(f"audio_scan: cache hit ({cache_file})")
|
||||||
|
return data["timestamps"], data["embeddings"]
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache read failed: {e}")
|
||||||
|
|
||||||
|
win_samples = int(window * sr)
|
||||||
|
hop_samples = int(hop * sr)
|
||||||
|
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
|
||||||
|
|
||||||
|
if n_windows == 0:
|
||||||
|
return np.array([]), np.empty((0, edim))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
model, device = _get_w2v_model(model_name)
|
||||||
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
# Auto-size batches based on available GPU memory
|
||||||
|
batch_size = 16
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
||||||
|
if vram_gb >= 16:
|
||||||
|
batch_size = 64
|
||||||
|
elif vram_gb >= 8:
|
||||||
|
batch_size = 32
|
||||||
|
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
timestamps = np.arange(n_windows) * hop
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for batch_start in range(0, n_windows, batch_size):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return np.array([]), np.empty((0, edim))
|
||||||
|
batch_end = min(batch_start + batch_size, n_windows)
|
||||||
|
chunks = []
|
||||||
|
for i in range(batch_start, batch_end):
|
||||||
|
start = i * hop_samples
|
||||||
|
chunks.append(y[start:start + win_samples])
|
||||||
|
with torch.no_grad():
|
||||||
|
if is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings.append(batch_emb)
|
||||||
|
|
||||||
|
result_ts = timestamps
|
||||||
|
result_emb = np.vstack(embeddings)
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
if cache_file:
|
||||||
|
try:
|
||||||
|
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
|
||||||
|
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
|
||||||
|
_log(f"audio_scan: w2v cache saved ({cache_file})")
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: cache write failed: {e}")
|
||||||
|
|
||||||
|
return result_ts, result_emb
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||||
|
gt_soft: list[float], tolerance: float = 12.0,
|
||||||
|
neg_margin: float = 120.0,
|
||||||
|
model_name: str | None = None,
|
||||||
|
gt_negative: list[float] | None = None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Extract embeddings only near positives and distant negatives.
|
||||||
|
|
||||||
|
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
|
||||||
|
"""
|
||||||
|
edim = _embed_dim(model_name)
|
||||||
|
duration = len(y) / sr
|
||||||
|
win_samples = int(_WINDOW * sr)
|
||||||
|
all_gt = list(gt_intense) + list(gt_soft)
|
||||||
|
|
||||||
|
# Positive windows: every second near intense markers
|
||||||
|
pos_times = set()
|
||||||
|
for gt in gt_intense:
|
||||||
|
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||||
|
t = gt + offset
|
||||||
|
if 0 <= t <= duration - _WINDOW:
|
||||||
|
pos_times.add(int(t))
|
||||||
|
|
||||||
|
# Manual negative windows: near explicit negative markers
|
||||||
|
manual_neg_times = set()
|
||||||
|
if gt_negative:
|
||||||
|
for gt in gt_negative:
|
||||||
|
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||||
|
t = gt + offset
|
||||||
|
if 0 <= t <= duration - _WINDOW:
|
||||||
|
manual_neg_times.add(int(t))
|
||||||
|
# Don't let manual negatives overlap with positives
|
||||||
|
manual_neg_times -= pos_times
|
||||||
|
|
||||||
|
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0 or no markers)
|
||||||
|
neg_times = set()
|
||||||
|
if all_gt and neg_margin > 0:
|
||||||
|
for t in range(0, int(duration - _WINDOW), 4):
|
||||||
|
if min(abs(t - g) for g in all_gt) > neg_margin:
|
||||||
|
neg_times.add(t)
|
||||||
|
|
||||||
|
all_times = sorted(pos_times | neg_times | manual_neg_times)
|
||||||
|
# Filter out windows that go past the end
|
||||||
|
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
||||||
|
|
||||||
|
if not valid_times:
|
||||||
|
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
model, device = _get_w2v_model(model_name)
|
||||||
|
batch_size = 16
|
||||||
|
timestamps_list: list[float] = []
|
||||||
|
embeddings_list: list[np.ndarray] = []
|
||||||
|
|
||||||
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
|
||||||
|
for batch_start in range(0, len(valid_times), batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||||
|
chunks = []
|
||||||
|
for t in valid_times[batch_start:batch_end]:
|
||||||
|
start = int(t * sr)
|
||||||
|
chunks.append(y[start:start + win_samples])
|
||||||
|
timestamps_list.append(float(t))
|
||||||
|
with torch.no_grad():
|
||||||
|
if is_ast:
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings_list.append(batch_emb)
|
||||||
|
|
||||||
|
timestamps = np.array(timestamps_list)
|
||||||
|
embeddings = np.vstack(embeddings_list)
|
||||||
|
|
||||||
|
labels = np.zeros(len(timestamps), dtype=int)
|
||||||
|
for i, t in enumerate(timestamps):
|
||||||
|
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||||
|
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||||
|
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
|
||||||
|
if di < tolerance:
|
||||||
|
labels[i] = 1
|
||||||
|
elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
|
||||||
|
labels[i] = -1
|
||||||
|
return timestamps, embeddings, labels
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Classifier mode — train / save / load / scan
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
||||||
|
model_path: str | None = None,
|
||||||
|
tolerance: float = 12.0,
|
||||||
|
neg_margin: float = 120.0,
|
||||||
|
embed_model: str | None = None,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
n_workers: int = 4,
|
||||||
|
progress_cb: object = None) -> dict:
|
||||||
|
"""Train a classifier from labeled videos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_infos: list of (video_path, intense_times, soft_times)
|
||||||
|
model_path: if given, save model to this path
|
||||||
|
tolerance/neg_margin: labeling parameters
|
||||||
|
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
||||||
|
cancel_flag: object with _cancel attribute; if set, training aborts early
|
||||||
|
n_workers: number of threads for parallel audio loading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
||||||
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||||
|
|
||||||
|
def _progress(msg: str) -> None:
|
||||||
|
_log(msg)
|
||||||
|
if progress_cb:
|
||||||
|
progress_cb(msg)
|
||||||
|
|
||||||
|
def _load_audio(path: str) -> np.ndarray:
|
||||||
|
return _load_audio_ffmpeg(path, sr=_SR)
|
||||||
|
|
||||||
|
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
|
||||||
|
n = len(video_infos)
|
||||||
|
load_workers = min(n_workers, 4)
|
||||||
|
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
|
||||||
|
audio_data: dict[int, np.ndarray] = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=load_workers) as pool:
|
||||||
|
future_to_idx = {
|
||||||
|
pool.submit(_load_audio, vi[0]): i
|
||||||
|
for i, vi in enumerate(video_infos)
|
||||||
|
}
|
||||||
|
failed = set()
|
||||||
|
for future in as_completed(future_to_idx):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: training cancelled")
|
||||||
|
return None
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
try:
|
||||||
|
audio_data[idx] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
|
||||||
|
failed.add(idx)
|
||||||
|
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
|
||||||
|
|
||||||
|
# Phase 2: extract embeddings sequentially on GPU
|
||||||
|
_progress(f"Extracting embeddings: 0/{n}")
|
||||||
|
all_X, all_y = [], []
|
||||||
|
for vi, vinfo in enumerate(video_infos):
|
||||||
|
if vi in failed:
|
||||||
|
continue
|
||||||
|
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
|
||||||
|
gt_negative = vinfo[3] if len(vinfo) > 3 else []
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: training cancelled")
|
||||||
|
return None
|
||||||
|
_progress(f"Extracting embeddings: {vi+1}/{n}")
|
||||||
|
y = audio_data.pop(vi)
|
||||||
|
|
||||||
|
timestamps, embeddings, labels = _extract_w2v_targeted(
|
||||||
|
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
||||||
|
model_name=embed_model, gt_negative=gt_negative,
|
||||||
|
)
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
continue
|
||||||
|
# Per-video z-score normalize
|
||||||
|
vid_mean = embeddings.mean(axis=0)
|
||||||
|
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
|
||||||
|
normed = (embeddings - vid_mean) / vid_std
|
||||||
|
for i in range(len(labels)):
|
||||||
|
if labels[i] == 1:
|
||||||
|
all_X.append(normed[i])
|
||||||
|
all_y.append(1)
|
||||||
|
elif labels[i] == -1:
|
||||||
|
all_X.append(normed[i])
|
||||||
|
all_y.append(0)
|
||||||
|
|
||||||
|
if not all_X:
|
||||||
|
_log("audio_scan: no training samples collected")
|
||||||
|
return None
|
||||||
|
|
||||||
|
X = np.stack(all_X)
|
||||||
|
y_arr = np.array(all_y)
|
||||||
|
n_pos = (y_arr == 1).sum()
|
||||||
|
n_neg = (y_arr == 0).sum()
|
||||||
|
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
|
||||||
|
|
||||||
|
if n_pos == 0 or n_neg == 0:
|
||||||
|
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Subsample negatives for balance
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
pos_idx = np.where(y_arr == 1)[0]
|
||||||
|
neg_idx = np.where(y_arr == 0)[0]
|
||||||
|
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
|
||||||
|
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
|
||||||
|
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||||
|
rng.shuffle(train_idx)
|
||||||
|
|
||||||
|
_progress(f"Fitting classifier on {len(train_idx)} samples...")
|
||||||
|
clf = HistGradientBoostingClassifier(
|
||||||
|
max_iter=200, max_depth=5, learning_rate=0.1, random_state=42,
|
||||||
|
)
|
||||||
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
# Calibrate probabilities for better threshold behavior
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
min_class = min(int(n_pos), int(n_neg_sample))
|
||||||
|
if min_class >= 6:
|
||||||
|
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||||
|
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
clf = cal_clf
|
||||||
|
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||||
|
|
||||||
|
model = {"classifier": clf, "n_features": X.shape[1],
|
||||||
|
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
import joblib
|
||||||
|
from datetime import datetime
|
||||||
|
parent = os.path.dirname(model_path)
|
||||||
|
if parent:
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
# Save with timestamp in name; keep a symlink/copy as the "latest"
|
||||||
|
stem, ext = os.path.splitext(model_path)
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
versioned = f"{stem}_{ts}{ext}"
|
||||||
|
joblib.dump(model, versioned)
|
||||||
|
_log(f"audio_scan: model saved to {versioned}")
|
||||||
|
# Update the base path to point to latest version (copy)
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(versioned, model_path)
|
||||||
|
_log(f"audio_scan: latest model updated: {model_path}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_classifier(model_path: str) -> dict | None:
|
||||||
|
"""Load a saved classifier model."""
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
return None
|
||||||
|
import joblib
|
||||||
|
return joblib.load(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def default_model_path(profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> str:
|
||||||
|
"""Return the path for a profile's classifier model.
|
||||||
|
|
||||||
|
When embed_model is given the file is ``{profile}_{model}.joblib``,
|
||||||
|
otherwise ``{profile}.joblib`` (legacy single-model layout).
|
||||||
|
"""
|
||||||
|
if embed_model:
|
||||||
|
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
|
||||||
|
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def list_model_versions(profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> list[tuple[str, str]]:
|
||||||
|
"""Return available backup versions for a model, newest first.
|
||||||
|
|
||||||
|
Returns list of (timestamp_label, file_path).
|
||||||
|
The current (active) model is listed first as "current".
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
current = default_model_path(profile_name, embed_model)
|
||||||
|
stem, ext = os.path.splitext(current)
|
||||||
|
versions: list[tuple[str, str]] = []
|
||||||
|
if os.path.exists(current):
|
||||||
|
versions.append(("current", current))
|
||||||
|
if not os.path.isdir(_MODEL_DIR):
|
||||||
|
return versions
|
||||||
|
pattern = re.compile(re.escape(os.path.basename(stem)) + r"_(\d{8}_\d{6})" + re.escape(ext) + "$")
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
versions.append((m.group(1), os.path.join(_MODEL_DIR, fname)))
|
||||||
|
# Sort backups newest first (after "current")
|
||||||
|
current_entry = versions[:1]
|
||||||
|
backups = sorted(versions[1:], key=lambda v: v[0], reverse=True)
|
||||||
|
return current_entry + backups
|
||||||
|
|
||||||
|
|
||||||
|
def restore_model_version(version_path: str, profile_name: str = "default",
|
||||||
|
embed_model: str | None = None) -> None:
|
||||||
|
"""Restore a backup version as the active model."""
|
||||||
|
import filecmp, shutil
|
||||||
|
from datetime import datetime
|
||||||
|
current = default_model_path(profile_name, embed_model)
|
||||||
|
if version_path == current:
|
||||||
|
return
|
||||||
|
# Back up current before replacing — but only if no identical backup exists
|
||||||
|
if os.path.exists(current):
|
||||||
|
stem, ext = os.path.splitext(current)
|
||||||
|
already_saved = False
|
||||||
|
if os.path.isdir(_MODEL_DIR):
|
||||||
|
import re
|
||||||
|
pat = re.compile(re.escape(os.path.basename(stem)) + r"_\d{8}_\d{6}" + re.escape(ext) + "$")
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
if pat.match(fname):
|
||||||
|
candidate = os.path.join(_MODEL_DIR, fname)
|
||||||
|
if filecmp.cmp(current, candidate, shallow=False):
|
||||||
|
already_saved = True
|
||||||
|
break
|
||||||
|
if not already_saved:
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
shutil.move(current, f"{stem}_{ts}{ext}")
|
||||||
|
shutil.copy2(version_path, current)
|
||||||
|
_log(f"audio_scan: restored {os.path.basename(version_path)} as active model")
|
||||||
|
|
||||||
|
|
||||||
|
def list_trained_models(profile_name: str = "default") -> list[str]:
|
||||||
|
"""Return embedding model keys that have a trained .joblib for *profile_name*.
|
||||||
|
|
||||||
|
Looks for files matching ``{profile}_{KEY}.joblib`` in the models dir.
|
||||||
|
KEY is either a bare embed model name (e.g. ``EAT_LARGE``) or
|
||||||
|
``{MODEL}_{name}`` for user-named variants.
|
||||||
|
"""
|
||||||
|
prefix = f"{profile_name}_"
|
||||||
|
suffix = ".joblib"
|
||||||
|
result = []
|
||||||
|
if not os.path.isdir(_MODEL_DIR):
|
||||||
|
return result
|
||||||
|
for fname in os.listdir(_MODEL_DIR):
|
||||||
|
if fname.startswith(prefix) and fname.endswith(suffix):
|
||||||
|
key = fname[len(prefix):-len(suffix)]
|
||||||
|
if key in _EMBED_MODELS:
|
||||||
|
result.append(key)
|
||||||
|
else:
|
||||||
|
for m in _EMBED_MODELS:
|
||||||
|
if key.startswith(m + "_"):
|
||||||
|
result.append(key)
|
||||||
|
break
|
||||||
|
# Also check legacy {profile}.joblib
|
||||||
|
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||||
|
if os.path.exists(legacy) and not result:
|
||||||
|
result.append("")
|
||||||
|
return sorted(result)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scanning
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _fuse_regions(regions: list[tuple[float, float, float]]
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Merge overlapping/adjacent regions, keeping max score."""
|
||||||
|
if not regions:
|
||||||
|
return []
|
||||||
|
by_start = sorted(regions, key=lambda r: r[0])
|
||||||
|
fused: list[tuple[float, float, float]] = []
|
||||||
|
s, e, sc = by_start[0]
|
||||||
|
for s2, e2, sc2 in by_start[1:]:
|
||||||
|
if s2 <= e: # overlapping or touching
|
||||||
|
e = max(e, e2)
|
||||||
|
sc = max(sc, sc2)
|
||||||
|
else:
|
||||||
|
fused.append((s, e, sc))
|
||||||
|
s, e, sc = s2, e2, sc2
|
||||||
|
fused.append((s, e, sc))
|
||||||
|
return fused
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_audio(video_path: str, embed_model: str | None = None,
|
||||||
|
hop: float = 1.0, window: float = _WINDOW) -> np.ndarray | None:
|
||||||
|
"""Pre-load audio for a video if embeddings aren't cached.
|
||||||
|
|
||||||
|
Returns the raw audio array, or None if cache already exists.
|
||||||
|
Call from a background thread while the GPU is busy with another video.
|
||||||
|
"""
|
||||||
|
if _w2v_cache_exists(video_path, hop, window, embed_model):
|
||||||
|
return None
|
||||||
|
_log(f"audio_scan: prefetching {os.path.basename(video_path)}")
|
||||||
|
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||||
|
_log(f"audio_scan: prefetched {len(y)/_SR:.1f}s")
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def scan_video(
|
||||||
|
video_path: str,
|
||||||
|
model: dict = None,
|
||||||
|
threshold: float = 0.50,
|
||||||
|
hop: float = 1.0,
|
||||||
|
window: float = _WINDOW,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
prefetched_audio: np.ndarray | None = None,
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Scan a video for matching audio regions using a trained classifier.
|
||||||
|
|
||||||
|
Returns list of (start_time, end_time, score) above threshold.
|
||||||
|
If prefetched_audio is provided, skips the ffmpeg decode step.
|
||||||
|
"""
|
||||||
|
if model is None:
|
||||||
|
_log("audio_scan: no model provided")
|
||||||
|
return []
|
||||||
|
|
||||||
|
clf = model["classifier"]
|
||||||
|
embed_model = model.get("embed_model")
|
||||||
|
|
||||||
|
# Try cache first — skip expensive audio loading if embeddings exist
|
||||||
|
cached = _w2v_cache_load(video_path, hop, window, embed_model)
|
||||||
|
if cached is not None:
|
||||||
|
timestamps, window_vectors = cached
|
||||||
|
else:
|
||||||
|
if prefetched_audio is not None:
|
||||||
|
_log(f"audio_scan: using prefetched audio")
|
||||||
|
y = prefetched_audio
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: loading {video_path}")
|
||||||
|
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||||
|
sr = _SR
|
||||||
|
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
|
||||||
|
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
||||||
|
timestamps, window_vectors = _extract_w2v_windows(
|
||||||
|
y, sr, hop=hop, window=window, video_path=video_path,
|
||||||
|
cancel_flag=cancel_flag, model_name=embed_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(timestamps) == 0:
|
||||||
|
_log("audio_scan: video shorter than window")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Per-video z-score normalize
|
||||||
|
vid_mean = window_vectors.mean(axis=0)
|
||||||
|
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
|
||||||
|
normed = (window_vectors - vid_mean) / vid_std
|
||||||
|
|
||||||
|
_log(f"audio_scan: classifying {len(normed)} windows...")
|
||||||
|
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
probs = clf.predict_proba(normed)[:, 1]
|
||||||
|
mask = probs >= threshold
|
||||||
|
raw = [
|
||||||
|
(timestamps[i], timestamps[i] + window, float(probs[i]))
|
||||||
|
for i in np.nonzero(mask)[0]
|
||||||
|
]
|
||||||
|
results = _fuse_regions(raw)
|
||||||
|
_log(f"audio_scan: {len(results)} regions above threshold {threshold} (from {len(raw)} raw)")
|
||||||
|
return results
|
||||||
@@ -0,0 +1,783 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import LayerNorm, Parameter
|
||||||
|
from .beats_modules import (
|
||||||
|
GradMultiply,
|
||||||
|
SamePad,
|
||||||
|
get_activation_fn,
|
||||||
|
GLU_Linear,
|
||||||
|
quant_noise,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.embedding_dim = args.encoder_embed_dim
|
||||||
|
|
||||||
|
self.pos_conv = nn.Conv1d(
|
||||||
|
self.embedding_dim,
|
||||||
|
self.embedding_dim,
|
||||||
|
kernel_size=args.conv_pos,
|
||||||
|
padding=args.conv_pos // 2,
|
||||||
|
groups=args.conv_pos_groups,
|
||||||
|
)
|
||||||
|
dropout = 0
|
||||||
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||||
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
self.relative_position_embedding = args.relative_position_embedding
|
||||||
|
self.num_buckets = args.num_buckets
|
||||||
|
self.max_distance = args.max_distance
|
||||||
|
else:
|
||||||
|
self.relative_position_embedding = False
|
||||||
|
self.num_buckets = 0
|
||||||
|
self.max_distance = 0
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
deep_norm=args.deep_norm,
|
||||||
|
has_relative_attention_bias=self.relative_position_embedding,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
gru_rel_pos=args.gru_rel_pos,
|
||||||
|
encoder_layers=args.encoder_layers,
|
||||||
|
)
|
||||||
|
for i in range(args.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.relative_position_embedding:
|
||||||
|
for i in range(1, args.encoder_layers):
|
||||||
|
del self.layers[i].self_attn.relative_attention_bias
|
||||||
|
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||||
|
|
||||||
|
self.layer_norm_first = args.layer_norm_first
|
||||||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
self.layerdrop = args.encoder_layerdrop
|
||||||
|
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
if args.deep_norm:
|
||||||
|
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||||
|
for i in range(args.encoder_layers):
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, layer=None):
|
||||||
|
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x[padding_mask] = 0
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x = x + x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
z = None
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
r = None
|
||||||
|
pos_bias = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||||
|
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||||
|
dropout_probability = np.random.random()
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: float = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
deep_norm: bool = False,
|
||||||
|
has_relative_attention_bias: bool = False,
|
||||||
|
num_buckets: int = 0,
|
||||||
|
max_distance: int = 0,
|
||||||
|
rescale_init: bool = False,
|
||||||
|
gru_rel_pos: bool = False,
|
||||||
|
encoder_layers: int = 0,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
|
||||||
|
self.activation_name = activation_fn
|
||||||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
self.embedding_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
dropout=attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
max_distance=max_distance,
|
||||||
|
rescale_init=rescale_init,
|
||||||
|
gru_rel_pos=gru_rel_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_norm_first = layer_norm_first
|
||||||
|
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||||
|
else:
|
||||||
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||||
|
|
||||||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
self.deep_norm = deep_norm
|
||||||
|
if self.deep_norm:
|
||||||
|
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||||
|
else:
|
||||||
|
self.deep_norm_alpha = 1
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
self_attn_mask: torch.Tensor = None,
|
||||||
|
self_attn_padding_mask: torch.Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
pos_bias=None
|
||||||
|
):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if self.layer_norm_first:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, pos_bias
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention.
|
||||||
|
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
q_noise=0.0,
|
||||||
|
qn_block_size=8,
|
||||||
|
has_relative_attention_bias=False,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
gru_rel_pos=False,
|
||||||
|
rescale_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout_module = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_head_dim = self.head_dim
|
||||||
|
self.k_head_dim = self.head_dim
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
|
)
|
||||||
|
|
||||||
|
k_bias = True
|
||||||
|
if rescale_init:
|
||||||
|
k_bias = False
|
||||||
|
|
||||||
|
k_embed_dim = embed_dim
|
||||||
|
q_embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.k_proj = quant_noise(
|
||||||
|
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.v_proj = quant_noise(
|
||||||
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.q_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self.gru_rel_pos = gru_rel_pos
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||||
|
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.qkv_same_dim:
|
||||||
|
# Empirically observed the convergence to be much better with
|
||||||
|
# the scaled initialization
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
if self.out_proj.bias is not None:
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
if self.bias_k is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_v)
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||||
|
|
||||||
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
max_distance = self.max_distance
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets = num_buckets // 2
|
||||||
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||||
|
relative_positions = torch.abs(relative_positions)
|
||||||
|
else:
|
||||||
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_positions < max_exact
|
||||||
|
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_positions.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = self._relative_positions_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True
|
||||||
|
)
|
||||||
|
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
values = values.permute([2, 0, 1])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key: Optional[Tensor],
|
||||||
|
value: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
static_kv: bool = False,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
before_softmax: bool = False,
|
||||||
|
need_head_weights: bool = False,
|
||||||
|
position_bias: Optional[Tensor] = None
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
|
"""Input shape: Time x Batch x Channel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
|
padding elements are indicated by 1s.
|
||||||
|
need_weights (bool, optional): return the attention weights,
|
||||||
|
averaged over heads (default: False).
|
||||||
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
|
implement causal attention, where the mask prevents the
|
||||||
|
attention from looking forward in time (default: None).
|
||||||
|
before_softmax (bool, optional): return the raw attention
|
||||||
|
weights and values before the attention softmax.
|
||||||
|
need_head_weights (bool, optional): return the attention
|
||||||
|
weights for each head. Implies *need_weights*. Default:
|
||||||
|
return the average attention weights over all heads.
|
||||||
|
"""
|
||||||
|
if need_head_weights:
|
||||||
|
need_weights = True
|
||||||
|
|
||||||
|
is_tpu = query.device.type == "xla"
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
|
if key is not None:
|
||||||
|
src_len, key_bsz, _ = key.size()
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
assert key_bsz == bsz
|
||||||
|
assert value is not None
|
||||||
|
assert src_len, bsz == value.shape[:2]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias and position_bias is None:
|
||||||
|
position_bias = self.compute_bias(tgt_len, src_len)
|
||||||
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
saved_state = self._get_input_buffer(incremental_state)
|
||||||
|
if saved_state is not None and "prev_key" in saved_state:
|
||||||
|
# previous time steps are cached - no need to recompute
|
||||||
|
# key and value if they are static
|
||||||
|
if static_kv:
|
||||||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||||||
|
key = value = None
|
||||||
|
else:
|
||||||
|
saved_state = None
|
||||||
|
|
||||||
|
if self.self_attention:
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(query)
|
||||||
|
v = self.v_proj(query)
|
||||||
|
elif self.encoder_decoder_attention:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self.q_proj(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(key)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert key is not None and value is not None
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
alpha = 32
|
||||||
|
q *= 1 / alpha
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||||
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
q.contiguous()
|
||||||
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if k is not None:
|
||||||
|
k = (
|
||||||
|
k.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if v is not None:
|
||||||
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved_state is not None:
|
||||||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
|
if "prev_key" in saved_state:
|
||||||
|
_prev_key = saved_state["prev_key"]
|
||||||
|
assert _prev_key is not None
|
||||||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
k = prev_key
|
||||||
|
else:
|
||||||
|
assert k is not None
|
||||||
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
|
src_len = k.size(1)
|
||||||
|
if "prev_value" in saved_state:
|
||||||
|
_prev_value = saved_state["prev_value"]
|
||||||
|
assert _prev_value is not None
|
||||||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
v = prev_value
|
||||||
|
else:
|
||||||
|
assert v is not None
|
||||||
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||||||
|
if "prev_key_padding_mask" in saved_state:
|
||||||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||||
|
assert k is not None and v is not None
|
||||||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||||||
|
batch_size=bsz,
|
||||||
|
src_len=k.size(1),
|
||||||
|
static_kv=static_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
|
# In this branch incremental_state is never None
|
||||||
|
assert incremental_state is not None
|
||||||
|
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||||
|
assert k is not None
|
||||||
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
|
# This is part of a workaround to get around fork/join parallelism
|
||||||
|
# not supporting Optional types.
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||||
|
key_padding_mask = None
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz
|
||||||
|
assert key_padding_mask.size(1) == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
assert v is not None
|
||||||
|
src_len += 1
|
||||||
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||||
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||||
|
key_padding_mask
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||||
|
|
||||||
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# don't attend to padding symbols
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
if not is_tpu:
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if before_softmax:
|
||||||
|
return attn_weights, v, position_bias
|
||||||
|
|
||||||
|
if position_bias is not None:
|
||||||
|
attn_mask_rel_pos = position_bias
|
||||||
|
if self.gru_rel_pos == 1:
|
||||||
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||||
|
_B, _H, _L, __ = query_layer.size()
|
||||||
|
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||||
|
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||||
|
|
||||||
|
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attn_mask_rel_pos
|
||||||
|
|
||||||
|
attn_weights_float = F.softmax(
|
||||||
|
attn_weights, dim=-1
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
assert v is not None
|
||||||
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights: Optional[Tensor] = None
|
||||||
|
if need_weights:
|
||||||
|
attn_weights = attn_weights_float.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
).transpose(1, 0)
|
||||||
|
if not need_head_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_weights = attn_weights.mean(dim=0)
|
||||||
|
|
||||||
|
return attn, attn_weights, position_bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_prev_key_padding_mask(
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
prev_key_padding_mask: Optional[Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
src_len: int,
|
||||||
|
static_kv: bool,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
# During incremental decoding, as the padding token enters and
|
||||||
|
# leaves the frame, there will be a time when prev or current
|
||||||
|
# is None
|
||||||
|
elif prev_key_padding_mask is not None:
|
||||||
|
if src_len > prev_key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
device=prev_key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask.float()
|
||||||
|
elif key_padding_mask is not None:
|
||||||
|
if src_len > key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
device=key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[filler.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = key_padding_mask.float()
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
return new_key_padding_mask
|
||||||
|
|
||||||
|
def _get_input_buffer(
|
||||||
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||||
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||||
|
return empty_result
|
||||||
|
|
||||||
|
def _set_input_buffer(
|
||||||
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
|
):
|
||||||
|
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||||
|
|
||||||
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||||
|
return attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
"""
|
||||||
|
Initialize the weights specific to the BERT Model.
|
||||||
|
This overrides the default initializations depending on the specified arguments.
|
||||||
|
1. If normal_init_linear_weights is set then weights of linear
|
||||||
|
layer will be initialized using the normal distribution and
|
||||||
|
bais will be set to the specified value.
|
||||||
|
2. If normal_init_embed_weights is set then weights of embedding
|
||||||
|
layer will be initialized using the normal distribution.
|
||||||
|
3. If normal_init_proj_weights is set then weights of
|
||||||
|
in_project_weight for MultiHeadAttention initialized using
|
||||||
|
the normal distribution (to be validated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||||
|
# so that the RNG is consistent with and without FSDP
|
||||||
|
data.copy_(
|
||||||
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
normal_(module.q_proj.weight.data)
|
||||||
|
normal_(module.k_proj.weight.data)
|
||||||
|
normal_(module.v_proj.weight.data)
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
import torchaudio.compliance.kaldi as ta_kaldi
|
||||||
|
|
||||||
|
from .beats_backbone import (
|
||||||
|
TransformerEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATsConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.input_patch_size: int = -1 # path size of patch embedding
|
||||||
|
self.embed_dim: int = 512 # patch embedding dimension
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||||
|
|
||||||
|
# label predictor
|
||||||
|
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||||
|
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||||
|
self.predictor_class: int = 527 # target class number for the predictor
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATs(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: BEATsConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.embed = cfg.embed_dim
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||||
|
if self.embed != cfg.encoder_embed_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_patch_size = cfg.input_patch_size
|
||||||
|
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||||
|
bias=cfg.conv_bias)
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
|
||||||
|
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
if cfg.finetuned_model:
|
||||||
|
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||||
|
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||||
|
else:
|
||||||
|
self.predictor = None
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(
|
||||||
|
padding_mask.size(0), features.size(1), -1
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask.all(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fbanks = []
|
||||||
|
for waveform in source:
|
||||||
|
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||||
|
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||||
|
fbanks.append(fbank)
|
||||||
|
fbank = torch.stack(fbanks, dim=0)
|
||||||
|
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
):
|
||||||
|
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||||
|
|
||||||
|
fbank = fbank.unsqueeze(1)
|
||||||
|
features = self.patch_embedding(fbank)
|
||||||
|
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
|
||||||
|
x = self.dropout_input(features)
|
||||||
|
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.predictor is not None:
|
||||||
|
x = self.predictor_dropout(x)
|
||||||
|
logits = self.predictor(x)
|
||||||
|
|
||||||
|
if padding_mask is not None and padding_mask.any():
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1)
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||||
|
else:
|
||||||
|
logits = logits.mean(dim=1)
|
||||||
|
|
||||||
|
lprobs = torch.sigmoid(logits)
|
||||||
|
|
||||||
|
return lprobs, padding_mask
|
||||||
|
else:
|
||||||
|
return x, padding_mask
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class GradMultiply(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scale):
|
||||||
|
ctx.scale = scale
|
||||||
|
res = x.new(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
return grad * ctx.scale, None
|
||||||
|
|
||||||
|
|
||||||
|
class SamePad(nn.Module):
|
||||||
|
def __init__(self, kernel_size, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
if causal:
|
||||||
|
self.remove = kernel_size - 1
|
||||||
|
else:
|
||||||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.remove > 0:
|
||||||
|
x = x[:, :, : -self.remove]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
self.act = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU_Linear(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||||
|
super(GLU_Linear, self).__init__()
|
||||||
|
|
||||||
|
self.glu_type = glu_type
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
if glu_type == "sigmoid":
|
||||||
|
self.glu_act = torch.nn.Sigmoid()
|
||||||
|
elif glu_type == "swish":
|
||||||
|
self.glu_act = Swish()
|
||||||
|
elif glu_type == "relu":
|
||||||
|
self.glu_act = torch.nn.ReLU()
|
||||||
|
elif glu_type == "gelu":
|
||||||
|
self.glu_act = torch.nn.GELU()
|
||||||
|
|
||||||
|
if bias_in_glu:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||||
|
else:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.glu_type == "bilinear":
|
||||||
|
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||||
|
else:
|
||||||
|
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_accurate(x):
|
||||||
|
if not hasattr(gelu_accurate, "_a"):
|
||||||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||||
|
return (
|
||||||
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation: str):
|
||||||
|
"""Returns the activation function corresponding to `activation`"""
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
warnings.warn(
|
||||||
|
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||||
|
)
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "gelu_accurate":
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "tanh":
|
||||||
|
return torch.tanh
|
||||||
|
elif activation == "linear":
|
||||||
|
return lambda x: x
|
||||||
|
elif activation == "glu":
|
||||||
|
return lambda x: x
|
||||||
|
else:
|
||||||
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||||
|
|
||||||
|
|
||||||
|
def quant_noise(module, p, block_size):
|
||||||
|
"""
|
||||||
|
Wraps modules and applies quantization noise to the weights for
|
||||||
|
subsequent quantization with Iterative Product Quantization as
|
||||||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- module: nn.Module
|
||||||
|
- p: amount of Quantization Noise
|
||||||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
- Module weights must have the right sizes wrt the block size
|
||||||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||||||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||||
|
- We implement the simplest form of noise here as stated in the paper
|
||||||
|
which consists in randomly dropping blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no quantization noise, don't register hook
|
||||||
|
if p <= 0:
|
||||||
|
return module
|
||||||
|
|
||||||
|
# supported modules
|
||||||
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||||
|
|
||||||
|
# test whether module.weight has the right sizes wrt block_size
|
||||||
|
is_conv = module.weight.ndim == 4
|
||||||
|
|
||||||
|
# 2D matrix
|
||||||
|
if not is_conv:
|
||||||
|
assert (
|
||||||
|
module.weight.size(1) % block_size == 0
|
||||||
|
), "Input features must be a multiple of block sizes"
|
||||||
|
|
||||||
|
# 4D matrix
|
||||||
|
else:
|
||||||
|
# 1x1 convolutions
|
||||||
|
if module.kernel_size == (1, 1):
|
||||||
|
assert (
|
||||||
|
module.in_channels % block_size == 0
|
||||||
|
), "Input channels must be a multiple of block sizes"
|
||||||
|
# regular convolutions
|
||||||
|
else:
|
||||||
|
k = module.kernel_size[0] * module.kernel_size[1]
|
||||||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||||
|
|
||||||
|
def _forward_pre_hook(mod, input):
|
||||||
|
# no noise for evaluation
|
||||||
|
if mod.training:
|
||||||
|
if not is_conv:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_features = weight.size(1)
|
||||||
|
out_features = weight.size(0)
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
mask = torch.zeros(
|
||||||
|
in_features // block_size * out_features, device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_channels = mod.in_channels
|
||||||
|
out_channels = mod.out_channels
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
if mod.kernel_size == (1, 1):
|
||||||
|
mask = torch.zeros(
|
||||||
|
int(in_channels // block_size * out_channels),
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(
|
||||||
|
weight.size(0), weight.size(1), device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = (
|
||||||
|
mask.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale weights and apply mask
|
||||||
|
mask = mask.to(
|
||||||
|
torch.bool
|
||||||
|
) # x.bool() is not currently supported in TorchScript
|
||||||
|
s = 1 / (1 - p)
|
||||||
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||||||
|
return module
|
||||||
|
|
||||||
@@ -0,0 +1,250 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
_RATIOS: dict[str, tuple[int, int]] = {
|
||||||
|
"9:16": (9, 16),
|
||||||
|
"4:5": (4, 5),
|
||||||
|
"1:1": (1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||||
|
"""Return an ffmpeg crop= filter expression for the given portrait ratio.
|
||||||
|
|
||||||
|
Uses ffmpeg expression syntax so source dimensions are resolved at runtime.
|
||||||
|
Commas inside min()/max() are escaped with \\, to prevent ffmpeg's
|
||||||
|
filtergraph parser from treating them as filter-chain separators.
|
||||||
|
"""
|
||||||
|
num, den = _RATIOS[ratio]
|
||||||
|
cw = f"ih*{num}/{den}"
|
||||||
|
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||||
|
return f"crop={cw}:ih:{x}:0"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_keyframe(
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
t: float,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||||
|
"""Return the latest keyframe at or before *t*, or None."""
|
||||||
|
result = None
|
||||||
|
for kf in keyframes:
|
||||||
|
if kf[0] <= t + tolerance:
|
||||||
|
result = kf
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_keyframes_to_jobs(
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
base_center: float,
|
||||||
|
base_ratio: str | None,
|
||||||
|
base_rand_p: bool,
|
||||||
|
base_rand_s: bool,
|
||||||
|
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||||
|
"""Resolve each job's crop state from keyframes, returning widened tuples.
|
||||||
|
|
||||||
|
Returns list of (start, path, ratio, center, rand_portrait, rand_square).
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for s, o, _r, _c in jobs:
|
||||||
|
kf = resolve_keyframe(keyframes, s)
|
||||||
|
if kf is not None:
|
||||||
|
_, center, ratio, rp, rs = kf
|
||||||
|
else:
|
||||||
|
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||||
|
result.append((s, o, ratio, center, rp, rs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _find_vaapi_device() -> str:
|
||||||
|
"""Return the first available VAAPI render device path (Linux)."""
|
||||||
|
import glob
|
||||||
|
devices = sorted(glob.glob("/dev/dri/renderD*"))
|
||||||
|
return devices[0] if devices else "/dev/dri/renderD128"
|
||||||
|
|
||||||
|
|
||||||
|
def build_ffmpeg_command(
|
||||||
|
input_path: str, start: float, output_path: str,
|
||||||
|
short_side: int | None = None,
|
||||||
|
portrait_ratio: str | None = None,
|
||||||
|
crop_center: float = 0.5,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
duration: float = 8.0,
|
||||||
|
target_fps: float | None = None,
|
||||||
|
snap32: bool = False,
|
||||||
|
frames: int | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
# -ss before -i: fast input-seeking. Safe here because we always re-encode,
|
||||||
|
# so there is no keyframe-alignment issue from pre-input seek.
|
||||||
|
# Image sequences always use libwebp, so skip HW encoder setup.
|
||||||
|
use_hw_vaapi = (encoder == "h264_vaapi" and not image_sequence
|
||||||
|
and sys.platform == "linux")
|
||||||
|
cmd = [_bin("ffmpeg"), "-y"]
|
||||||
|
|
||||||
|
# VAAPI needs a render device for hardware context (Linux only).
|
||||||
|
if use_hw_vaapi:
|
||||||
|
vaapi_dev = _find_vaapi_device()
|
||||||
|
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||||
|
"-vaapi_device", vaapi_dev]
|
||||||
|
|
||||||
|
cmd += [
|
||||||
|
"-threads", "0",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", str(duration),
|
||||||
|
]
|
||||||
|
|
||||||
|
filters: list[str] = []
|
||||||
|
if portrait_ratio is not None:
|
||||||
|
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||||
|
if short_side is not None:
|
||||||
|
# Scale so the shorter dimension equals short_side.
|
||||||
|
filters.append(
|
||||||
|
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||||
|
)
|
||||||
|
|
||||||
|
# LTX-2: centered crop to ÷32 (no rescale → no aspect distortion) then fps.
|
||||||
|
# Placed among CPU filters, after scale and before the VAAPI hwupload block.
|
||||||
|
if snap32:
|
||||||
|
filters.append("crop=trunc(iw/32)*32:trunc(ih/32)*32")
|
||||||
|
if target_fps is not None:
|
||||||
|
filters.append(f"fps={target_fps:g}")
|
||||||
|
|
||||||
|
# VAAPI: decoded frames are GPU surfaces. CPU filters need hwdownload first.
|
||||||
|
if use_hw_vaapi:
|
||||||
|
if filters:
|
||||||
|
filters.insert(0, "hwdownload")
|
||||||
|
filters.insert(1, "format=nv12")
|
||||||
|
filters.append("format=nv12")
|
||||||
|
filters.append("hwupload")
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
cmd += ["-vf", ",".join(filters)]
|
||||||
|
|
||||||
|
# LTX-2 output rate + exact frame cap (apply to both clip and webp-seq paths).
|
||||||
|
if target_fps is not None:
|
||||||
|
cmd += ["-r", f"{target_fps:g}"]
|
||||||
|
if frames is not None:
|
||||||
|
cmd += ["-frames:v", str(frames)]
|
||||||
|
|
||||||
|
if image_sequence:
|
||||||
|
cmd += [
|
||||||
|
"-an",
|
||||||
|
"-c:v", "libwebp",
|
||||||
|
"-quality", "92",
|
||||||
|
"-compression_level", "1",
|
||||||
|
os.path.join(output_path, "frame_%04d.webp"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cmd += ["-c:v", encoder]
|
||||||
|
if "nvenc" in encoder:
|
||||||
|
cmd += ["-preset", "p4", "-cq", "28"]
|
||||||
|
elif "vaapi" in encoder:
|
||||||
|
cmd += ["-qp", "28"]
|
||||||
|
elif "qsv" in encoder:
|
||||||
|
cmd += ["-global_quality", "28"]
|
||||||
|
elif "amf" in encoder:
|
||||||
|
cmd += ["-qp_i", "28", "-qp_p", "28"]
|
||||||
|
cmd += ["-c:a", "pcm_s16le", output_path]
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str,
|
||||||
|
duration: float = 8.0) -> list[str]:
|
||||||
|
"""Return an ffmpeg command that extracts audio to <sequence_dir>.wav."""
|
||||||
|
audio_path = sequence_dir + ".wav"
|
||||||
|
return [
|
||||||
|
_bin("ffmpeg"), "-y",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", str(duration),
|
||||||
|
"-vn",
|
||||||
|
"-c:a", "pcm_s16le",
|
||||||
|
audio_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Audio codec chosen per output extension for the manual "Extract audio area"
|
||||||
|
# tool. Empty list -> let ffmpeg pick a default encoder from the extension.
|
||||||
|
_AUDIO_CODEC_BY_EXT: dict[str, list[str]] = {
|
||||||
|
".wav": ["-c:a", "pcm_s16le"],
|
||||||
|
".flac": ["-c:a", "flac"],
|
||||||
|
".mp3": ["-c:a", "libmp3lame", "-q:a", "2"],
|
||||||
|
".m4a": ["-c:a", "aac", "-b:a", "256k"],
|
||||||
|
".aac": ["-c:a", "aac", "-b:a", "256k"],
|
||||||
|
".ogg": ["-c:a", "libvorbis", "-q:a", "5"],
|
||||||
|
".opus": ["-c:a", "libopus", "-b:a", "192k"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def probe_duration(path: str) -> float | None:
|
||||||
|
"""Return the media duration in seconds via ffprobe, or None on failure."""
|
||||||
|
try:
|
||||||
|
r = subprocess.run(
|
||||||
|
[_bin("ffprobe"), "-v", "error", "-show_entries", "format=duration",
|
||||||
|
"-of", "default=nw=1:nk=1", path],
|
||||||
|
capture_output=True, text=True, timeout=30,
|
||||||
|
)
|
||||||
|
if r.returncode == 0 and r.stdout.strip():
|
||||||
|
return float(r.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_clip_command(input_path: str, start: float, duration: float,
|
||||||
|
out_path: str) -> list[str]:
|
||||||
|
"""ffmpeg command to extract exactly *duration* seconds of audio starting
|
||||||
|
at *start*, re-encoded per *out_path*'s extension (wav/mp3/flac/…)."""
|
||||||
|
ext = os.path.splitext(out_path)[1].lower()
|
||||||
|
codec = _AUDIO_CODEC_BY_EXT.get(ext, [])
|
||||||
|
return [
|
||||||
|
_bin("ffmpeg"), "-y",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", str(duration),
|
||||||
|
"-vn",
|
||||||
|
*codec,
|
||||||
|
out_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_hw_encoders() -> list[str]:
|
||||||
|
"""Probe ffmpeg for available H.264 hardware encoders.
|
||||||
|
|
||||||
|
Returns only encoders relevant to the current platform:
|
||||||
|
- Windows: h264_nvenc, h264_qsv, h264_amf
|
||||||
|
- Linux: h264_nvenc, h264_vaapi, h264_qsv
|
||||||
|
- macOS: h264_videotoolbox
|
||||||
|
"""
|
||||||
|
if sys.platform == "win32":
|
||||||
|
candidates = ["h264_nvenc", "h264_qsv", "h264_amf"]
|
||||||
|
elif sys.platform == "darwin":
|
||||||
|
candidates = ["h264_videotoolbox"]
|
||||||
|
else:
|
||||||
|
candidates = ["h264_nvenc", "h264_vaapi", "h264_qsv"]
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return []
|
||||||
|
output = result.stdout
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
available = [enc for enc in candidates if re.search(rf'\b{enc}\b', output)]
|
||||||
|
if available:
|
||||||
|
_log(f"HW encoders detected: {', '.join(available)}")
|
||||||
|
else:
|
||||||
|
_log("No HW encoders detected — GPU export unavailable")
|
||||||
|
return available
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""LTX-2 frame-count math. Legal F satisfy F % 8 == 1 (8x temporal + 1)."""
|
||||||
|
|
||||||
|
|
||||||
|
def is_legal_frames(f: int) -> bool:
|
||||||
|
return f >= 9 and f % 8 == 1
|
||||||
|
|
||||||
|
|
||||||
|
def legal_frames(min_f: int = 9, max_f: int = 1000) -> list[int]:
|
||||||
|
start = max(9, min_f + ((1 - min_f) % 8)) # first 8k+1 >= min_f
|
||||||
|
return list(range(start, max_f + 1, 8))
|
||||||
|
|
||||||
|
|
||||||
|
def nearest_legal_frames(f: int) -> int:
|
||||||
|
if f <= 9:
|
||||||
|
return 9
|
||||||
|
low = ((f - 1) // 8) * 8 + 1
|
||||||
|
high = low + 8
|
||||||
|
return low if (f - low) <= (high - f) else high
|
||||||
|
|
||||||
|
|
||||||
|
def duration_for_frames(frames: int, fps: float) -> float:
|
||||||
|
return frames / fps
|
||||||
|
|
||||||
|
|
||||||
|
def frames_for_duration(duration: float, fps: float) -> int:
|
||||||
|
return nearest_legal_frames(round(duration * fps))
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _frozen_path() -> Path:
|
||||||
|
if getattr(sys, "frozen", False):
|
||||||
|
return Path(sys._MEIPASS)
|
||||||
|
return Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _bin(name: str) -> str:
|
||||||
|
"""Resolve a binary name (e.g. 'ffmpeg') to its full path in frozen builds."""
|
||||||
|
p = _frozen_path() / name
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
return name # fall back to PATH
|
||||||
|
|
||||||
|
|
||||||
|
def _log(*args) -> None:
|
||||||
|
"""Print a timestamped log line to stderr."""
|
||||||
|
ts = datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def build_export_path(folder: str, basename: str, counter: int,
|
||||||
|
sub: int | None = None, tag: str | None = None) -> str:
|
||||||
|
"""Build clip output path. *folder* should be the vid folder (e.g. .../mp4/vid_001)."""
|
||||||
|
name = f"{basename}_{counter:03d}"
|
||||||
|
if tag is not None:
|
||||||
|
name = f"{name}_{tag}"
|
||||||
|
if sub is not None:
|
||||||
|
name = f"{name}_{sub}"
|
||||||
|
return os.path.join(folder, name + ".mp4")
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequence_dir(folder: str, basename: str, counter: int,
|
||||||
|
sub: int | None = None, tag: str | None = None) -> str:
|
||||||
|
"""Build WebP sequence output dir. *folder* should be the vid folder."""
|
||||||
|
name = f"{basename}_{counter:03d}"
|
||||||
|
if tag is not None:
|
||||||
|
name = f"{name}_{tag}"
|
||||||
|
if sub is not None:
|
||||||
|
name = f"{name}_{sub}"
|
||||||
|
return os.path.join(folder, name)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
m = int(seconds // 60)
|
||||||
|
# Floor-truncate to 1 dp (not round) — prevents "X:60.0" rollover when
|
||||||
|
# seconds is e.g. 59.95.
|
||||||
|
s = int(seconds % 60 * 10) / 10
|
||||||
|
return f"{m}:{s:04.1f}"
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_yolo():
|
||||||
|
"""Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed."""
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model is None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
_yolo_model = YOLO("yolov8n.pt")
|
||||||
|
_log("YOLO model loaded")
|
||||||
|
except ImportError:
|
||||||
|
_log("ultralytics not installed — tracking disabled")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"YOLO load failed: {e}")
|
||||||
|
return None
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frame_cv(video_path: str, time: float):
|
||||||
|
"""Extract a single frame as a numpy array (BGR) via ffmpeg -> temp PNG -> cv2."""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||||
|
os.close(fd)
|
||||||
|
try:
|
||||||
|
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||||
|
"-frames:v", "1", tmp]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return cv2.imread(tmp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_subject_center(
|
||||||
|
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||||
|
) -> tuple[int | None, float, float] | None:
|
||||||
|
"""Detect objects at *time* and return (class_id, norm_x, norm_y) of the
|
||||||
|
best match to (target_cls, last_x, last_y). Returns None on failure."""
|
||||||
|
model = _get_yolo()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
frame = extract_frame_cv(video_path, time)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
results = model(frame, verbose=False)
|
||||||
|
if not results or len(results[0].boxes) == 0:
|
||||||
|
return None
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
dets = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
cx = (x1 + x2) / 2 / w
|
||||||
|
cy = (y1 + y2) / 2 / h
|
||||||
|
dets.append((cls, cx, cy))
|
||||||
|
# Prefer same class, nearest to last known position.
|
||||||
|
def score(d):
|
||||||
|
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||||
|
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||||
|
return cls_penalty + dist
|
||||||
|
best = min(dets, key=score)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def track_centers_for_jobs(
|
||||||
|
video_path: str, cursor: float, crop_center: float,
|
||||||
|
starts: list[float],
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run detection at the cursor (to identify the target) then at each start
|
||||||
|
time. Returns a list of horizontal crop centers (one per start)."""
|
||||||
|
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||||
|
if ref is None:
|
||||||
|
_log("Tracking: no detection at cursor, using fixed center")
|
||||||
|
return [crop_center] * len(starts)
|
||||||
|
target_cls, last_x, last_y = ref
|
||||||
|
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||||
|
centers = []
|
||||||
|
for t in starts:
|
||||||
|
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||||
|
if det is not None:
|
||||||
|
_, cx, cy = det
|
||||||
|
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||||
|
centers.append(cx)
|
||||||
|
last_x, last_y = cx, cy
|
||||||
|
else:
|
||||||
|
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||||
|
centers.append(last_x)
|
||||||
|
return centers
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
# 8-cut Client Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Build a Tauri + Svelte desktop client that connects to the 8-cut server API for remote video editing. Full feature parity with the Qt app. Targets Linux first, then Mac.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Tauri app (Rust shell + Svelte webview)
|
||||||
|
├── mpv sidecar (bundled binary)
|
||||||
|
│ ├── plays video: http://server/api/stream/{path}?quality=low
|
||||||
|
│ ├── plays audio: http://server/api/audio/{path}
|
||||||
|
│ └── controlled via JSON IPC socket
|
||||||
|
├── Svelte UI
|
||||||
|
│ ├── File browser
|
||||||
|
│ ├── Canvas timeline (markers, cursor, play region)
|
||||||
|
│ ├── Canvas crop overlay
|
||||||
|
│ ├── Export controls + WebSocket progress
|
||||||
|
│ └── Settings panel (profile, subprofiles, quality)
|
||||||
|
└── Rust backend
|
||||||
|
├── Spawn/manage mpv process + IPC
|
||||||
|
├── Proxy server API calls (avoid CORS)
|
||||||
|
└── Tauri commands exposed to Svelte frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
## Playback
|
||||||
|
|
||||||
|
mpv runs as a sidecar process, controlled via JSON IPC socket. Two streams:
|
||||||
|
- Video: `http://server/api/stream/{path}?root={root}&quality={quality}` (transcoded, no audio)
|
||||||
|
- Audio: `http://server/api/audio/{path}?root={root}` (full quality WAV)
|
||||||
|
|
||||||
|
mpv's `--audio-file=` flag syncs both streams with frame-accurate seeking.
|
||||||
|
|
||||||
|
Quality presets: potato (480p), low (720p), medium (1080p), high (original).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### File management
|
||||||
|
- Browse server video roots (`GET /api/roots`, `GET /api/files`)
|
||||||
|
- Hide/unhide files per profile (`POST/DELETE /api/hidden/{filename}`)
|
||||||
|
- Sort by name/size, filter hidden
|
||||||
|
|
||||||
|
### Playback
|
||||||
|
- Play/pause/resume from pause point
|
||||||
|
- AB-loop with current spread/clips settings
|
||||||
|
- Play region adapts to spread changes without restarting
|
||||||
|
- Quality selector
|
||||||
|
|
||||||
|
### Timeline (Canvas)
|
||||||
|
- Cursor position, markers, play position indicator
|
||||||
|
- Click to seek, drag cursor
|
||||||
|
- Lock mode: cursor locked to marker, double-click jumps to end of clip span
|
||||||
|
- Autoclip: when paused, auto-adjust clip count to fit pause position
|
||||||
|
|
||||||
|
### Crop & keyframes
|
||||||
|
- Portrait ratio selector (9:16, 4:5, 1:1, off)
|
||||||
|
- Crop center slider with live canvas overlay
|
||||||
|
- Crop keyframes at arbitrary timeline positions
|
||||||
|
- Subject tracking (triggered server-side)
|
||||||
|
- Random portrait/square toggles
|
||||||
|
|
||||||
|
### Export
|
||||||
|
- Configurable: clips, spread, short side, format (MP4/WebP sequence)
|
||||||
|
- Label + category annotation
|
||||||
|
- Encoder selection (libx264 / h264_nvenc)
|
||||||
|
- Subprofiles with folder suffix routing
|
||||||
|
- Number keys 1-9 for subprofile quick export, E for main
|
||||||
|
- WebSocket progress (`WS /ws/export`), per-clip completion
|
||||||
|
- Delete/re-export from marker context menu
|
||||||
|
|
||||||
|
### Profiles
|
||||||
|
- Profile switcher, markers reload per profile
|
||||||
|
- Subprofile management (add/remove)
|
||||||
|
|
||||||
|
### Settings
|
||||||
|
- Server URL (configurable)
|
||||||
|
- Default quality preset
|
||||||
|
- All settings persisted client-side via Tauri store
|
||||||
|
|
||||||
|
## Server API endpoints used
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/roots
|
||||||
|
GET /api/files?root={root}
|
||||||
|
GET /api/video/{path}?root={root}
|
||||||
|
GET /api/stream/{path}?root={root}&quality={quality}
|
||||||
|
GET /api/audio/{path}?root={root}
|
||||||
|
GET /api/cache/status/{path}?root={root}
|
||||||
|
GET /api/markers/{filename}?profile={profile}
|
||||||
|
GET /api/profiles
|
||||||
|
GET /api/labels
|
||||||
|
POST /api/export
|
||||||
|
GET /api/export/{job_id}
|
||||||
|
DELETE /api/export?output_path={path}
|
||||||
|
POST /api/hidden/{filename}?profile={profile}
|
||||||
|
DELETE /api/hidden/{filename}?profile={profile}
|
||||||
|
GET /api/hidden?profile={profile}
|
||||||
|
WS /ws/export
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
client/
|
||||||
|
├── src-tauri/
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── main.rs (Tauri entry, app setup)
|
||||||
|
│ │ ├── mpv.rs (mpv sidecar spawn + IPC)
|
||||||
|
│ │ ├── commands.rs (Tauri commands for Svelte)
|
||||||
|
│ │ └── lib.rs
|
||||||
|
│ ├── Cargo.toml
|
||||||
|
│ └── tauri.conf.json
|
||||||
|
├── src/
|
||||||
|
│ ├── App.svelte
|
||||||
|
│ ├── lib/
|
||||||
|
│ │ ├── api.ts (server API client)
|
||||||
|
│ │ ├── mpv.ts (mpv IPC bridge via Tauri commands)
|
||||||
|
│ │ ├── ws.ts (WebSocket export progress)
|
||||||
|
│ │ └── stores.ts (Svelte stores: files, markers, settings)
|
||||||
|
│ ├── components/
|
||||||
|
│ │ ├── FileBrowser.svelte
|
||||||
|
│ │ ├── Timeline.svelte
|
||||||
|
│ │ ├── CropOverlay.svelte
|
||||||
|
│ │ ├── ExportPanel.svelte
|
||||||
|
│ │ ├── SettingsPanel.svelte
|
||||||
|
│ │ └── ProfileBar.svelte
|
||||||
|
│ └── main.ts
|
||||||
|
├── package.json
|
||||||
|
└── vite.config.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation order
|
||||||
|
|
||||||
|
1. Scaffold Tauri + Svelte project
|
||||||
|
2. mpv sidecar: spawn, IPC, basic play/pause/seek
|
||||||
|
3. API client module + server connection
|
||||||
|
4. File browser component
|
||||||
|
5. Video playback: load file → stream URL → mpv
|
||||||
|
6. Canvas timeline: cursor, seek, markers
|
||||||
|
7. Export panel + WebSocket progress
|
||||||
|
8. Crop overlay + keyframes
|
||||||
|
9. Lock mode, autoclip, play region
|
||||||
|
10. Profiles, subprofiles, hidden files
|
||||||
|
11. Keyboard shortcuts
|
||||||
|
12. Settings persistence
|
||||||
|
13. Package for Linux (.deb / .AppImage)
|
||||||
|
14. Package for Mac (.dmg)
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
# 8-cut Server API Design
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Run 8-cut as a FastAPI server on Unraid (Docker) so a Tauri desktop client on Mac can edit remotely over WireGuard — no file transfers, no auth.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Unraid (Docker container):
|
||||||
|
FastAPI + ffmpeg + SQLite
|
||||||
|
├── /api/files list videos from mounted volumes
|
||||||
|
├── /api/stream/{path} transcoded video (cached, no audio)
|
||||||
|
├── /api/audio/{path} full-quality audio (cached, passthrough)
|
||||||
|
├── /api/video/{path} raw file (for reference/download)
|
||||||
|
├── /api/markers CRUD markers per profile
|
||||||
|
├── /api/profiles list/create profiles
|
||||||
|
├── /api/export trigger + manage exports
|
||||||
|
├── /api/labels label history
|
||||||
|
├── /api/hidden hidden file management
|
||||||
|
└── ws://…/ws/export real-time export progress
|
||||||
|
|
||||||
|
Mac (Tauri + Svelte + libmpv):
|
||||||
|
├── mpv plays stream URL (video) + audio URL separately
|
||||||
|
├── Canvas timeline + crop overlay + keyframes
|
||||||
|
├── Full UI: profiles, subprofiles, settings
|
||||||
|
└── Stateless — all state lives on server
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker mounts
|
||||||
|
|
||||||
|
| Mount | Purpose | Env var |
|
||||||
|
|-------------|--------------------------------|--------------|
|
||||||
|
| `/videos` | Source video files (read-only) | `MEDIA_DIRS` |
|
||||||
|
| `/exports` | Export output | `EXPORT_DIR` |
|
||||||
|
| `/data` | SQLite DB + transcode cache | `DB_PATH`, `CACHE_DIR` |
|
||||||
|
|
||||||
|
`MEDIA_DIRS` supports multiple paths: `/videos1,/videos2`.
|
||||||
|
|
||||||
|
## Video streaming with transcode cache
|
||||||
|
|
||||||
|
The client needs low-bitrate video for scrubbing over the network but full-quality audio for accurate editing.
|
||||||
|
|
||||||
|
**Flow:**
|
||||||
|
1. Client requests `/api/stream/{path}?quality=low`
|
||||||
|
2. Server checks cache: `{CACHE_DIR}/{quality}/{hash}.mp4`
|
||||||
|
3. If cached → serve with range requests (instant seeking)
|
||||||
|
4. If not → start background ffmpeg transcode, return `202 Accepted` with job ID
|
||||||
|
5. Client polls or gets WebSocket notification when ready
|
||||||
|
6. Audio: `/api/audio/{path}` extracts audio (passthrough, fast) to cache on first request
|
||||||
|
|
||||||
|
**Quality presets:**
|
||||||
|
|
||||||
|
| Preset | Resolution | Bitrate |
|
||||||
|
|----------|-----------|----------|
|
||||||
|
| `potato` | 480p | ~500 Kbps |
|
||||||
|
| `low` | 720p | ~2 Mbps |
|
||||||
|
| `medium` | 1080p | ~5 Mbps |
|
||||||
|
| `high` | original | ~10 Mbps |
|
||||||
|
|
||||||
|
Each quality level cached separately. Client can switch quality — mpv reloads the URL.
|
||||||
|
|
||||||
|
**mpv on client:**
|
||||||
|
```
|
||||||
|
video = http://server/api/stream/file.mp4?quality=low
|
||||||
|
audio = http://server/api/audio/file.mp4
|
||||||
|
```
|
||||||
|
mpv's `--audio-file=` flag plays both in sync with frame-accurate seeking.
|
||||||
|
|
||||||
|
## API endpoints
|
||||||
|
|
||||||
|
### Files
|
||||||
|
```
|
||||||
|
GET /api/files?root={root}
|
||||||
|
→ [{path, name, size, duration?, markers_count}]
|
||||||
|
|
||||||
|
GET /api/video/{path}
|
||||||
|
→ raw file with range requests
|
||||||
|
|
||||||
|
GET /api/stream/{path}?quality=low|medium|high|potato
|
||||||
|
→ cached transcoded video (no audio), range requests
|
||||||
|
→ 202 if transcode in progress
|
||||||
|
|
||||||
|
GET /api/audio/{path}
|
||||||
|
→ cached full-quality audio, range requests
|
||||||
|
→ 202 if extraction in progress
|
||||||
|
|
||||||
|
GET /api/cache/status/{path}
|
||||||
|
→ {qualities: {potato: "ready", low: "transcoding", ...}, audio: "ready"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Markers & profiles
|
||||||
|
```
|
||||||
|
GET /api/markers/{filename}?profile=default
|
||||||
|
→ [{start_time, marker_number, output_path}]
|
||||||
|
|
||||||
|
GET /api/profiles
|
||||||
|
→ ["default", "intense", ...]
|
||||||
|
|
||||||
|
GET /api/labels
|
||||||
|
→ ["dog barking", "rain", ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export
|
||||||
|
```
|
||||||
|
POST /api/export
|
||||||
|
body: {input_path, cursor, folder_suffix?, name, clips, spread,
|
||||||
|
short_side?, portrait_ratio?, crop_center, format,
|
||||||
|
label?, category?, profile, crop_keyframes?,
|
||||||
|
rand_portrait?, rand_square?, track_subject?}
|
||||||
|
→ {job_id}
|
||||||
|
|
||||||
|
GET /api/export/{job_id}
|
||||||
|
→ {status, completed, total, outputs: [...]}
|
||||||
|
|
||||||
|
DELETE /api/export/{output_path}
|
||||||
|
→ delete from DB + disk
|
||||||
|
|
||||||
|
WS /ws/export
|
||||||
|
→ server pushes: {type: "clip_done", path: "..."} | {type: "all_done"} | {type: "error", msg: "..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hidden files
|
||||||
|
```
|
||||||
|
POST /api/hidden/{filename}?profile=default
|
||||||
|
DELETE /api/hidden/{filename}?profile=default
|
||||||
|
GET /api/hidden?profile=default
|
||||||
|
→ ["file1.mp4", "file2.mp4"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code reuse from main.py
|
||||||
|
|
||||||
|
**Extracted to shared module (used by both server and Qt app):**
|
||||||
|
- `ProcessedDB` — SQLite operations
|
||||||
|
- `build_ffmpeg_command` — ffmpeg command construction
|
||||||
|
- `build_audio_extract_command`
|
||||||
|
- `build_export_path` / `build_sequence_dir`
|
||||||
|
- `detect_hw_encoders`
|
||||||
|
- `upsert_clip_annotation` / `remove_clip_annotation`
|
||||||
|
- `apply_keyframes_to_jobs` / `resolve_keyframe`
|
||||||
|
- `track_centers_for_jobs` (subject tracking)
|
||||||
|
|
||||||
|
**Server-specific (new):**
|
||||||
|
- FastAPI app + route handlers
|
||||||
|
- Transcode cache manager
|
||||||
|
- Export worker (plain threading, replaces QThread-based ExportWorker)
|
||||||
|
- File listing / media root scanning
|
||||||
|
- WebSocket export progress broadcaster
|
||||||
|
|
||||||
|
**Tauri client (new, Svelte):**
|
||||||
|
- mpv integration via Tauri plugin or sidecar
|
||||||
|
- Canvas-based timeline widget
|
||||||
|
- Canvas-based crop overlay
|
||||||
|
- All UI controls
|
||||||
|
- API client module
|
||||||
|
|
||||||
|
## Dockerfile
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM python:3.12-slim
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
WORKDIR /app
|
||||||
|
COPY server/ .
|
||||||
|
RUN pip install --no-cache-dir fastapi uvicorn
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
8-cut/
|
||||||
|
├── main.py (existing Qt app, unchanged)
|
||||||
|
├── core/ (shared logic, extracted from main.py)
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── db.py (ProcessedDB)
|
||||||
|
│ ├── ffmpeg.py (build commands, detect encoders)
|
||||||
|
│ ├── export.py (ExportWorker — plain threading)
|
||||||
|
│ ├── paths.py (build_export_path, build_sequence_dir)
|
||||||
|
│ └── annotations.py (dataset.json helpers)
|
||||||
|
├── server/
|
||||||
|
│ ├── app.py (FastAPI app)
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── files.py
|
||||||
|
│ │ ├── stream.py
|
||||||
|
│ │ ├── markers.py
|
||||||
|
│ │ ├── export.py
|
||||||
|
│ │ └── hidden.py
|
||||||
|
│ ├── cache.py (transcode cache manager)
|
||||||
|
│ ├── ws.py (WebSocket handler)
|
||||||
|
│ └── config.py (env vars, settings)
|
||||||
|
├── client/ (Tauri + Svelte — future)
|
||||||
|
│ └── ...
|
||||||
|
├── Dockerfile
|
||||||
|
└── docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation order
|
||||||
|
|
||||||
|
1. Extract shared logic from main.py → `core/`
|
||||||
|
2. Update main.py to import from `core/` (verify Qt app still works)
|
||||||
|
3. Build FastAPI server with file listing + video serving
|
||||||
|
4. Add transcode cache + audio extraction
|
||||||
|
5. Add markers/profiles/labels/hidden API
|
||||||
|
6. Add export endpoint + WebSocket progress
|
||||||
|
7. Dockerfile + docker-compose
|
||||||
|
8. (Later) Tauri client
|
||||||
@@ -0,0 +1,948 @@
|
|||||||
|
# Server API Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Extract shared logic from main.py into a `core/` package, then build the FastAPI server that serves video files, manages the DB, and runs exports.
|
||||||
|
|
||||||
|
**Architecture:** Shared logic (DB, ffmpeg, paths, annotations, tracking) moves to `core/`. Both `main.py` (Qt app) and `server/` import from `core/`. The server adds HTTP video streaming with transcode cache, REST endpoints, and WebSocket export progress.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.12, FastAPI, uvicorn, SQLite, ffmpeg
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Create core/ package — paths and helpers
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/__init__.py`
|
||||||
|
- Create: `core/paths.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/__init__.py**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create core/paths.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _frozen_path() -> Path:
|
||||||
|
if getattr(sys, "frozen", False):
|
||||||
|
return Path(sys._MEIPASS)
|
||||||
|
return Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _bin(name: str) -> str:
|
||||||
|
p = _frozen_path() / name
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _log(*args) -> None:
|
||||||
|
ts = datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name + ".mp4")
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||||
|
group = f"{basename}_{counter:03d}"
|
||||||
|
name = f"{group}_{sub}" if sub is not None else group
|
||||||
|
return os.path.join(folder, group, name)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
m = int(seconds // 60)
|
||||||
|
s = int(seconds % 60 * 10) / 10
|
||||||
|
return f"{m}:{s:04.1f}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/
|
||||||
|
git commit -m "feat: create core/paths module with shared path helpers"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Create core/ffmpeg.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/ffmpeg.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/ffmpeg.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 77-112 and 244-289: `_RATIOS`, `_portrait_crop_filter`, `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`, `detect_hw_encoders`. (Lines 115-188 are also ffmpeg-related. Lines 191-241 are annotations — extracted separately in Task 4.)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
_RATIOS: dict[str, tuple[int, int]] = {
|
||||||
|
"9:16": (9, 16),
|
||||||
|
"4:5": (4, 5),
|
||||||
|
"1:1": (1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||||
|
num, den = _RATIOS[ratio]
|
||||||
|
cw = f"ih*{num}/{den}"
|
||||||
|
x = f"max(0\\,min((iw-{cw})*{crop_center}\\,iw-{cw}))"
|
||||||
|
return f"crop={cw}:ih:{x}:0"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_keyframe(
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
t: float,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
) -> tuple[float, float, str | None, bool, bool] | None:
|
||||||
|
result = None
|
||||||
|
for kf in keyframes:
|
||||||
|
if kf[0] <= t + tolerance:
|
||||||
|
result = kf
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_keyframes_to_jobs(
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
keyframes: list[tuple[float, float, str | None, bool, bool]],
|
||||||
|
base_center: float,
|
||||||
|
base_ratio: str | None,
|
||||||
|
base_rand_p: bool,
|
||||||
|
base_rand_s: bool,
|
||||||
|
) -> list[tuple[float, str, str | None, float, bool, bool]]:
|
||||||
|
result = []
|
||||||
|
for s, o, _r, _c in jobs:
|
||||||
|
kf = resolve_keyframe(keyframes, s)
|
||||||
|
if kf is not None:
|
||||||
|
_, center, ratio, rp, rs = kf
|
||||||
|
else:
|
||||||
|
center, ratio, rp, rs = base_center, base_ratio, base_rand_p, base_rand_s
|
||||||
|
result.append((s, o, ratio, center, rp, rs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_ffmpeg_command(
|
||||||
|
input_path: str, start: float, output_path: str,
|
||||||
|
short_side: int | None = None,
|
||||||
|
portrait_ratio: str | None = None,
|
||||||
|
crop_center: float = 0.5,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
) -> list[str]:
|
||||||
|
use_hw_vaapi = encoder == "h264_vaapi" and not image_sequence
|
||||||
|
cmd = [_bin("ffmpeg"), "-y"]
|
||||||
|
if use_hw_vaapi:
|
||||||
|
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||||
|
"-vaapi_device", "/dev/dri/renderD128"]
|
||||||
|
cmd += ["-threads", "0", "-ss", str(start), "-i", input_path, "-t", "8"]
|
||||||
|
filters: list[str] = []
|
||||||
|
if portrait_ratio is not None:
|
||||||
|
filters.append(_portrait_crop_filter(portrait_ratio, crop_center))
|
||||||
|
if short_side is not None:
|
||||||
|
filters.append(
|
||||||
|
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||||
|
)
|
||||||
|
if use_hw_vaapi:
|
||||||
|
if filters:
|
||||||
|
filters.insert(0, "hwdownload")
|
||||||
|
filters.insert(1, "format=nv12")
|
||||||
|
filters.append("format=nv12")
|
||||||
|
filters.append("hwupload")
|
||||||
|
if filters:
|
||||||
|
cmd += ["-vf", ",".join(filters)]
|
||||||
|
if image_sequence:
|
||||||
|
cmd += ["-an", "-c:v", "libwebp", "-quality", "92", "-compression_level", "1",
|
||||||
|
os.path.join(output_path, "frame_%04d.webp")]
|
||||||
|
else:
|
||||||
|
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||||
|
audio_path = sequence_dir + ".wav"
|
||||||
|
return [_bin("ffmpeg"), "-y", "-ss", str(start), "-i", input_path,
|
||||||
|
"-t", "8", "-vn", "-c:a", "pcm_s16le", audio_path]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_hw_encoders() -> list[str]:
|
||||||
|
_HW_ENCODERS = ["h264_nvenc", "h264_vaapi", "h264_qsv", "h264_amf", "h264_videotoolbox"]
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return []
|
||||||
|
output = result.stdout
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
available = []
|
||||||
|
for enc in _HW_ENCODERS:
|
||||||
|
if re.search(rf'\b{enc}\b', output):
|
||||||
|
available.append(enc)
|
||||||
|
if available:
|
||||||
|
_log(f"HW encoders detected: {', '.join(available)}")
|
||||||
|
else:
|
||||||
|
_log("No HW encoders detected — GPU export unavailable")
|
||||||
|
return available
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/ffmpeg.py
|
||||||
|
git commit -m "feat: create core/ffmpeg module with ffmpeg helpers"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Create core/db.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/db.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/db.py**
|
||||||
|
|
||||||
|
Extract the entire `ProcessedDB` class from main.py lines 398-626. Import `_log` from `core.paths`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .paths import _log
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessedDB:
|
||||||
|
_SCHEMA_VERSION = 3
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | None = None):
|
||||||
|
# ... exact copy of existing class ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Copy the full class body verbatim — all methods unchanged.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/db.py
|
||||||
|
git commit -m "feat: create core/db module with ProcessedDB"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Create core/annotations.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/annotations.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/annotations.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 191-241: `build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def build_annotation_json_path(folder: str) -> str:
|
||||||
|
return os.path.join(folder, "dataset.json")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_clip_annotation(folder: str, clip_path: str) -> None:
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
if not os.path.exists(json_path):
|
||||||
|
return
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return
|
||||||
|
entries = [e for e in entries if e.get("path") != abs_path]
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
||||||
|
if not label.strip():
|
||||||
|
return
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
json_path = build_annotation_json_path(folder)
|
||||||
|
entries: list[dict] = []
|
||||||
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
entries = json.load(f)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
entries = []
|
||||||
|
abs_path = os.path.abspath(clip_path)
|
||||||
|
entry: dict = {"path": abs_path, "label": label}
|
||||||
|
for i, e in enumerate(entries):
|
||||||
|
if e.get("path") == abs_path:
|
||||||
|
entries[i] = entry
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
entries.append(entry)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(entries, f, indent=2, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/annotations.py
|
||||||
|
git commit -m "feat: create core/annotations module"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Create core/export.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/export.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/export.py**
|
||||||
|
|
||||||
|
A plain-threading version of `ExportWorker` (no QThread dependency). Used by the server. The Qt app continues using its own QThread-based worker.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .ffmpeg import build_ffmpeg_command, build_audio_extract_command
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
|
||||||
|
class ExportRunner:
|
||||||
|
"""Run ffmpeg export jobs in a background thread pool.
|
||||||
|
|
||||||
|
Callbacks:
|
||||||
|
on_clip_done(path: str)
|
||||||
|
on_all_done()
|
||||||
|
on_error(msg: str)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_path: str,
|
||||||
|
jobs: list[tuple[float, str, str | None, float]],
|
||||||
|
short_side: int | None = None,
|
||||||
|
image_sequence: bool = False,
|
||||||
|
max_workers: int | None = None,
|
||||||
|
encoder: str = "libx264",
|
||||||
|
on_clip_done: Callable[[str], None] | None = None,
|
||||||
|
on_all_done: Callable[[], None] | None = None,
|
||||||
|
on_error: Callable[[str], None] | None = None,
|
||||||
|
):
|
||||||
|
self._input = input_path
|
||||||
|
self._jobs = jobs
|
||||||
|
self._short_side = short_side
|
||||||
|
self._image_sequence = image_sequence
|
||||||
|
self._max_workers = max_workers
|
||||||
|
self._encoder = encoder
|
||||||
|
self._on_clip_done = on_clip_done
|
||||||
|
self._on_all_done = on_all_done
|
||||||
|
self._on_error = on_error
|
||||||
|
self._cancel = False
|
||||||
|
self._procs: list[subprocess.Popen] = []
|
||||||
|
self._procs_lock = threading.Lock()
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
self._cancel = True
|
||||||
|
with self._procs_lock:
|
||||||
|
for proc in self._procs:
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._thread is not None and self._thread.is_alive()
|
||||||
|
|
||||||
|
def _run_one(self, start: float, output: str,
|
||||||
|
portrait_ratio: str | None, crop_center: float) -> str:
|
||||||
|
if self._cancel:
|
||||||
|
raise RuntimeError("cancelled")
|
||||||
|
if self._image_sequence:
|
||||||
|
os.makedirs(output, exist_ok=True)
|
||||||
|
cmd = build_ffmpeg_command(
|
||||||
|
self._input, start, output,
|
||||||
|
short_side=self._short_side,
|
||||||
|
portrait_ratio=portrait_ratio,
|
||||||
|
crop_center=crop_center,
|
||||||
|
image_sequence=self._image_sequence,
|
||||||
|
encoder=self._encoder,
|
||||||
|
)
|
||||||
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
with self._procs_lock:
|
||||||
|
self._procs.append(proc)
|
||||||
|
try:
|
||||||
|
_, stderr = proc.communicate(timeout=120)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
raise RuntimeError("ffmpeg timed out")
|
||||||
|
finally:
|
||||||
|
with self._procs_lock:
|
||||||
|
self._procs.remove(proc)
|
||||||
|
if self._cancel:
|
||||||
|
raise RuntimeError("cancelled")
|
||||||
|
if proc.returncode != 0:
|
||||||
|
msg = stderr.decode(errors='replace')[-500:] if stderr else "ffmpeg failed"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
if self._image_sequence:
|
||||||
|
audio_cmd = build_audio_extract_command(self._input, start, output)
|
||||||
|
subprocess.run(audio_cmd, capture_output=True, text=True, timeout=60)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
cap = self._max_workers or (os.cpu_count() or 2)
|
||||||
|
workers = min(len(self._jobs), cap)
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(self._run_one, s, o, pr, cc): o
|
||||||
|
for s, o, pr, cc in self._jobs
|
||||||
|
}
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
if self._cancel:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
path = fut.result()
|
||||||
|
if self._on_clip_done:
|
||||||
|
self._on_clip_done(path)
|
||||||
|
except Exception as e:
|
||||||
|
if "cancelled" not in str(e) and self._on_error:
|
||||||
|
self._on_error(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
if self._on_error:
|
||||||
|
self._on_error(str(e))
|
||||||
|
return
|
||||||
|
if self._cancel:
|
||||||
|
return
|
||||||
|
if self._on_all_done:
|
||||||
|
self._on_all_done()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/export.py
|
||||||
|
git commit -m "feat: create core/export module with ExportRunner"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Create core/tracking.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/tracking.py`
|
||||||
|
|
||||||
|
**Step 1: Create core/tracking.py**
|
||||||
|
|
||||||
|
Extract from main.py lines 294-395: YOLO tracking functions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_yolo():
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model is None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
_yolo_model = YOLO("yolov8n.pt")
|
||||||
|
_log("YOLO model loaded")
|
||||||
|
except ImportError:
|
||||||
|
_log("ultralytics not installed — tracking disabled")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"YOLO load failed: {e}")
|
||||||
|
return None
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frame_cv(video_path: str, time: float):
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||||
|
os.close(fd)
|
||||||
|
try:
|
||||||
|
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||||
|
"-frames:v", "1", tmp]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return cv2.imread(tmp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_subject_center(
|
||||||
|
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||||
|
) -> tuple[int | None, float, float] | None:
|
||||||
|
model = _get_yolo()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
frame = extract_frame_cv(video_path, time)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
results = model(frame, verbose=False)
|
||||||
|
if not results or len(results[0].boxes) == 0:
|
||||||
|
return None
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
dets = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
cx = (x1 + x2) / 2 / w
|
||||||
|
cy = (y1 + y2) / 2 / h
|
||||||
|
dets.append((cls, cx, cy))
|
||||||
|
def score(d):
|
||||||
|
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||||
|
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||||
|
return cls_penalty + dist
|
||||||
|
best = min(dets, key=score)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def track_centers_for_jobs(
|
||||||
|
video_path: str, cursor: float, crop_center: float,
|
||||||
|
starts: list[float],
|
||||||
|
) -> list[float]:
|
||||||
|
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||||
|
if ref is None:
|
||||||
|
_log("Tracking: no detection at cursor, using fixed center")
|
||||||
|
return [crop_center] * len(starts)
|
||||||
|
target_cls, last_x, last_y = ref
|
||||||
|
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||||
|
centers = []
|
||||||
|
for t in starts:
|
||||||
|
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||||
|
if det is not None:
|
||||||
|
_, cx, cy = det
|
||||||
|
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||||
|
centers.append(cx)
|
||||||
|
last_x, last_y = cx, cy
|
||||||
|
else:
|
||||||
|
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||||
|
centers.append(last_x)
|
||||||
|
return centers
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/tracking.py
|
||||||
|
git commit -m "feat: create core/tracking module with YOLO subject tracking"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Update main.py to import from core/
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py`
|
||||||
|
|
||||||
|
**Step 1: Replace function definitions with imports**
|
||||||
|
|
||||||
|
At the top of main.py, after the existing stdlib imports (line 17), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from core.paths import _bin, _log, build_export_path, build_sequence_dir, format_time
|
||||||
|
from core.ffmpeg import (
|
||||||
|
_RATIOS, resolve_keyframe, apply_keyframes_to_jobs,
|
||||||
|
build_ffmpeg_command, build_audio_extract_command, detect_hw_encoders,
|
||||||
|
)
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
from core.annotations import remove_clip_annotation, upsert_clip_annotation
|
||||||
|
from core.tracking import track_centers_for_jobs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Delete the extracted function definitions and dead imports**
|
||||||
|
|
||||||
|
Remove definitions from main.py:
|
||||||
|
- Lines 36-74: `_frozen_path`, `_bin`, `_log`, `build_export_path`, `build_sequence_dir`, `format_time`
|
||||||
|
- Lines 77-188: `resolve_keyframe`, `apply_keyframes_to_jobs`, `build_ffmpeg_command`, `build_audio_extract_command`
|
||||||
|
- Lines 191-241: annotation functions (`build_annotation_json_path`, `remove_clip_annotation`, `upsert_clip_annotation`)
|
||||||
|
- Lines 244-289: `detect_hw_encoders`, `_RATIOS`, `_portrait_crop_filter`
|
||||||
|
- Lines 294-395: tracking functions (`_yolo_model`, `_get_yolo`, `extract_frame_cv`, `detect_subject_center`, `track_centers_for_jobs`)
|
||||||
|
- Lines 398-626: `ProcessedDB` class
|
||||||
|
|
||||||
|
Remove now-dead stdlib imports from the top of main.py:
|
||||||
|
- `re` (only used in `detect_hw_encoders`)
|
||||||
|
- `json` (only used in annotation functions)
|
||||||
|
- `sqlite3` (only used in `ProcessedDB`)
|
||||||
|
- `tempfile` (only used in `extract_frame_cv`)
|
||||||
|
- `datetime`, `timezone` from the datetime import (only used in `_log` and `ProcessedDB`)
|
||||||
|
|
||||||
|
Keep in main.py:
|
||||||
|
- `_SELVA_CATEGORIES` (UI constant, line 291)
|
||||||
|
- `_RATIOS` reference — imported from core.ffmpeg
|
||||||
|
- `ExportWorker` (QThread-based, stays in main.py — the server uses `core.export.ExportRunner` instead)
|
||||||
|
- `_DBWorker` and `FrameGrabber` (QThread-based, stay in main.py)
|
||||||
|
|
||||||
|
**Step 3: Verify Qt app still works**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Open a video, export a clip, check markers — verify nothing broke.
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "refactor: import shared logic from core/ instead of inline definitions"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: Create server/config.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/__init__.py` (empty package marker)
|
||||||
|
- Create: `server/config.py`
|
||||||
|
|
||||||
|
**Step 1: Create `server/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create config**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
MEDIA_DIRS: list[str] = [
|
||||||
|
d.strip() for d in os.environ.get("MEDIA_DIRS", str(Path.home())).split(",") if d.strip()
|
||||||
|
]
|
||||||
|
EXPORT_DIR: str = os.environ.get("EXPORT_DIR", str(Path.home() / "8cut-exports"))
|
||||||
|
DB_PATH: str = os.environ.get("DB_PATH", str(Path.home() / ".8cut.db"))
|
||||||
|
CACHE_DIR: str = os.environ.get("CACHE_DIR", str(Path.home() / ".8cut-cache"))
|
||||||
|
HOST: str = os.environ.get("HOST", "0.0.0.0")
|
||||||
|
PORT: int = int(os.environ.get("PORT", "8000"))
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".ts", ".flv", ".wmv"}
|
||||||
|
|
||||||
|
QUALITY_PRESETS = {
|
||||||
|
"potato": {"height": 480, "bitrate": "500k"},
|
||||||
|
"low": {"height": 720, "bitrate": "2M"},
|
||||||
|
"medium": {"height": 1080, "bitrate": "5M"},
|
||||||
|
"high": {"height": 0, "bitrate": "10M"}, # 0 = original resolution
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/
|
||||||
|
git commit -m "feat: create server/config with env var settings and quality presets"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 9: Create server/app.py — FastAPI skeleton + file listing
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/app.py`
|
||||||
|
- Create: `server/routes/__init__.py`
|
||||||
|
- Create: `server/routes/files.py`
|
||||||
|
|
||||||
|
**Step 1: Create FastAPI app**
|
||||||
|
|
||||||
|
`server/app.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from .routes import files, stream, markers, export, hidden
|
||||||
|
|
||||||
|
app = FastAPI(title="8-cut Server")
|
||||||
|
app.include_router(files.router, prefix="/api")
|
||||||
|
app.include_router(stream.router, prefix="/api")
|
||||||
|
app.include_router(markers.router, prefix="/api")
|
||||||
|
app.include_router(export.router, prefix="/api")
|
||||||
|
app.include_router(hidden.router, prefix="/api")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create file listing route**
|
||||||
|
|
||||||
|
`server/routes/files.py`:
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_videos(root: str) -> list[dict]:
|
||||||
|
results = []
|
||||||
|
for dirpath, _, filenames in os.walk(root):
|
||||||
|
for f in sorted(filenames):
|
||||||
|
if os.path.splitext(f)[1].lower() in VIDEO_EXTENSIONS:
|
||||||
|
full = os.path.join(dirpath, f)
|
||||||
|
rel = os.path.relpath(full, root)
|
||||||
|
results.append({
|
||||||
|
"name": f,
|
||||||
|
"path": rel,
|
||||||
|
"root": root,
|
||||||
|
"size": os.path.getsize(full),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files")
|
||||||
|
def list_files(root: str | None = Query(None)):
|
||||||
|
dirs = [root] if root and root in MEDIA_DIRS else MEDIA_DIRS
|
||||||
|
files = []
|
||||||
|
for d in dirs:
|
||||||
|
files.extend(_scan_videos(d))
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/roots")
|
||||||
|
def list_roots():
|
||||||
|
return MEDIA_DIRS
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Create `server/routes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# empty — package marker
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Create stub routers** so app.py imports don't fail. Each file gets a minimal router — later tasks fill in the real endpoints.
|
||||||
|
|
||||||
|
`server/routes/stream.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/markers.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/export.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
`server/routes/hidden.py`:
|
||||||
|
```python
|
||||||
|
from fastapi import APIRouter
|
||||||
|
router = APIRouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/
|
||||||
|
git commit -m "feat: add FastAPI app with file listing endpoint"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 10: Create server/routes/stream.py — video serving + transcode cache
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/cache.py`
|
||||||
|
- Create: `server/routes/stream.py`
|
||||||
|
|
||||||
|
**Step 1: Create cache manager**
|
||||||
|
|
||||||
|
`server/cache.py` handles:
|
||||||
|
- Computing cache paths from source file hash + quality
|
||||||
|
- Checking cache status
|
||||||
|
- Launching background ffmpeg transcodes
|
||||||
|
- Tracking in-progress jobs
|
||||||
|
|
||||||
|
**Step 2: Create stream routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/video/{path} — raw file, range requests
|
||||||
|
GET /api/stream/{path}?quality=low — cached transcode, range requests (202 if not ready)
|
||||||
|
GET /api/audio/{path} — cached audio extraction, range requests (202 if not ready)
|
||||||
|
GET /api/cache/status/{path} — cache status for all qualities
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/cache.py server/routes/stream.py
|
||||||
|
git commit -m "feat: add video streaming with transcode cache and audio extraction"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 11: Create server/routes/markers.py — DB endpoints
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/markers.py`
|
||||||
|
|
||||||
|
**Step 1: Create markers/profiles/labels routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/markers/{filename}?profile=default
|
||||||
|
GET /api/profiles
|
||||||
|
GET /api/labels
|
||||||
|
```
|
||||||
|
|
||||||
|
Uses `ProcessedDB` singleton from `core.db`.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/markers.py
|
||||||
|
git commit -m "feat: add markers, profiles, and labels API endpoints"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 12: Create server/routes/export.py + WebSocket
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/export.py`
|
||||||
|
- Create: `server/ws.py`
|
||||||
|
|
||||||
|
**Step 1: Create export routes + WS**
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/export — start export job
|
||||||
|
GET /api/export/{id} — check job status
|
||||||
|
DELETE /api/export/{path} — delete export from DB + disk
|
||||||
|
WS /ws/export — real-time progress
|
||||||
|
```
|
||||||
|
|
||||||
|
Uses `ExportRunner` from `core.export`.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/export.py server/ws.py
|
||||||
|
git commit -m "feat: add export endpoint with WebSocket progress"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 13: Create server/routes/hidden.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `server/routes/hidden.py`
|
||||||
|
|
||||||
|
**Step 1: Create hidden file routes**
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/hidden/{filename}?profile=default
|
||||||
|
DELETE /api/hidden/{filename}?profile=default
|
||||||
|
GET /api/hidden?profile=default
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add server/routes/hidden.py
|
||||||
|
git commit -m "feat: add hidden files API endpoints"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 14: Create Dockerfile + docker-compose.yml
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `Dockerfile`
|
||||||
|
- Create: `docker-compose.yml`
|
||||||
|
|
||||||
|
**Step 1: Create Dockerfile**
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM python:3.12-slim
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
WORKDIR /app
|
||||||
|
COPY core/ core/
|
||||||
|
COPY server/ server/
|
||||||
|
# Note: ultralytics + opencv-python needed only if subject tracking is used.
|
||||||
|
# Add them here if tracking is required on the server.
|
||||||
|
RUN pip install --no-cache-dir fastapi uvicorn
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Create docker-compose.yml**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
8cut:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- /path/to/videos:/videos:ro
|
||||||
|
- /path/to/exports:/exports
|
||||||
|
- 8cut-data:/data
|
||||||
|
environment:
|
||||||
|
MEDIA_DIRS: /videos
|
||||||
|
EXPORT_DIR: /exports
|
||||||
|
DB_PATH: /data/8cut.db
|
||||||
|
CACHE_DIR: /data/cache
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
8cut-data:
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add Dockerfile docker-compose.yml
|
||||||
|
git commit -m "feat: add Dockerfile and docker-compose for server deployment"
|
||||||
|
```
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
# Audio Similarity Scanning — Design
|
||||||
|
|
||||||
|
**Goal:** Scan a video's audio track and highlight segments that match the sound profile of existing reference clips, so the user can quickly find similar moments without scrubbing manually.
|
||||||
|
|
||||||
|
**Runs in:** Python/Qt client (`main.py`), not the server.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Core Module: `core/audio_scan.py`
|
||||||
|
|
||||||
|
New module alongside `core/tracking.py`. Two main functions:
|
||||||
|
|
||||||
|
- `build_profile(clip_paths: list[str]) -> dict` — extracts MFCCs (20 coefficients) from each clip using `librosa`, returns a profile containing both the averaged vector and individual clip vectors.
|
||||||
|
- `scan_video(video_path: str, profile: dict, mode: str, threshold: float, hop: float) -> list[tuple[float, float, float]]` — slides an 8s window across the video's audio, returns `(start_time, end_time, score)` tuples for segments above threshold.
|
||||||
|
|
||||||
|
### Feature Extraction
|
||||||
|
|
||||||
|
- Audio loaded via `librosa.load()` (handles video files directly, mono, 22050Hz).
|
||||||
|
- MFCCs: `librosa.feature.mfcc(n_mfcc=20)`, averaged over time axis to produce a single vector per window/clip.
|
||||||
|
- Similarity: cosine similarity (`numpy` dot product on L2-normalized vectors).
|
||||||
|
|
||||||
|
### Matching Modes
|
||||||
|
|
||||||
|
- **Average mode:** Compare each window to the mean of all reference MFCC vectors. Fast, good when references are homogeneous.
|
||||||
|
- **Nearest mode:** Compare each window to every reference vector, take the max score. Better when references have variety within the style.
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
- `threshold` (float, 0.0–1.0): minimum cosine similarity to include a segment. Default 0.7.
|
||||||
|
- `hop` (float, seconds): step size for the sliding window. Default 1.0s.
|
||||||
|
- Window size fixed at 8s to match reference clip length.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## UI Integration in `main.py`
|
||||||
|
|
||||||
|
### Controls
|
||||||
|
|
||||||
|
Added near the existing tracking checkbox area:
|
||||||
|
|
||||||
|
- **"Scan" button** — triggers audio scan on current video.
|
||||||
|
- **Threshold slider** (0.0–1.0, step 0.05) — controls match strictness.
|
||||||
|
- **Mode combobox** — "Average" / "Nearest".
|
||||||
|
- **Reference source combobox** — "Current Profile" / "Custom Folder" (shows folder picker when "Custom Folder" selected).
|
||||||
|
|
||||||
|
### Scan Workflow
|
||||||
|
|
||||||
|
1. User clicks Scan.
|
||||||
|
2. Reference clips collected: either all export `output_path` values from the current profile (via DB) or all audio/video files in a custom folder.
|
||||||
|
3. Scan runs in a `QThread` so UI stays responsive.
|
||||||
|
4. On completion, results sent to Timeline widget via signal.
|
||||||
|
|
||||||
|
### Timeline Display
|
||||||
|
|
||||||
|
- New `set_scan_regions(regions: list[tuple[float, float, float]])` method on Timeline.
|
||||||
|
- Drawn as semi-transparent colored rectangles behind existing markers.
|
||||||
|
- Color intensity proportional to score (brighter = higher match).
|
||||||
|
- Cleared on file change or re-scan.
|
||||||
|
|
||||||
|
### Keyboard Shortcut
|
||||||
|
|
||||||
|
- `S` — jump cursor to the next scan region (similar to `M` for next marker).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
Reference clips (DB export paths or folder)
|
||||||
|
|
|
||||||
|
librosa.load() each -> MFCC vectors (20-dim)
|
||||||
|
|
|
||||||
|
Profile: { mean_vector, clip_vectors[] }
|
||||||
|
|
|
||||||
|
Current video -> librosa.load() full audio (mono 22050Hz)
|
||||||
|
|
|
||||||
|
Sliding 8s window (hop=1s) -> MFCC per window
|
||||||
|
|
|
||||||
|
Cosine similarity vs profile -> score per position
|
||||||
|
|
|
||||||
|
Threshold filter -> [(start, end, score), ...]
|
||||||
|
|
|
||||||
|
Timeline: semi-transparent highlight regions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
- 2-hour video at 22050Hz mono ~ 380MB memory.
|
||||||
|
- MFCC extraction + sliding window: ~10-30s.
|
||||||
|
- QThread keeps UI responsive.
|
||||||
|
|
||||||
|
## What This Does NOT Do
|
||||||
|
|
||||||
|
- No DB schema changes — scan results are ephemeral (visual only).
|
||||||
|
- No auto-export — user decides what to cut.
|
||||||
|
- No server integration — runs entirely in the Python client.
|
||||||
|
- No GPU/ML model dependency — just librosa + numpy.
|
||||||
@@ -0,0 +1,739 @@
|
|||||||
|
# Audio Similarity Scanning — Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Scan a video's audio track to find segments matching a reference sound profile, displayed as highlighted regions on the timeline.
|
||||||
|
|
||||||
|
**Architecture:** New `core/audio_scan.py` module extracts MFCC features from reference clips and slides an 8s window across the target video's audio, scoring each position via cosine similarity. A `ScanWorker` QThread runs the scan in the background, and results are drawn as semi-transparent rectangles on the existing Timeline widget.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3, librosa 0.11, numpy, PyQt6
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Core audio_scan module — build_profile
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `core/audio_scan.py`
|
||||||
|
- Create: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the tests**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# tests/test_audio_scan.py
|
||||||
|
import tempfile, os
|
||||||
|
import numpy as np
|
||||||
|
from core.audio_scan import build_profile, _extract_mfcc
|
||||||
|
|
||||||
|
|
||||||
|
def _make_wav(path: str, duration: float = 8.0, sr: int = 22050):
|
||||||
|
"""Create a short sine-wave WAV file for testing."""
|
||||||
|
import soundfile as sf
|
||||||
|
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
|
||||||
|
audio = 0.5 * np.sin(2 * np.pi * 440 * t)
|
||||||
|
sf.write(path, audio, sr)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_mfcc_returns_1d_vector():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
vec = _extract_mfcc(f.name)
|
||||||
|
assert vec.shape == (20,)
|
||||||
|
assert not np.isnan(vec).any()
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_single_clip():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
profile = build_profile([f.name])
|
||||||
|
assert "mean_vector" in profile
|
||||||
|
assert "clip_vectors" in profile
|
||||||
|
assert profile["mean_vector"].shape == (20,)
|
||||||
|
assert len(profile["clip_vectors"]) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_multiple_clips():
|
||||||
|
paths = []
|
||||||
|
try:
|
||||||
|
for i in range(3):
|
||||||
|
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||||
|
freq = 440 + i * 200
|
||||||
|
import soundfile as sf
|
||||||
|
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||||
|
sf.write(f.name, 0.5 * np.sin(2 * np.pi * freq * t), 22050)
|
||||||
|
paths.append(f.name)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
profile = build_profile(paths)
|
||||||
|
assert len(profile["clip_vectors"]) == 3
|
||||||
|
assert profile["mean_vector"].shape == (20,)
|
||||||
|
finally:
|
||||||
|
for p in paths:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_skips_missing_files():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
_make_wav(f.name)
|
||||||
|
try:
|
||||||
|
profile = build_profile([f.name, "/no/such/file.wav"])
|
||||||
|
assert len(profile["clip_vectors"]) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_profile_empty_returns_none():
|
||||||
|
result = build_profile([])
|
||||||
|
assert result is None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run tests to verify they fail**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: FAIL with `ModuleNotFoundError: No module named 'core.audio_scan'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# core/audio_scan.py
|
||||||
|
"""Audio similarity scanning — MFCC-based profile matching."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
from .paths import _log
|
||||||
|
|
||||||
|
_N_MFCC = 20
|
||||||
|
_SR = 22050
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_mfcc(path: str, sr: int = _SR) -> np.ndarray:
|
||||||
|
"""Load audio from a file and return a mean MFCC vector (20-dim)."""
|
||||||
|
y, _ = librosa.load(path, sr=sr, mono=True)
|
||||||
|
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=_N_MFCC)
|
||||||
|
return mfcc.mean(axis=1) # average over time → (20,)
|
||||||
|
|
||||||
|
|
||||||
|
def build_profile(clip_paths: list[str]) -> dict | None:
|
||||||
|
"""Extract MFCCs from reference clips.
|
||||||
|
|
||||||
|
Returns dict with:
|
||||||
|
- mean_vector: averaged MFCC across all clips (20,)
|
||||||
|
- clip_vectors: list of individual MFCC vectors
|
||||||
|
Returns None if no clips could be loaded.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
for p in clip_paths:
|
||||||
|
try:
|
||||||
|
vec = _extract_mfcc(p)
|
||||||
|
vectors.append(vec)
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"audio_scan: skip {p}: {e}")
|
||||||
|
if not vectors:
|
||||||
|
return None
|
||||||
|
arr = np.stack(vectors)
|
||||||
|
return {
|
||||||
|
"mean_vector": arr.mean(axis=0),
|
||||||
|
"clip_vectors": vectors,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run tests to verify they pass**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: all 5 PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add audio_scan module with build_profile"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Core audio_scan module — scan_video
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py`
|
||||||
|
- Modify: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the tests**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from core.audio_scan import scan_video
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_finds_matching_region():
|
||||||
|
"""A video made of the same sine wave as the reference should match."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
_make_wav(ref.name, duration=8.0)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
_make_wav(vid.name, duration=20.0)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0)
|
||||||
|
assert len(regions) > 0
|
||||||
|
for start, end, score in regions:
|
||||||
|
assert abs((end - start) - 8.0) < 1e-9
|
||||||
|
assert score >= 0.5
|
||||||
|
assert score >= 0.5
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_nearest_mode():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
_make_wav(ref.name, duration=8.0)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
_make_wav(vid.name, duration=20.0)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.5, hop=1.0)
|
||||||
|
assert len(regions) > 0
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_high_threshold_no_match():
|
||||||
|
"""Different frequencies with very high threshold should not match."""
|
||||||
|
import soundfile as sf
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||||
|
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||||
|
sf.write(ref.name, 0.5 * np.sin(2 * np.pi * 440 * t), 22050)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
# White noise — very different from sine wave
|
||||||
|
sf.write(vid.name, np.random.randn(22050 * 20).astype(np.float32) * 0.1, 22050)
|
||||||
|
try:
|
||||||
|
profile = build_profile([ref.name])
|
||||||
|
regions = scan_video(vid.name, profile, mode="average", threshold=0.99, hop=1.0)
|
||||||
|
assert len(regions) == 0
|
||||||
|
finally:
|
||||||
|
os.unlink(ref.name)
|
||||||
|
os.unlink(vid.name)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run tests to verify they fail**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_scan_video_finds_matching_region -v`
|
||||||
|
Expected: FAIL with `ImportError: cannot import name 'scan_video'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
Add to `core/audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""Cosine similarity between two vectors.
|
||||||
|
|
||||||
|
Returns value in [-1, 1]. Negative means anti-correlated (very
|
||||||
|
dissimilar). For threshold filtering this is fine — negative scores
|
||||||
|
never exceed the threshold. Scores near 0 may be uncorrelated or
|
||||||
|
weakly anti-correlated.
|
||||||
|
"""
|
||||||
|
na = np.linalg.norm(a)
|
||||||
|
nb = np.linalg.norm(b)
|
||||||
|
if na == 0 or nb == 0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a, b) / (na * nb))
|
||||||
|
|
||||||
|
|
||||||
|
def scan_video(
|
||||||
|
video_path: str,
|
||||||
|
profile: dict,
|
||||||
|
mode: str = "average",
|
||||||
|
threshold: float = 0.7,
|
||||||
|
hop: float = 1.0,
|
||||||
|
window: float = 8.0,
|
||||||
|
cancel_flag: object = None,
|
||||||
|
) -> list[tuple[float, float, float]]:
|
||||||
|
"""Slide a window across the video audio and score against the profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: path to video/audio file
|
||||||
|
profile: dict from build_profile()
|
||||||
|
mode: "average" (compare to mean) or "nearest" (max over all clips)
|
||||||
|
threshold: minimum cosine similarity to include
|
||||||
|
hop: step size in seconds
|
||||||
|
window: window size in seconds (default 8s)
|
||||||
|
cancel_flag: object with _cancel bool attribute; checked each iteration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (start_time, end_time, score) for regions above threshold
|
||||||
|
"""
|
||||||
|
_log(f"audio_scan: loading {video_path}")
|
||||||
|
y, sr = librosa.load(video_path, sr=_SR, mono=True)
|
||||||
|
duration = len(y) / sr
|
||||||
|
_log(f"audio_scan: {duration:.1f}s loaded, scanning with hop={hop}s")
|
||||||
|
|
||||||
|
win_samples = int(window * sr)
|
||||||
|
hop_samples = int(hop * sr)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
pos = 0
|
||||||
|
while pos + win_samples <= len(y):
|
||||||
|
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||||
|
_log("audio_scan: cancelled")
|
||||||
|
return results
|
||||||
|
|
||||||
|
chunk = y[pos : pos + win_samples]
|
||||||
|
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=_N_MFCC)
|
||||||
|
vec = mfcc.mean(axis=1)
|
||||||
|
|
||||||
|
if mode == "nearest":
|
||||||
|
score = max(
|
||||||
|
_cosine_similarity(vec, cv) for cv in profile["clip_vectors"]
|
||||||
|
)
|
||||||
|
else: # average
|
||||||
|
score = _cosine_similarity(vec, profile["mean_vector"])
|
||||||
|
|
||||||
|
if score >= threshold:
|
||||||
|
start_t = pos / sr
|
||||||
|
results.append((start_t, start_t + window, score))
|
||||||
|
|
||||||
|
pos += hop_samples
|
||||||
|
|
||||||
|
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run tests to verify they pass**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||||
|
Expected: all 8 PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add scan_video with average and nearest modes"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Timeline — draw scan regions
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (Timeline class, around lines 209-260 and 300-375)
|
||||||
|
|
||||||
|
**Step 1: Add scan region storage to Timeline.__init__**
|
||||||
|
|
||||||
|
In `main.py`, find the Timeline class `__init__` method (around line 198). After `self._markers` initialization (line 209), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add set_scan_regions method**
|
||||||
|
|
||||||
|
After the `set_markers` method (line 249-252), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
|
||||||
|
"""regions: list of (start_time, end_time, score)"""
|
||||||
|
self._scan_regions = regions
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def clear_scan_regions(self) -> None:
|
||||||
|
self._scan_regions = []
|
||||||
|
self.update()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Draw scan regions in paintEvent**
|
||||||
|
|
||||||
|
In `paintEvent` (starts around line 282), find the marker drawing section (line 363, comment `# ── export markers`). BEFORE that section, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ── scan regions ──────────────────────────────────────────────
|
||||||
|
if self._scan_regions and self._duration > 0:
|
||||||
|
for (start, end, score) in self._scan_regions:
|
||||||
|
x1 = int(start / self._duration * w)
|
||||||
|
x2 = int(end / self._duration * w)
|
||||||
|
alpha = int(40 + score * 80) # 40–120 opacity
|
||||||
|
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: app starts without errors. No scan regions visible yet (none set).
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: timeline scan region rendering"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: ScanWorker QThread
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (add ScanWorker class, after ExportWorker around line 165)
|
||||||
|
|
||||||
|
**Step 1: Add the ScanWorker class**
|
||||||
|
|
||||||
|
After the `ExportWorker` class (ends around line 165), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ScanWorker(QThread):
|
||||||
|
"""Runs audio similarity scan off the main thread."""
|
||||||
|
finished = pyqtSignal(list) # emits list of (start, end, score)
|
||||||
|
error = pyqtSignal(str)
|
||||||
|
progress = pyqtSignal(str) # status message
|
||||||
|
|
||||||
|
def __init__(self, video_path: str, clip_paths: list[str],
|
||||||
|
mode: str = "average", threshold: float = 0.7):
|
||||||
|
super().__init__()
|
||||||
|
self._video_path = video_path
|
||||||
|
self._clip_paths = clip_paths
|
||||||
|
self._mode = mode
|
||||||
|
self._threshold = threshold
|
||||||
|
self._cancel = False
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._cancel = True
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
from core.audio_scan import build_profile, scan_video
|
||||||
|
try:
|
||||||
|
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
|
||||||
|
profile = build_profile(self._clip_paths)
|
||||||
|
if self._cancel:
|
||||||
|
return
|
||||||
|
if profile is None:
|
||||||
|
self.error.emit("No valid reference clips found")
|
||||||
|
return
|
||||||
|
self.progress.emit("Scanning audio...")
|
||||||
|
regions = scan_video(
|
||||||
|
self._video_path, profile,
|
||||||
|
mode=self._mode, threshold=self._threshold,
|
||||||
|
cancel_flag=self,
|
||||||
|
)
|
||||||
|
if not self._cancel:
|
||||||
|
self.finished.emit(regions)
|
||||||
|
except Exception as e:
|
||||||
|
if not self._cancel:
|
||||||
|
self.error.emit(str(e))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify import works**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -c "from main import ScanWorker; print('ok')"`
|
||||||
|
Expected: `ok`
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add ScanWorker QThread for background scanning"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: DB helper — get_all_export_paths
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/db.py`
|
||||||
|
- Modify: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write the test**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_db_get_all_export_paths():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||||
|
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||||
|
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||||
|
paths = db.get_all_export_paths("test")
|
||||||
|
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run test to verify it fails**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||||
|
Expected: FAIL with `AttributeError: 'ProcessedDB' object has no attribute 'get_all_export_paths'`
|
||||||
|
|
||||||
|
**Step 3: Write the implementation**
|
||||||
|
|
||||||
|
Add to `core/db.py`, after the `get_markers` method. Note: no lock needed — follows
|
||||||
|
the codebase convention where read-only methods don't acquire the lock.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||||
|
"""Return all unique output_path values for a given profile."""
|
||||||
|
if not self._enabled:
|
||||||
|
return []
|
||||||
|
rows = self._con.execute(
|
||||||
|
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||||
|
(profile,),
|
||||||
|
).fetchall()
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run test to verify it passes**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/db.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add get_all_export_paths to ProcessedDB"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: UI controls for audio scanning
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` (MainWindow class — control creation ~1490-1575, layout ~1620-1640)
|
||||||
|
|
||||||
|
**Step 1: Add scan control widgets**
|
||||||
|
|
||||||
|
In the MainWindow `__init__`, find the control creation section. After `self._chk_track` (around line 1501), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ── audio scan controls ──────────────────────────────────────
|
||||||
|
self._btn_scan = QPushButton("Scan")
|
||||||
|
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
|
||||||
|
self._btn_scan.clicked.connect(self._start_scan)
|
||||||
|
|
||||||
|
self._sld_threshold = QDoubleSpinBox()
|
||||||
|
self._sld_threshold.setRange(0.0, 1.0)
|
||||||
|
self._sld_threshold.setSingleStep(0.05)
|
||||||
|
self._sld_threshold.setValue(0.7)
|
||||||
|
self._sld_threshold.setPrefix("Thr: ")
|
||||||
|
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
|
||||||
|
|
||||||
|
self._cmb_scan_mode = QComboBox()
|
||||||
|
self._cmb_scan_mode.addItems(["Average", "Nearest"])
|
||||||
|
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
|
||||||
|
|
||||||
|
self._cmb_scan_ref = QComboBox()
|
||||||
|
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
|
||||||
|
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
|
||||||
|
self._scan_folder: str = ""
|
||||||
|
|
||||||
|
self._scan_worker: ScanWorker | None = None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add controls to settings_row layout**
|
||||||
|
|
||||||
|
Find the `settings_row` assembly (around line 1620). Before `settings_row.addStretch()` (around line 1635), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
settings_row.addWidget(self._btn_scan)
|
||||||
|
settings_row.addWidget(self._sld_threshold)
|
||||||
|
settings_row.addWidget(self._cmb_scan_mode)
|
||||||
|
settings_row.addWidget(self._cmb_scan_ref)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Add handler methods**
|
||||||
|
|
||||||
|
Add these methods to MainWindow (after `_jump_to_next_marker` around line 2410):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _on_scan_ref_changed(self, index: int) -> None:
|
||||||
|
if index == 1: # Custom Folder
|
||||||
|
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
|
||||||
|
if folder:
|
||||||
|
self._scan_folder = folder
|
||||||
|
else:
|
||||||
|
self._cmb_scan_ref.setCurrentIndex(0)
|
||||||
|
|
||||||
|
def _cleanup_scan_worker(self) -> None:
|
||||||
|
"""Disconnect signals and schedule deletion of old scan worker."""
|
||||||
|
if self._scan_worker is not None:
|
||||||
|
try:
|
||||||
|
self._scan_worker.finished.disconnect()
|
||||||
|
self._scan_worker.error.disconnect()
|
||||||
|
self._scan_worker.progress.disconnect()
|
||||||
|
except TypeError:
|
||||||
|
pass # already disconnected
|
||||||
|
self._scan_worker.deleteLater()
|
||||||
|
self._scan_worker = None
|
||||||
|
|
||||||
|
def _start_scan(self) -> None:
|
||||||
|
if not self._file_path:
|
||||||
|
self._show_status("No video loaded")
|
||||||
|
return
|
||||||
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
|
self._show_status("Scan already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clean up previous worker
|
||||||
|
self._cleanup_scan_worker()
|
||||||
|
|
||||||
|
# Collect reference clip paths
|
||||||
|
if self._cmb_scan_ref.currentIndex() == 0:
|
||||||
|
# Current profile — all exports across all files in this profile
|
||||||
|
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
|
||||||
|
if os.path.exists(p)]
|
||||||
|
else:
|
||||||
|
# Custom folder
|
||||||
|
if not self._scan_folder:
|
||||||
|
self._show_status("No reference folder selected")
|
||||||
|
return
|
||||||
|
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
|
||||||
|
clip_paths = [
|
||||||
|
os.path.join(self._scan_folder, f)
|
||||||
|
for f in sorted(os.listdir(self._scan_folder))
|
||||||
|
if f.lower().endswith(exts)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not clip_paths:
|
||||||
|
self._show_status("No reference clips found")
|
||||||
|
return
|
||||||
|
|
||||||
|
mode = self._cmb_scan_mode.currentText().lower()
|
||||||
|
threshold = self._sld_threshold.value()
|
||||||
|
|
||||||
|
self._btn_scan.setEnabled(False)
|
||||||
|
self._scan_file_path = self._file_path # remember which file we're scanning
|
||||||
|
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
|
||||||
|
|
||||||
|
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
|
||||||
|
self._scan_worker.finished.connect(self._on_scan_done)
|
||||||
|
self._scan_worker.error.connect(self._on_scan_error)
|
||||||
|
self._scan_worker.progress.connect(self._show_status)
|
||||||
|
self._scan_worker.start()
|
||||||
|
|
||||||
|
def _on_scan_done(self, regions: list) -> None:
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
# Ignore stale results if the user switched files during scan
|
||||||
|
if self._file_path != getattr(self, '_scan_file_path', None):
|
||||||
|
return
|
||||||
|
self._timeline.set_scan_regions(regions)
|
||||||
|
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
||||||
|
|
||||||
|
def _on_scan_error(self, msg: str) -> None:
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
self._show_status(f"Scan error: {msg}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: Scan button, threshold spinner, mode dropdown, and reference source dropdown visible in the settings row. Clicking Scan with no file loaded shows "No video loaded" in status.
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add scan UI controls and start_scan handler"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Keyboard shortcut — jump to next scan region
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py`
|
||||||
|
|
||||||
|
**Step 1: Add the keyboard shortcut**
|
||||||
|
|
||||||
|
Find the shortcut definitions (around line 1728, where `QShortcut(QKeySequence("M"), ...)` is defined). Add after it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add the jump method**
|
||||||
|
|
||||||
|
After `_on_scan_error` (or after `_jump_to_next_marker`), add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _jump_to_next_scan_region(self) -> None:
|
||||||
|
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
|
||||||
|
if not regions:
|
||||||
|
return
|
||||||
|
for (start, _end, _score) in regions:
|
||||||
|
if start > self._cursor + 0.1:
|
||||||
|
self._step_cursor(start - self._cursor)
|
||||||
|
return
|
||||||
|
# Wrap to first region
|
||||||
|
self._step_cursor(regions[0][0] - self._cursor)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Update help text**
|
||||||
|
|
||||||
|
Find the help/shortcuts tooltip (around line 1757). Add a row:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Clear scan regions and cancel running scan on file change**
|
||||||
|
|
||||||
|
Find `_load_file` method (around line 1931). After the existing marker/state resets, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self._timeline.clear_scan_regions()
|
||||||
|
if self._scan_worker and self._scan_worker.isRunning():
|
||||||
|
self._scan_worker.cancel()
|
||||||
|
self._cleanup_scan_worker()
|
||||||
|
self._btn_scan.setEnabled(True)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Verify manually**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python main.py`
|
||||||
|
Expected: S key does nothing when no scan regions exist. After a scan, S jumps through matched regions.
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add main.py
|
||||||
|
git commit -m "feat: add S shortcut and clear scan on file change"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: Final integration test
|
||||||
|
|
||||||
|
**Step 1: End-to-end manual test**
|
||||||
|
|
||||||
|
1. Open the app: `cd /media/p5/8-cut && python main.py`
|
||||||
|
2. Load a video file
|
||||||
|
3. Export a few clips (these become the reference)
|
||||||
|
4. Set reference source to "Current Profile"
|
||||||
|
5. Click "Scan"
|
||||||
|
6. Verify: status shows progress messages, then "Scan complete: N matching regions"
|
||||||
|
7. Verify: cyan-tinted regions appear on the timeline
|
||||||
|
8. Press S to jump through scan regions
|
||||||
|
9. Change threshold and re-scan — verify different number of regions
|
||||||
|
10. Switch mode to "Nearest" and re-scan
|
||||||
|
11. Switch reference to "Custom Folder", pick a folder with clips
|
||||||
|
12. Re-scan and verify results
|
||||||
|
|
||||||
|
**Step 2: Run all tests**
|
||||||
|
|
||||||
|
Run: `cd /media/p5/8-cut && python -m pytest tests/ -v`
|
||||||
|
Expected: all tests PASS
|
||||||
|
|
||||||
|
**Step 3: Final commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add -A
|
||||||
|
git commit -m "feat: audio similarity scanning complete"
|
||||||
|
```
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# Audio Pipeline Improvements Design
|
||||||
|
|
||||||
|
Date: 2026-04-19
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Improve audio scan classification accuracy, especially for non-speech sounds (suction, gagging, impacts), through three changes:
|
||||||
|
|
||||||
|
1. Multi-layer feature extraction from existing HuBERT/Wav2Vec2 models
|
||||||
|
2. Two new embedding models: AST (AudioSet-supervised) and EAT (self-supervised + AudioSet finetuned)
|
||||||
|
3. Calibrated classifier for better threshold behavior
|
||||||
|
|
||||||
|
## 1. Multi-Layer Feature Extraction
|
||||||
|
|
||||||
|
### Current behavior
|
||||||
|
|
||||||
|
`model(waveforms)` extracts embeddings from the **last transformer layer only**.
|
||||||
|
|
||||||
|
### Change
|
||||||
|
|
||||||
|
Use `model.extract_features(waveforms)` (torchaudio API) to get all layer outputs. Select layers at quartile boundaries, mean-pool each over time, concatenate.
|
||||||
|
|
||||||
|
| Model | Layers | Single-layer dim | Multi-layer dim (4 quartiles) |
|
||||||
|
|-------|--------|-------------------|-------------------------------|
|
||||||
|
| HUBERT_XLARGE | 48 | 1280 | 5120 |
|
||||||
|
| HUBERT_LARGE | 24 | 1024 | 4096 |
|
||||||
|
| HUBERT_BASE | 12 | 768 | 3072 |
|
||||||
|
| WAV2VEC2_BASE | 12 | 768 | 3072 |
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
- New entries in `_EMBED_MODELS`: `"HUBERT_XLARGE_ML"` -> 5120, etc.
|
||||||
|
- `_extract_w2v_windows`: when model name ends with `_ML`, call `extract_features()` instead of `model()`, select quartile layers, concat
|
||||||
|
- Cache key: model name includes `_ML` suffix -> separate cache files
|
||||||
|
- No change to classifier or training pipeline (HistGBT handles high-dim fine)
|
||||||
|
|
||||||
|
## 2. AST (Audio Spectrogram Transformer)
|
||||||
|
|
||||||
|
### What
|
||||||
|
|
||||||
|
`MIT/ast-finetuned-audioset-10-10-0.4593` via HuggingFace `transformers`. 86M params, 768-dim, supervised on AudioSet 527 sound classes.
|
||||||
|
|
||||||
|
### Integration
|
||||||
|
|
||||||
|
- Load: `ASTModel.from_pretrained()` + `ASTFeatureExtractor`
|
||||||
|
- Preprocessing: `ASTFeatureExtractor` handles mel spectrogram from 16kHz raw audio
|
||||||
|
- Batching: prepare `input_values` per window, stack into batch, forward through model
|
||||||
|
- Multi-layer: `output_hidden_states=True` returns 13 layers; `AST_ML` variant concats quartile layers -> 3072-dim
|
||||||
|
- Model cached via `_get_w2v_model()` same lazy-load pattern
|
||||||
|
|
||||||
|
### Entries
|
||||||
|
|
||||||
|
- `"AST"` -> 768
|
||||||
|
- `"AST_ML"` -> 3072
|
||||||
|
|
||||||
|
## 3. EAT (Efficient Audio Transformer)
|
||||||
|
|
||||||
|
### What
|
||||||
|
|
||||||
|
`worstchan/EAT-base_epoch30_finetune_AS2M` via HuggingFace with `trust_remote_code=True`. 88M params, 768-dim, self-supervised + AudioSet finetuned.
|
||||||
|
|
||||||
|
### Integration
|
||||||
|
|
||||||
|
- Load: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
||||||
|
- Preprocessing: manual 128-bin Kaldi fbank mel spectrogram via torchaudio, normalize with EAT constants `(mel - (-4.268)) / (4.569 * 2)`, reshape to `[B, 1, T, 128]`
|
||||||
|
- Feature extraction: `model.extract_features(mel)` returns `[B, seq, 768]`; CLS token `[:, 0, :]` for utterance-level, or mean-pool `[:, 1:, :]` for frame-level. Use mean-pool for consistency with other models.
|
||||||
|
- Multi-layer: not natively supported, skip for now
|
||||||
|
|
||||||
|
### Entry
|
||||||
|
|
||||||
|
- `"EAT"` -> 768
|
||||||
|
|
||||||
|
## 4. Calibrated Classifier
|
||||||
|
|
||||||
|
Wrap `HistGradientBoostingClassifier` in `CalibratedClassifierCV(clf, cv=3, method='isotonic')` after fitting. Gives well-calibrated probabilities -> threshold slider maps more linearly to precision/recall.
|
||||||
|
|
||||||
|
One change in `train_classifier()`, no UI changes needed.
|
||||||
|
|
||||||
|
## 5. Requirements
|
||||||
|
|
||||||
|
Add to `requirements.txt`:
|
||||||
|
```
|
||||||
|
transformers>=4.30
|
||||||
|
timm>=0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
Both AST and EAT need `transformers`. EAT additionally needs `timm` (used internally by its custom model code). Both setup scripts (`setup_env.sh`, `setup-windows.ps1`) install from `requirements.txt` so no changes needed there.
|
||||||
|
|
||||||
|
## Cache Compatibility
|
||||||
|
|
||||||
|
- All new model variants get distinct cache keys via model name in the hash
|
||||||
|
- Existing caches for HUBERT_XLARGE, BEATs, etc. remain valid and untouched
|
||||||
|
- New models create new `.npz` files in the same `cache/w2v/` directory
|
||||||
|
|
||||||
|
## UI Changes
|
||||||
|
|
||||||
|
- `_EMBED_MODELS` dict additions appear automatically in Train dialog model dropdown and scan model dropdown
|
||||||
|
- No other UI changes needed
|
||||||
@@ -0,0 +1,588 @@
|
|||||||
|
# Audio Pipeline Improvements Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Improve audio scan accuracy with multi-layer extraction, AST/EAT models, and calibrated classifier.
|
||||||
|
|
||||||
|
**Architecture:** All changes are in `core/audio_scan.py`. The embedding extraction functions gain new model-type branches (AST, EAT, multi-layer). The classifier gets a calibration wrapper. `_EMBED_MODELS` dict and `_get_w2v_model()` are extended. No UI changes needed — new models appear automatically in dropdowns.
|
||||||
|
|
||||||
|
**Tech Stack:** torchaudio (existing), transformers (new dep), timm (new dep), sklearn.calibration (existing dep)
|
||||||
|
|
||||||
|
**Key design notes:**
|
||||||
|
- `_get_w2v_model()` resolves `_ML` suffixed names to their base model for loading (e.g. `HUBERT_XLARGE_ML` loads `HUBERT_XLARGE`). Both share the same GPU model — only the extraction path differs (last-layer vs multi-layer). The global `_w2v_model_name` stores the **base** name so switching between `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` does NOT trigger a reload.
|
||||||
|
- Cache keys use the **full** model name (including `_ML`), so single-layer and multi-layer caches coexist as separate `.npz` files.
|
||||||
|
- AST and EAT are separate model types that do NOT share the torchaudio loading path — they get their own `elif` branches in `_get_w2v_model()`.
|
||||||
|
- Both `_extract_w2v_windows` and `_extract_w2v_targeted` need identical changes to their batch inference blocks. Keep them in sync.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add transformers and timm to requirements
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `requirements.txt`
|
||||||
|
|
||||||
|
**Step 1: Add dependencies**
|
||||||
|
|
||||||
|
Add after the `torchaudio` line in `requirements.txt`:
|
||||||
|
|
||||||
|
```
|
||||||
|
transformers>=4.30
|
||||||
|
timm>=0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify install**
|
||||||
|
|
||||||
|
Run: `pip install transformers timm`
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add requirements.txt
|
||||||
|
git commit -m "deps: add transformers and timm for AST/EAT models"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Multi-layer extraction for torchaudio models
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-58` (_EMBED_MODELS dict)
|
||||||
|
- Modify: `core/audio_scan.py:96-100` (_embed_dim)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model)
|
||||||
|
- Modify: `core/audio_scan.py:189-205` (_extract_w2v_windows batch loop)
|
||||||
|
- Modify: `core/audio_scan.py:278-293` (_extract_w2v_targeted batch loop)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
Add to `tests/test_audio_scan.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_multi_layer():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
# Multi-layer models should report concatenated dimension
|
||||||
|
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||||
|
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||||
|
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||||
|
# Single-layer unchanged
|
||||||
|
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run test to verify it fails**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||||
|
Expected: FAIL — `_embed_dim("HUBERT_XLARGE_ML")` returns 768 (default fallback)
|
||||||
|
|
||||||
|
**Step 3: Add multi-layer entries to _EMBED_MODELS**
|
||||||
|
|
||||||
|
In `core/audio_scan.py:50-58`, add after existing entries:
|
||||||
|
|
||||||
|
```python
|
||||||
|
_EMBED_MODELS = {
|
||||||
|
"WAV2VEC2_BASE": 768,
|
||||||
|
"WAV2VEC2_LARGE": 1024,
|
||||||
|
"WAV2VEC2_LARGE_LV60K": 1024,
|
||||||
|
"HUBERT_BASE": 768,
|
||||||
|
"HUBERT_LARGE": 1024,
|
||||||
|
"HUBERT_XLARGE": 1280,
|
||||||
|
"BEATS": 768,
|
||||||
|
# Multi-layer variants (4 quartile layers concatenated)
|
||||||
|
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||||
|
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||||
|
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Run test to verify it passes**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 5: Add helper to resolve base model and layer indices**
|
||||||
|
|
||||||
|
Add after `_embed_dim()` (around line 101):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||||
|
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||||
|
|
||||||
|
Returns None for single-layer models.
|
||||||
|
Layer indices are 0-based into the list returned by extract_features().
|
||||||
|
"""
|
||||||
|
if not model_name.endswith("_ML"):
|
||||||
|
return None
|
||||||
|
base = model_name[:-3] # strip "_ML"
|
||||||
|
if base not in _EMBED_MODELS:
|
||||||
|
return None
|
||||||
|
# Layer counts per model family
|
||||||
|
layer_counts = {
|
||||||
|
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||||
|
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||||
|
"AST": 12,
|
||||||
|
}
|
||||||
|
n = layer_counts.get(base)
|
||||||
|
if n is None:
|
||||||
|
return None
|
||||||
|
# Select 4 layers at quartile boundaries (0-indexed)
|
||||||
|
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||||
|
return base, indices
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: AST is included in the layer_counts dict here already so Task 3 doesn't need to modify it again.
|
||||||
|
|
||||||
|
**Step 6: Write test for _ml_config**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_ml_config():
|
||||||
|
from core.audio_scan import _ml_config
|
||||||
|
assert _ml_config("HUBERT_XLARGE") is None
|
||||||
|
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||||
|
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||||
|
assert base == "HUBERT_XLARGE"
|
||||||
|
assert layers == [11, 23, 35, 47]
|
||||||
|
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||||
|
assert base == "HUBERT_BASE"
|
||||||
|
assert layers == [2, 5, 8, 11]
|
||||||
|
```
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_ml_config -v`
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 7: Modify _get_w2v_model to resolve ML base names**
|
||||||
|
|
||||||
|
In `_get_w2v_model()` (line 68), the comparison key must use the resolved base name so that `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` share the same loaded model without reloading:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_w2v_model(model_name: str | None = None):
|
||||||
|
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||||
|
global _w2v_model, _w2v_device, _w2v_model_name
|
||||||
|
if model_name is None:
|
||||||
|
model_name = _DEFAULT_EMBED_MODEL
|
||||||
|
# Multi-layer variants use the same base model weights
|
||||||
|
ml = _ml_config(model_name)
|
||||||
|
load_name = ml[0] if ml else model_name
|
||||||
|
if _w2v_model is None or _w2v_model_name != load_name:
|
||||||
|
import torch
|
||||||
|
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if load_name == "BEATS":
|
||||||
|
... # existing BEATs code unchanged
|
||||||
|
else:
|
||||||
|
import torchaudio
|
||||||
|
bundle = getattr(torchaudio.pipelines, load_name)
|
||||||
|
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||||
|
_w2v_model.eval()
|
||||||
|
_w2v_model_name = load_name
|
||||||
|
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||||
|
return _w2v_model, _w2v_device
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 8: Modify _extract_w2v_windows batch inference**
|
||||||
|
|
||||||
|
In `_extract_w2v_windows`, compute `ml_cfg` **once** before the batch loop (after line 173 `is_beats = ...`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then replace the batch inference block (lines 197-204):
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings.append(batch_emb)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 9: Modify _extract_w2v_targeted batch inference (keep in sync)**
|
||||||
|
|
||||||
|
In `_extract_w2v_targeted`, add `ml_cfg` computation after line 276 `is_beats = ...`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then replace the batch inference block (lines 285-292) with the same branching logic as Step 8:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.no_grad():
|
||||||
|
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||||
|
if is_beats:
|
||||||
|
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||||
|
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
elif ml_cfg is not None:
|
||||||
|
all_layers, _ = model.extract_features(waveforms)
|
||||||
|
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
features, _ = model(waveforms)
|
||||||
|
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||||
|
embeddings_list.append(batch_emb)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `_extract_w2v_targeted` appends to `embeddings_list` (not `embeddings`).
|
||||||
|
|
||||||
|
**Step 10: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 11: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: multi-layer extraction for HuBERT/Wav2Vec2 models"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: AST model integration
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add AST entries)
|
||||||
|
- Modify: `core/audio_scan.py:45-47` (add _ast_feature_extractor global)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add AST loading branch)
|
||||||
|
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add AST inference branch)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
```
|
||||||
|
|
||||||
|
Run: `pytest tests/test_audio_scan.py::test_embed_dim_ast -v`
|
||||||
|
Expected: FAIL
|
||||||
|
|
||||||
|
**Step 2: Add AST entries to _EMBED_MODELS**
|
||||||
|
|
||||||
|
Add to the dict (after the ML entries):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Transformers-based models
|
||||||
|
"AST": 768,
|
||||||
|
"AST_ML": 3072, # 768 * 4
|
||||||
|
```
|
||||||
|
|
||||||
|
Run test again — should PASS now.
|
||||||
|
|
||||||
|
**Step 3: Add module-level global for AST feature extractor**
|
||||||
|
|
||||||
|
Near line 47 (after `_w2v_model_name = None`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
_ast_feature_extractor = None
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add AST loading branch in _get_w2v_model**
|
||||||
|
|
||||||
|
In `_get_w2v_model()`, add an `elif` branch **before** the torchaudio fallback `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif load_name == "AST":
|
||||||
|
from transformers import ASTModel, ASTFeatureExtractor
|
||||||
|
_w2v_model = ASTModel.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
).to(_w2v_device)
|
||||||
|
global _ast_feature_extractor
|
||||||
|
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||||
|
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `_ast_feature_extractor` is recreated on every model load (not cached separately) — simple and correct since the feature extractor is lightweight and model reloads are rare.
|
||||||
|
|
||||||
|
**Step 5: Add AST inference branch in both extraction functions**
|
||||||
|
|
||||||
|
In both `_extract_w2v_windows` and `_extract_w2v_targeted`, compute `is_ast` once before the loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in the batch inference block, add after the `elif ml_cfg` branch and before `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif is_ast:
|
||||||
|
# AST uses its own feature extractor for mel spectrogram
|
||||||
|
inputs = _ast_feature_extractor(
|
||||||
|
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
input_values = inputs.input_values.to(device)
|
||||||
|
if ml_cfg is not None:
|
||||||
|
out = model(input_values, output_hidden_states=True)
|
||||||
|
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||||
|
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = model(input_values)
|
||||||
|
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
```
|
||||||
|
|
||||||
|
Important: `chunks` is already a list of numpy arrays (built in the loop at lines 194-196). Pass it directly as `list(chunks)` — the `ASTFeatureExtractor` accepts a list of numpy arrays and handles batching/padding internally. Verified: `ASTFeatureExtractor([np.array, np.array, ...], sampling_rate=16000, return_tensors="pt", padding=True)` returns `input_values` of shape `[B, 1024, 128]`.
|
||||||
|
|
||||||
|
**Step 6: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 7: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add AST (Audio Spectrogram Transformer) embedding model"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: EAT model integration
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add EAT entry)
|
||||||
|
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add EAT loading branch)
|
||||||
|
- Add: `core/audio_scan.py` (_eat_preprocess helper function)
|
||||||
|
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add EAT inference branch)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Write failing test**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Add EAT entry to _EMBED_MODELS**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"EAT": 768,
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: No `EAT_ML` variant — EAT's `extract_features()` does not natively support multi-layer output. Can be added later if needed by monkey-patching.
|
||||||
|
|
||||||
|
**Step 3: Add EAT loading branch in _get_w2v_model**
|
||||||
|
|
||||||
|
Add after the AST branch, before the torchaudio `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif load_name == "EAT":
|
||||||
|
from transformers import AutoModel
|
||||||
|
_w2v_model = AutoModel.from_pretrained(
|
||||||
|
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||||
|
trust_remote_code=True,
|
||||||
|
).to(_w2v_device)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Add EAT preprocessing helper**
|
||||||
|
|
||||||
|
Add as a module-level function near `_get_w2v_model`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||||
|
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||||
|
|
||||||
|
Returns tensor of shape [B, 1, T, 128].
|
||||||
|
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
TARGET_LEN = 1024
|
||||||
|
MEAN, STD = -4.268, 4.569
|
||||||
|
|
||||||
|
mels = []
|
||||||
|
for chunk in chunks:
|
||||||
|
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||||
|
fbank = kaldi.fbank(
|
||||||
|
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||||
|
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||||
|
)
|
||||||
|
# Pad or truncate to TARGET_LEN
|
||||||
|
if fbank.shape[0] < TARGET_LEN:
|
||||||
|
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||||
|
else:
|
||||||
|
fbank = fbank[:TARGET_LEN]
|
||||||
|
fbank = (fbank - MEAN) / (STD * 2)
|
||||||
|
mels.append(fbank)
|
||||||
|
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Add EAT inference branch in both extraction functions**
|
||||||
|
|
||||||
|
Compute `is_eat` once before the loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in the batch inference block, add after the `elif is_ast` branch and before `else`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif is_eat:
|
||||||
|
mel_input = _eat_preprocess(chunks, sr, device)
|
||||||
|
features = model.extract_features(mel_input)
|
||||||
|
# Mean-pool frame-level tokens (skip CLS at index 0)
|
||||||
|
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||||
|
```
|
||||||
|
|
||||||
|
Important: `model.extract_features()` returns a plain `torch.Tensor` of shape `[B, 513, 768]` (not a tuple). Index 0 is the CLS token, indices 1-512 are frame-level patch embeddings. We mean-pool the frame tokens for consistency with how other models are pooled.
|
||||||
|
|
||||||
|
**Step 6: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 7: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py tests/test_audio_scan.py
|
||||||
|
git commit -m "feat: add EAT (Efficient Audio Transformer) embedding model"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Calibrated classifier
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `core/audio_scan.py:424-429` (train_classifier, wrap clf)
|
||||||
|
- Test: `tests/test_audio_scan.py`
|
||||||
|
|
||||||
|
**Step 1: Modify train_classifier**
|
||||||
|
|
||||||
|
After the existing `clf.fit()` call (line 428), add calibration with a safe guard:
|
||||||
|
|
||||||
|
```python
|
||||||
|
clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
_log("audio_scan: classifier trained")
|
||||||
|
|
||||||
|
# Calibrate probabilities for better threshold behavior
|
||||||
|
# Requires at least 6 samples per class for stable 3-fold isotonic calibration
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
min_class = min(int(n_pos), int(n_neg_sample))
|
||||||
|
if min_class >= 6:
|
||||||
|
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||||
|
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||||
|
clf = cal_clf
|
||||||
|
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||||
|
else:
|
||||||
|
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||||
|
```
|
||||||
|
|
||||||
|
Why `min_class >= 6`: `CalibratedClassifierCV` uses stratified k-fold internally. With `cv=3`, each fold needs at least 2 samples per class. `min_class >= 6` guarantees this. With fewer samples, the uncalibrated HistGBT probabilities are still reasonable — calibration is an enhancement, not a requirement.
|
||||||
|
|
||||||
|
Previous plan bug: `cv=min(3, n_pos, n_neg_sample)` could produce `cv=1` when `n_pos=1`, which raises `ValueError` (minimum is 2). Even `cv=2` with 2 positives causes one fold to have only 1 positive, making isotonic regression unstable. The `>= 6` guard avoids all these edge cases.
|
||||||
|
|
||||||
|
**Step 2: Run all tests**
|
||||||
|
|
||||||
|
Run: `pytest tests/ -v`
|
||||||
|
Expected: All pass
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add core/audio_scan.py
|
||||||
|
git commit -m "feat: calibrate classifier probabilities with isotonic regression"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Integration test with real model (manual)
|
||||||
|
|
||||||
|
This task is manual — it requires GPU and a real video file.
|
||||||
|
|
||||||
|
**Step 1: Test multi-layer extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='HUBERT_XLARGE_ML')
|
||||||
|
print(f'HUBERT_XLARGE_ML: {emb.shape}') # expect (13, 5120)
|
||||||
|
assert emb.shape[1] == _embed_dim('HUBERT_XLARGE_ML')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Test AST extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='AST')
|
||||||
|
print(f'AST: {emb.shape}') # expect (13, 768)
|
||||||
|
assert emb.shape[1] == _embed_dim('AST')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Test AST multi-layer**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='AST_ML')
|
||||||
|
print(f'AST_ML: {emb.shape}') # expect (13, 3072)
|
||||||
|
assert emb.shape[1] == _embed_dim('AST_ML')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Test EAT extraction**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||||
|
import numpy as np
|
||||||
|
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||||
|
ts, emb = _extract_w2v_windows(y, model_name='EAT')
|
||||||
|
print(f'EAT: {emb.shape}') # expect (13, 768)
|
||||||
|
assert emb.shape[1] == _embed_dim('EAT')
|
||||||
|
print('PASS')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Test model switching doesn't reload unnecessarily**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
from core.audio_scan import _get_w2v_model
|
||||||
|
import core.audio_scan as m
|
||||||
|
# Load HUBERT_XLARGE
|
||||||
|
_get_w2v_model('HUBERT_XLARGE')
|
||||||
|
name1 = m._w2v_model_name
|
||||||
|
# Switch to ML variant — should NOT reload
|
||||||
|
_get_w2v_model('HUBERT_XLARGE_ML')
|
||||||
|
name2 = m._w2v_model_name
|
||||||
|
assert name1 == name2 == 'HUBERT_XLARGE', f'Expected no reload, got {name1} -> {name2}'
|
||||||
|
print('PASS: no reload on ML switch')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Test full train+scan cycle in app**
|
||||||
|
|
||||||
|
Load app, select each new model from scan model dropdown, scan a video, train, verify results display correctly.
|
||||||
|
|
||||||
|
**Step 7: Final commit and push**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push
|
||||||
|
```
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
# ComfyUI-8cut Node Pack Design
|
||||||
|
|
||||||
|
Date: 2026-04-19
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Port 8-cut's video scanning, training, review, and export workflow to a ComfyUI node pack. The primary motivation is **remote access** — ComfyUI's web UI allows browser-based operation over the network, and HTML5 `<video>` handles streaming compression natively. No tensor-based image pipeline; videos stay as file paths throughout.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Approach
|
||||||
|
|
||||||
|
Monolithic Review Node + simple pipeline nodes. One central **VideoReview** node embeds the full interactive player/timeline/region table as a large DOM widget. Other nodes (Scan, Train, Export) are headless pipeline nodes that pass lightweight metadata.
|
||||||
|
|
||||||
|
### Core reuse
|
||||||
|
|
||||||
|
The entire `8-cut/core/` package is Qt-free and reusable as-is:
|
||||||
|
- `core/audio_scan.py` — `scan_video()`, `train_classifier()`, `load_classifier()`
|
||||||
|
- `core/db.py` — `ProcessedDB` (SQLite, all scan/training/export persistence)
|
||||||
|
- `core/ffmpeg.py` — `build_ffmpeg_command()` (clip export)
|
||||||
|
- `core/tracking.py` — YOLO-based subject tracking
|
||||||
|
- `core/paths.py` — path helpers, `format_time()`
|
||||||
|
|
||||||
|
No porting required — these are imported directly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Node Pack Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
ComfyUI-8cut/
|
||||||
|
__init__.py # NODE_CLASS_MAPPINGS, WEB_DIRECTORY
|
||||||
|
core/ # symlink or copy of 8-cut/core/
|
||||||
|
data/
|
||||||
|
8cut.db # separate SQLite DB (can copy from ~/.8cut.db)
|
||||||
|
models/ # trained classifiers (.joblib)
|
||||||
|
nodes/
|
||||||
|
load_video.py
|
||||||
|
audio_scan.py
|
||||||
|
video_review.py
|
||||||
|
train_model.py
|
||||||
|
export_clips.py
|
||||||
|
server_routes.py # custom API routes
|
||||||
|
web/
|
||||||
|
js/
|
||||||
|
video_review.js # timeline + player + scan panel widget
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Custom Types
|
||||||
|
|
||||||
|
No tensors anywhere in the pipeline. All data flows as lightweight metadata:
|
||||||
|
|
||||||
|
| Type | Python value | Purpose |
|
||||||
|
|------|-------------|---------|
|
||||||
|
| `VIDEO_PATH` | `str` (absolute path) | Video file reference |
|
||||||
|
| `SCAN_REGIONS` | `list[dict]` with start/end/score/model/disabled | Scan output / review edits |
|
||||||
|
| `SCAN_MODEL` | `str` (path to .joblib) | Trained classifier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### LoadVideo
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `video_path` (STRING, file browser), `profile` (STRING combo from DB profiles) |
|
||||||
|
| **Output** | `VIDEO_PATH`, `filename` (STRING) |
|
||||||
|
| **Logic** | Validates path exists, returns it. Populates profile combo via API route. |
|
||||||
|
|
||||||
|
### AudioScan
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_MODEL`, `threshold` (FLOAT 0-1), `hop` (FLOAT) |
|
||||||
|
| **Output** | `SCAN_REGIONS` |
|
||||||
|
| **Logic** | Calls `core.audio_scan.scan_video()` directly. Progress via `PromptServer.send_sync("progress", ...)`. |
|
||||||
|
|
||||||
|
### VideoReview (interactive, blocking)
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS` (optional) |
|
||||||
|
| **Output** | `SCAN_REGIONS` (edited) |
|
||||||
|
| **OUTPUT_NODE** | `True` |
|
||||||
|
| **Logic** | Execution pauses here. User interacts via the widget. Clicks "Continue" to pass edited regions downstream. |
|
||||||
|
|
||||||
|
The widget layout:
|
||||||
|
|
||||||
|
```
|
||||||
|
+-------------------------------------+
|
||||||
|
| [video player (HTML5 <video>)] |
|
||||||
|
| +- timeline with scan regions ----+|
|
||||||
|
| | cursor + region drag/resize ||
|
||||||
|
| +---------------------------------+|
|
||||||
|
| +- model tabs [EAT_LARGE][HuBERT]+|
|
||||||
|
| | Time | End | Score ||
|
||||||
|
| | 1:23 | 1:31 | 0.92 ||
|
||||||
|
| | 3:45 | 3:53 | 0.87 ||
|
||||||
|
| | [Add Negative] [Export] [Continue]|
|
||||||
|
| +---------------------------------+|
|
||||||
|
+-------------------------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
Widget size: ~640x500px minimum, resizable via LiteGraph.
|
||||||
|
|
||||||
|
**Blocking mechanism**: The node's `run()` method blocks on a server-side event/queue. The frontend signals completion via `POST /8cut/review_done/{node_id}`, which unblocks `run()` and returns the edited `SCAN_REGIONS`.
|
||||||
|
|
||||||
|
### TrainModel
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `profile` (STRING combo), `positive_folder` (STRING combo), `negative_folder` (STRING combo, optional), `embed_model` (STRING combo from `_EMBED_MODELS`), `use_hard_negatives` (BOOL) |
|
||||||
|
| **Output** | `SCAN_MODEL` |
|
||||||
|
| **Logic** | Queries `db.get_training_data()` to assemble `video_infos`, calls `core.audio_scan.train_classifier()`. Saves to `models/{profile}_{embed_model}.joblib` with version rotation. Progress via ComfyUI progress bar. |
|
||||||
|
|
||||||
|
### ExportClips
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|---|---|
|
||||||
|
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS`, `output_folder` (STRING), `short_side` (INT), `format` (combo MP4/WEBM), `spread` (FLOAT), `clip_count` (INT), `fuse_gap` (FLOAT) |
|
||||||
|
| **Output** | exported file paths (list) |
|
||||||
|
| **Logic** | Region fusion via `_build_export_spans()`, then `core.ffmpeg.build_ffmpeg_command()` per clip. Records each clip in DB via `db.add()`. |
|
||||||
|
|
||||||
|
### Typical workflow
|
||||||
|
|
||||||
|
```
|
||||||
|
[LoadVideo] --> [AudioScan] --> [VideoReview] --> [ExportClips]
|
||||||
|
^
|
||||||
|
[TrainModel]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training loop (hard negatives round-trip)
|
||||||
|
|
||||||
|
1. Scan with existing model -> regions in VideoReview
|
||||||
|
2. Review -> mark false positives as negatives (DB)
|
||||||
|
3. Train -> new model uses hard negatives
|
||||||
|
4. Rescan -> better results
|
||||||
|
5. Repeat
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Routes
|
||||||
|
|
||||||
|
### Video serving
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/video` | GET | Serve raw video file via `web.FileResponse`. Query param: `path`. Browser decodes mp4/h264 natively — key for remote streaming. |
|
||||||
|
| `/8cut/video_transcode` | GET | Fallback: transcode to webm on-the-fly via ffmpeg `StreamResponse` for browser-incompatible formats (some MKV, odd codecs). |
|
||||||
|
|
||||||
|
### Region editing (from VideoReview widget)
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/toggle_region` | POST | `toggle_scan_result_disabled()` |
|
||||||
|
| `/8cut/resize_region` | POST | `update_scan_result()` |
|
||||||
|
| `/8cut/delete_region` | POST | `delete_scan_result()` |
|
||||||
|
| `/8cut/add_negatives` | POST | `add_hard_negatives()` |
|
||||||
|
| `/8cut/scan_versions` | GET | `get_scan_versions()` |
|
||||||
|
| `/8cut/review_done/{node_id}` | POST | Unblock the VideoReview node's `run()`, pass final regions |
|
||||||
|
|
||||||
|
### Data queries (for combo widget population)
|
||||||
|
|
||||||
|
| Route | Method | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `/8cut/profiles` | GET | `db.get_profiles()` |
|
||||||
|
| `/8cut/export_folders` | GET | `db.get_export_folders()` |
|
||||||
|
| `/8cut/models` | GET | List available `.joblib` models |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Frontend JS Widget (`web/js/video_review.js`)
|
||||||
|
|
||||||
|
Registered via `app.registerExtension()`. Hooks into the VideoReview node's `onNodeCreated` and `onExecuted` callbacks.
|
||||||
|
|
||||||
|
### Components
|
||||||
|
|
||||||
|
1. **Video player** — HTML5 `<video>` element, src pointed at `/8cut/video?path=...`
|
||||||
|
2. **Timeline** — `<canvas>` overlay below the video. Renders:
|
||||||
|
- Scan region rectangles (color-coded by score, red for negatives, gray for disabled)
|
||||||
|
- Cursor line (click to seek)
|
||||||
|
- Drag handles on region edges (resize)
|
||||||
|
- Waveform (optional, fetched via separate route)
|
||||||
|
3. **Region table** — HTML table with model tabs. Click row to seek. Columns: Time, End, Score.
|
||||||
|
4. **Action buttons** — Add Negative, Export, Continue
|
||||||
|
5. **Version combo** — dropdown to switch scan history versions
|
||||||
|
|
||||||
|
### Interaction flow
|
||||||
|
|
||||||
|
- Widget activates when `onExecuted` fires with scan regions
|
||||||
|
- User clicks/drags timeline, edits regions, marks negatives
|
||||||
|
- Each edit hits an API route (immediate DB persistence)
|
||||||
|
- "Continue" sends `POST /8cut/review_done/{node_id}` with final region state
|
||||||
|
- Node's `run()` unblocks, passes `SCAN_REGIONS` downstream
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## DB
|
||||||
|
|
||||||
|
Separate SQLite DB at `ComfyUI-8cut/data/8cut.db`. Uses the existing `ProcessedDB` class unchanged — same schema, same migration code. Users can copy their existing `~/.8cut.db` to carry over scan history, training data, and hard negatives.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
Same as 8-cut's `requirements.txt` minus PyQt6/python-mpv:
|
||||||
|
- `torch`, `torchaudio`, `torchvision` (from CUDA index)
|
||||||
|
- `transformers>=4.30,<5.0`, `timm>=0.9`
|
||||||
|
- `librosa`, `scikit-learn`, `joblib`, `soundfile`, `numpy`
|
||||||
|
- `ultralytics` (YOLO tracking)
|
||||||
|
|
||||||
|
ComfyUI already provides torch. The node pack's install script just needs the audio/ML extras.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Priority
|
||||||
|
|
||||||
|
1. **Node pack skeleton** — structure, `__init__.py`, custom types, API routes for video serving
|
||||||
|
2. **LoadVideo + AudioScan** — headless nodes, no widget needed yet
|
||||||
|
3. **VideoReview widget (minimal)** — video player + static region display + Continue button
|
||||||
|
4. **VideoReview interactivity** — timeline click/drag, region editing, negative marking
|
||||||
|
5. **TrainModel + ExportClips** — complete the pipeline
|
||||||
|
6. **Polish** — version history, waveform overlay, transcode fallback
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
# Scan History & Hard Negative Management — Final Design
|
||||||
|
|
||||||
|
Date: 2026-04-19 (implemented on `feat/training-ui`)
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
1. Keep scan result history per `(file, model)` so users can track classifier improvement across training iterations
|
||||||
|
2. Make hard negatives manageable — viewable, removable, and optionally disabled per training run
|
||||||
|
3. Fix latent bug: `get_export_folders()` doesn't filter by `scan_export`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Ghost Folder Fix
|
||||||
|
|
||||||
|
### Bug
|
||||||
|
|
||||||
|
`get_export_folders()` queried all `output_path` rows without filtering `scan_export`. Folders that only contained scan-exported clips appeared in training dropdowns with 0 clips.
|
||||||
|
|
||||||
|
### Implementation (`core/db.py`)
|
||||||
|
|
||||||
|
**`get_export_folders(profile, include_scan_exports=False)`** — new parameter. When `False` (default), the SQL query adds `AND scan_export = 0` to exclude scan-only folders. The `get_training_stats()` method passes this through and also filters its return dict to remove folders with 0 clips:
|
||||||
|
|
||||||
|
```python
|
||||||
|
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test
|
||||||
|
|
||||||
|
`tests/test_db.py::test_export_folders_excludes_scan_exports` — verifies scan-only folders are excluded by default and included when `include_scan_exports=True`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Scan Result History
|
||||||
|
|
||||||
|
### Schema
|
||||||
|
|
||||||
|
Added column to `scan_results`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
scan_timestamp TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
All rows from the same scan share one timestamp string with **microsecond precision** (`%Y%m%d_%H%M%S_%f`, e.g. `"20260419_143022_123456"`). Microsecond precision prevents version collisions on fast successive scans.
|
||||||
|
|
||||||
|
Migration adds the column via `ALTER TABLE` for existing databases. Legacy rows keep `scan_timestamp = ''`.
|
||||||
|
|
||||||
|
### DB methods (`core/db.py`)
|
||||||
|
|
||||||
|
**`save_scan_results(filename, profile, model, regions, max_versions=5)`**
|
||||||
|
1. Inserts new rows with current microsecond-precision timestamp
|
||||||
|
2. Counts distinct timestamps for this `(filename, profile, model)`
|
||||||
|
3. Prunes oldest timestamps beyond `max_versions`
|
||||||
|
|
||||||
|
No more DELETE-then-INSERT — all versions coexist in the table.
|
||||||
|
|
||||||
|
**`get_scan_versions(filename, profile, model)`**
|
||||||
|
Returns `[{timestamp, count, max_score}, ...]` ordered newest first. Filters `scan_timestamp != ''` so legacy rows don't appear as named versions.
|
||||||
|
|
||||||
|
**`get_scan_results(filename, profile, scan_timestamp=None)`**
|
||||||
|
- With `scan_timestamp`: returns rows matching that exact version
|
||||||
|
- Without (default): uses `INNER JOIN` subquery with `MAX(scan_timestamp)` per model to return only the latest version. Legacy rows (empty timestamp) sort before any real timestamp, so they're returned when no versioned scans exist.
|
||||||
|
|
||||||
|
### UI (`main.py` — `ScanResultsPanel`)
|
||||||
|
|
||||||
|
Each model tab wraps its `QTableWidget` in a container `QWidget` with a `QComboBox` for version selection:
|
||||||
|
|
||||||
|
```
|
||||||
|
container (QWidget)
|
||||||
|
├── cmb_version (QComboBox) — hidden when ≤ 1 version
|
||||||
|
└── table (QTableWidget)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Helper methods** unwrap this container:
|
||||||
|
- `_current_table()` — returns `QTableWidget` from active tab (handles both raw table and container)
|
||||||
|
- `_tab_table(index)` — same by tab index
|
||||||
|
|
||||||
|
**Version combo** is populated by `_populate_version_combos()` after every `load_for_file()` and `add_scan_results()` call. Labels use `datetime.strptime` parsing with try/except fallback for robustness:
|
||||||
|
|
||||||
|
```
|
||||||
|
2026-04-19 14:30 (12 regions, best: 0.95)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Version switching** via `_on_version_changed(model, idx)`:
|
||||||
|
1. Reads `scan_timestamp` from combo's `userData`
|
||||||
|
2. Calls `get_scan_results(filename, profile, scan_timestamp=ts)`
|
||||||
|
3. Repopulates the table in-place
|
||||||
|
4. **Clears the undo stack** — stale undo entries from a different version would corrupt data
|
||||||
|
5. Emits `regions_edited` to refresh the timeline
|
||||||
|
|
||||||
|
**Tab switch** connects `tab_changed` signal to `_on_scan_regions_edited` (not just `_update_scan_export_count`), so the timeline updates scan regions when switching model tabs.
|
||||||
|
|
||||||
|
### Cache interaction
|
||||||
|
|
||||||
|
Embedding cache is per `(file, model)` and doesn't change across scans. History stores classified regions (start, end, score), not embeddings.
|
||||||
|
|
||||||
|
### Test
|
||||||
|
|
||||||
|
`tests/test_db.py::test_scan_result_history` — saves 3 versions, verifies counts, ordering, and latest-by-default behavior.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Hard Negative Management
|
||||||
|
|
||||||
|
### Schema
|
||||||
|
|
||||||
|
Added column to `hard_negatives`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
source_model TEXT NOT NULL DEFAULT ''
|
||||||
|
```
|
||||||
|
|
||||||
|
Migration adds the column via `ALTER TABLE` for existing databases.
|
||||||
|
|
||||||
|
### DB methods (`core/db.py`)
|
||||||
|
|
||||||
|
**`add_hard_negatives(filename, profile, times, source_path="", source_model="")`** — now stores which embedding model produced the scan that led to the negative marking.
|
||||||
|
|
||||||
|
**`get_hard_negatives(profile)`** — returns all rows as `[{id, filename, start_time, source_path, source_model}, ...]` for the management dialog.
|
||||||
|
|
||||||
|
**`delete_hard_negatives_by_ids(ids)`** — bulk delete by row IDs.
|
||||||
|
|
||||||
|
**`get_training_data(..., use_hard_negatives=True)`** — new parameter. When `False`, the hard negatives query is skipped entirely. Non-destructive — negatives remain in DB.
|
||||||
|
|
||||||
|
### Source model tracking (`main.py`)
|
||||||
|
|
||||||
|
`_on_scan_negatives()` now passes `source_model=self._scan_panel.current_model_name()` when marking negatives from scan results. `current_model_name()` extracts the model name from the active tab text (stripping the count suffix).
|
||||||
|
|
||||||
|
### Training toggle (`main.py` — `TrainDialog`)
|
||||||
|
|
||||||
|
Checkbox **"Use hard negatives in training"** (default checked) with "Manage..." button in an HBox layout. The toggle:
|
||||||
|
- Updates live training stats preview via debounced `_update_stats()`
|
||||||
|
- Passes `use_hard_negatives` through `_open_train_dialog()` to `get_training_data()`
|
||||||
|
|
||||||
|
### Management dialog (`main.py` — `HardNegativesDialog`)
|
||||||
|
|
||||||
|
Accessible from TrainDialog's "Manage..." button. Features:
|
||||||
|
|
||||||
|
| Component | Details |
|
||||||
|
|-----------|---------|
|
||||||
|
| **Filter combo** | `(all)` + each distinct `source_model` found in data |
|
||||||
|
| **Summary label** | `<b>N</b> hard negatives` |
|
||||||
|
| **Table** | File, Time (`{:.1f}s`), Source Model, hidden ID column |
|
||||||
|
| **Delete Selected** | Multi-select aware, skips hidden (filtered) rows |
|
||||||
|
| **Clear All** | **Filter-aware**: if a model filter is active, only deletes negatives for that model with an appropriate confirmation message. If `(all)`, deletes everything. |
|
||||||
|
| **Close** | Closes dialog, triggers stats refresh in parent TrainDialog |
|
||||||
|
|
||||||
|
`blockSignals(True)` guards prevent spurious filter callbacks during `_load()` repopulation.
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- `test_hard_negatives_source_model` — verifies source_model stored and retrieved
|
||||||
|
- `test_training_data_skips_hard_negatives` — verifies `use_hard_negatives=False` excludes them
|
||||||
|
- `test_delete_hard_negatives_by_ids` — verifies bulk deletion by ID
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Runtime Fixes (discovered during testing)
|
||||||
|
|
||||||
|
### EAT/torchvision ABI mismatch
|
||||||
|
|
||||||
|
**Problem:** `torchvision` installed from PyPI (CPU build) was incompatible with `torch` from CUDA wheel index, causing `operator torchvision::nms does not exist`.
|
||||||
|
|
||||||
|
**Fix:** Added `torchvision` to the explicit torch install line in both setup scripts:
|
||||||
|
```bash
|
||||||
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
```
|
||||||
|
|
||||||
|
Also added `--extra-index-url "$TORCH_INDEX"` to the `pip install -r requirements.txt` line to prevent transitive dependencies (timm, ultralytics) from pulling CPU-only torch packages.
|
||||||
|
|
||||||
|
Applied to: `setup_env.sh` (both conda and venv paths), `setup-windows.ps1`.
|
||||||
|
|
||||||
|
### EAT / transformers 5.x incompatibility
|
||||||
|
|
||||||
|
**Problem:** transformers 5.x broke EAT's remote model code (`'EATModel' object has no attribute 'all_tied_weights_keys'`).
|
||||||
|
|
||||||
|
**Fix:** Pinned `transformers>=4.30,<5.0` in `requirements.txt`.
|
||||||
|
|
||||||
|
### NumPy non-writable array warning
|
||||||
|
|
||||||
|
**Problem:** Cached HuBERT/EAT embeddings loaded from disk are read-only numpy arrays. `torch.from_numpy()` on a non-writable array triggers a deprecation warning.
|
||||||
|
|
||||||
|
**Fix:** In `core/audio_scan.py`, changed EAT preprocessing to copy the array:
|
||||||
|
```python
|
||||||
|
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timeline not updating on tab switch
|
||||||
|
|
||||||
|
**Problem:** Switching model tabs in the scan results panel didn't refresh the timeline's highlighted regions because `tab_changed` was only connected to `_update_scan_export_count`.
|
||||||
|
|
||||||
|
**Fix:** Connected `tab_changed` to `_on_scan_regions_edited` instead, which handles both timeline refresh and export count update.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## File Summary
|
||||||
|
|
||||||
|
| File | Changes |
|
||||||
|
|------|---------|
|
||||||
|
| `core/db.py` | Schema migrations, `get_export_folders` filter, versioned `save_scan_results`, `get_scan_versions`, version-aware `get_scan_results`, `add_hard_negatives` with `source_model`, `get_hard_negatives`, `delete_hard_negatives_by_ids`, `get_training_data` with `use_hard_negatives` |
|
||||||
|
| `main.py` | `HardNegativesDialog` class, `TrainDialog` hard neg toggle + manage button, `ScanResultsPanel` container/combo architecture, version combo population and switching, `current_model_name()`, tab-switch timeline fix |
|
||||||
|
| `core/audio_scan.py` | `np.array(chunk)` copy for read-only numpy arrays in EAT preprocessing |
|
||||||
|
| `requirements.txt` | `transformers>=4.30,<5.0` pin |
|
||||||
|
| `setup_env.sh` | `torchvision` in torch install, `--extra-index-url` on requirements install |
|
||||||
|
| `setup-windows.ps1` | `torchvision` in torch install, `--extra-index-url` on requirements install, removed skip-if-exists guard |
|
||||||
|
| `tests/test_db.py` | 5 tests covering all DB-layer changes |
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
# Scan History & Hard Negative Management — Implementation Log
|
||||||
|
|
||||||
|
> All tasks complete. See the design doc for the final specification.
|
||||||
|
|
||||||
|
**Branch:** `feat/training-ui`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Fix ghost folder bug in get_export_folders -- DONE
|
||||||
|
|
||||||
|
**Commit:** `2614a76 fix: get_export_folders respects scan_export filter`
|
||||||
|
|
||||||
|
- `core/db.py` — `get_export_folders(profile, include_scan_exports=False)`: filters `scan_export = 0` by default
|
||||||
|
- `core/db.py` — `get_training_stats()`: passes `include_scan_exports` through, filters out 0-clip folders
|
||||||
|
- `tests/test_db.py` — `test_export_folders_excludes_scan_exports`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Scan result history — schema and DB methods -- DONE
|
||||||
|
|
||||||
|
**Commit:** `4fb2ae1 feat: scan result history — keep N versions per (file, model)`
|
||||||
|
|
||||||
|
- `core/db.py` — added `scan_timestamp TEXT NOT NULL DEFAULT ''` column with migration
|
||||||
|
- `core/db.py` — `save_scan_results()`: versioned insert with microsecond-precision timestamp (`%Y%m%d_%H%M%S_%f`), auto-prunes beyond `max_versions=5`
|
||||||
|
- `core/db.py` — `get_scan_versions()`: returns `[{timestamp, count, max_score}, ...]` newest first
|
||||||
|
- `core/db.py` — `get_scan_results(scan_timestamp=None)`: `INNER JOIN` subquery with `MAX(scan_timestamp)` for latest-by-default
|
||||||
|
- `tests/test_db.py` — `test_scan_result_history`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Scan history UI — version selector in ScanResultsPanel -- DONE
|
||||||
|
|
||||||
|
**Commit:** `8ed9fbf feat: scan version selector in results panel`
|
||||||
|
|
||||||
|
- `main.py` — `_add_tab()`: wraps table in container `QWidget` with version `QComboBox` (hidden when ≤ 1 version)
|
||||||
|
- `main.py` — `_current_table()` / `_tab_table(idx)`: unwrap container to get `QTableWidget`
|
||||||
|
- `main.py` — `_populate_version_combos()`: queries `get_scan_versions()`, formats labels with `datetime.strptime` + try/except fallback
|
||||||
|
- `main.py` — `_on_version_changed()`: reloads table from specific version, clears undo stack, emits `regions_edited`
|
||||||
|
- `main.py` — `current_model_name()`: extracts model name from tab text
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Hard negatives — schema and training toggle -- DONE
|
||||||
|
|
||||||
|
**Commit:** `edc5784 feat: hard negative source_model tracking, training toggle`
|
||||||
|
|
||||||
|
- `core/db.py` — added `source_model TEXT NOT NULL DEFAULT ''` column to `hard_negatives` with migration
|
||||||
|
- `core/db.py` — `add_hard_negatives(source_model="")`: stores originating model
|
||||||
|
- `core/db.py` — `get_hard_negatives(profile)`: returns full rows as list of dicts
|
||||||
|
- `core/db.py` — `delete_hard_negatives_by_ids(ids)`: bulk delete by row IDs
|
||||||
|
- `core/db.py` — `get_training_data(use_hard_negatives=True)`: conditionally skips hard negatives query
|
||||||
|
- `main.py` — `TrainDialog`: "Use hard negatives" checkbox + "Manage..." button in HBox layout
|
||||||
|
- `main.py` — `_on_scan_negatives()`: passes `source_model=self._scan_panel.current_model_name()`
|
||||||
|
- `tests/test_db.py` — `test_hard_negatives_source_model`, `test_training_data_skips_hard_negatives`, `test_delete_hard_negatives_by_ids`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Hard negatives management dialog -- DONE
|
||||||
|
|
||||||
|
**Commit:** `e6db83f feat: hard negatives management dialog with filter and bulk delete`
|
||||||
|
|
||||||
|
- `main.py` — `HardNegativesDialog`: table with File/Time/Source Model/hidden ID columns, model filter combo, delete selected, filter-aware clear all, close button
|
||||||
|
- Filter-aware "Clear All": respects active model filter, shows appropriate confirmation message
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Code review fixes -- DONE
|
||||||
|
|
||||||
|
**Commit:** `5d45b8d fix: timestamp collision, undo stack invalidation, label parsing, filter-aware clear`
|
||||||
|
|
||||||
|
Four issues found during code review:
|
||||||
|
1. **Timestamp collision** — second-precision timestamps could merge versions on sub-second calls. Fixed with microsecond precision `%f`
|
||||||
|
2. **Undo stack invalidation** — switching scan versions left stale undo entries. Fixed by clearing undo stack in `_on_version_changed()`
|
||||||
|
3. **Timestamp label fragile parsing** — hard-coded string slicing. Fixed with `datetime.strptime` + try/except fallback
|
||||||
|
4. **Clear All ignoring filter** — deleted all negatives regardless of model filter. Fixed to respect active filter
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Runtime fixes (discovered during manual testing)
|
||||||
|
|
||||||
|
| Commit | Fix |
|
||||||
|
|--------|-----|
|
||||||
|
| `a3c657c` | Install `torchvision` from CUDA wheel index (was pulling CPU build from PyPI) |
|
||||||
|
| `3c3b1d7` | Remove "skip if torch exists" guard in Windows setup so re-runs fix broken envs |
|
||||||
|
| `fd043f4` | Pin `transformers>=4.30,<5.0` — EAT remote model code incompatible with transformers 5.x |
|
||||||
|
| `7d6fee9` | Copy read-only numpy array before `torch.from_numpy()` in EAT preprocessing |
|
||||||
|
| `bd345ab` | Connect `tab_changed` to `_on_scan_regions_edited` so timeline refreshes on tab switch |
|
||||||
|
| `d8b3972` | Add `--extra-index-url` to `pip install -r requirements.txt` in both setup scripts |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Test results
|
||||||
|
|
||||||
|
All 68 tests pass (5 new DB tests + 63 existing).
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
# Main Window UI Restructure — Design
|
||||||
|
|
||||||
|
**Goal:** Reorganize the `MainWindow` UI in `main.py` from a flat wall of ~50 always-visible controls into a legible, grouped layout — a menu bar for rare actions, a tabbed control deck for settings, an always-visible transport bar, and a real status bar — plus a visual polish pass. Keep every existing behavior, shortcut, and mouse interaction working.
|
||||||
|
|
||||||
|
**Scope:** Reorganization **and** visual polish. **Not** an interaction-model change — single-key shortcuts, timeline mouse overloading, and the export/scan logic are untouched.
|
||||||
|
|
||||||
|
**Audience:** Single power user. Optimize for density and speed. The goal is *order, not hiding*: keep everything fast to reach; push only genuinely rare actions into menus.
|
||||||
|
|
||||||
|
**Runs in:** Python/Qt client (`main.py`), `MainWindow` class only. No `core/` changes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Problem (from audit)
|
||||||
|
|
||||||
|
- **No information architecture.** No menu bar, no toolbar; status bar explicitly disabled (`setStatusBar(None)`, main.py:4440). Every function is a permanently-visible widget at equal weight.
|
||||||
|
- **`settings_row` overloaded** (main.py:4334–4370): 24 widgets in one non-wrapping `QHBoxLayout` spanning three unrelated domains (encode/clip params, export variants, audio-scan ML). Needs >1500px; window opens at 1100px.
|
||||||
|
- **Stranded controls** — e.g. the workers spinbox sits between Cancel and Delete in the transport row (main.py:4316).
|
||||||
|
- **Weak feedback** — only an 11px `#888` status label at the far-right end of the overflowing settings row (main.py:4364).
|
||||||
|
- **Flat visual hierarchy** — single Fusion stylesheet, scattered inline `setStyleSheet` state swaps, no primary/secondary distinction, no grouping.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Chosen approach: Tabbed control deck
|
||||||
|
|
||||||
|
The 3-pane horizontal splitter (Queue · Center · Scan results) is unchanged. The center column is restructured:
|
||||||
|
|
||||||
|
```
|
||||||
|
╔═ File Edit Scan View Help ═══════════════════ Profile:[default▾] [?] ╗ menu bar (+ corner widgets)
|
||||||
|
║ ┌Queue──┐ │ current_file.mp4 │ ┌ Scan results ─────┐ ║
|
||||||
|
║ │+Open │ │ ┌──────────────────────────────────────┐ │ │ [model tabs] │ ║
|
||||||
|
║ │filter │ │ │ VIDEO (mpv) │ │ │ version▾ │ ║
|
||||||
|
║ │┌List┬+┐│ │ │ │ │ │ start end score │ ║
|
||||||
|
║ ││f1 ││ │ │ └──────────────────────────────────────┘ │ │ ... │ ║
|
||||||
|
║ ││f2 ││ │ │ [════════════ timeline ════════════════] │ │ │ ║
|
||||||
|
║ │└────┘ ││ │ [════════════ crop bar ════════════════] │ │ [Neg] [Export] │ ║
|
||||||
|
║ └───────┘ │ ┌─ transport (always visible) ──────────┐ │ └───────────────────┘ ║
|
||||||
|
║ │ │▶ ⏸ x2 x4 🔒 --/-- ··· [Export] +₁+₂ Cancel Delete│ ║
|
||||||
|
║ │ ├─[ Export ]─[ Crop & Track ]─[ Scan ]──┤ ← control deck (tabs) ║
|
||||||
|
║ │ │ (controls for the active tab here) │ ║
|
||||||
|
║ │ └───────────────────────────────────────┘ ║
|
||||||
|
╠═══════════════════════════════════════════════════════════════════════════════╣
|
||||||
|
║ Ready. current file · profile: default · 8 wk ║ status bar
|
||||||
|
╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why tabbed deck:** Replaces the three stacked rows with a compact tab strip. The transport bar (most-used controls) stays always visible above the tabs; settings group by concern behind tabs. Trade-off accepted: viewing Scan + Export controls simultaneously costs a tab switch.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Control mapping
|
||||||
|
|
||||||
|
Every current control has an explicit home; nothing is removed.
|
||||||
|
|
||||||
|
### Menu bar (rare / batch / management)
|
||||||
|
|
||||||
|
| Menu | Items |
|
||||||
|
|------|-------|
|
||||||
|
| **File** | Open Files… · Set export folder… · Quit |
|
||||||
|
| **Edit** | Undo *(Ctrl+Z → `_scan_panel.undo`)* · Subprofiles ▸ (Add… / Remove…) |
|
||||||
|
| **Scan** | Scan current · Auto-export · Scan All… · Train classifier… |
|
||||||
|
| **View** | Review mode ✓ · Subcategory markers ▸ · Hide exported ✓ · Show hidden ✓ |
|
||||||
|
| **Help** | Keyboard shortcuts *(? / F1)* · What's new · About |
|
||||||
|
| *corner (right)* | Profile ▾ · `?` |
|
||||||
|
|
||||||
|
*Hard Negatives and Dataset Stats remain inside the Train dialog (main.py:682, 762) — not surfaced separately. Profile new/delete remains driven by the profile combo's `activated` handler.*
|
||||||
|
|
||||||
|
### Transport bar (always visible — playback + one-press export actions)
|
||||||
|
|
||||||
|
`▶ Play · ⏸ Pause · x2 · x4 · 🔒 Lock · --/-- time · ⟨stretch⟩ · next-preview · **Export** · subprofile buttons ₁₂… · Cancel · Delete`
|
||||||
|
|
||||||
|
### Control deck — Export tab
|
||||||
|
`Label · Category · Name · Folder + browse · Format · HW encode · Resize · Duration · Clips · Spread · Workers · Re-export`
|
||||||
|
|
||||||
|
### Control deck — Crop & Track tab
|
||||||
|
`Portrait ratio · 1 random portrait · 1 random square · Track subject`
|
||||||
|
|
||||||
|
### Control deck — Scan tab
|
||||||
|
`Scan model ▾ · ⏲ history · Scan · Auto · Speech · Review · Fuse · Threshold`
|
||||||
|
|
||||||
|
### Left pane (Queue) — unchanged
|
||||||
|
`+ Open · filter · Hide exported · Show hidden · list tabs (tabbed / side-by-side)`
|
||||||
|
|
||||||
|
### Right pane (Scan results) — unchanged structurally
|
||||||
|
|
||||||
|
### Decisions
|
||||||
|
- **Train** → Scan menu only (no deck button).
|
||||||
|
- **Subcategory markers ("Sub")** → View menu submenu (off the deck).
|
||||||
|
- Items appearing in both a menu and a visible control (Hide exported, Review, Scan, Auto) share one handler and stay synced.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Status bar
|
||||||
|
|
||||||
|
Restores `QStatusBar` (removes `setStatusBar(None)`):
|
||||||
|
- **Left**: transient feedback — `Exporting 2/3…`, `Scan complete · 14 regions`, `Ready.` — with an optional inline `QProgressBar` for export/scan runs. Replaces `_lbl_status` and the `_status_timer` clear logic.
|
||||||
|
- **Right (permanent widget)**: `current file · profile: <name> · <n> workers`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Visual polish
|
||||||
|
|
||||||
|
Extends the existing dark Fusion theme — no theme change.
|
||||||
|
|
||||||
|
1. **Aligned tab layouts** — each deck tab uses `QFormLayout`/grid so `label : control` pairs align in columns (biggest legibility win vs. today's ragged horizontal runs).
|
||||||
|
2. **Primary/secondary button weight** — **Export** gets an accent style (blue, reusing `#3a6ea8`); Cancel/Delete read as secondary/destructive. The existing **red Export = "armed to overwrite"** state (main.py:5403) is preserved as a distinct state layered on top.
|
||||||
|
3. **Consistent toggle states** — x2 / x4 / 🔒 Lock / Review are checkable; one global `:checked` style replaces Lock's ad-hoc inline `#4a3000` swap (main.py:5705).
|
||||||
|
4. **Spacing rhythm** — uniform margins/spacing; **fixed deck height** (= tallest tab) so the video never resizes on tab switch.
|
||||||
|
5. **Label cleanup** — de-abbreviate where cheap (`Thr→Threshold`, `Dur→Duration`); replace cryptic `⏲` with a clearer history affordance.
|
||||||
|
6. **One stylesheet block** — fold scattered inline `setStyleSheet` calls into the central sheet (tabs, separators, status bar, toggles, primary button); keep per-widget overrides only for genuine state changes (overwrite-armed Export).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation notes & risks
|
||||||
|
|
||||||
|
- **Preserve all signal wiring.** Controls are re-parented into new layouts, but every existing `connect()` and the controls' object identities are kept — this is a layout move, not a rewrite of handlers.
|
||||||
|
- **Preserve all shortcuts.** The `QShortcut` block (main.py:4450–4483) and `_KeyFilter` focus suppression are untouched. Menu items reuse the same handler methods and may display the matching shortcut text.
|
||||||
|
- **Fixed deck height** prevents video-area jump when switching tabs.
|
||||||
|
- **Synced menu/button state** — checkable menu items (Review, Hide exported) and their visible toggles must reflect each other; route both through the existing handler and update both widgets.
|
||||||
|
- **Profile combo** moves to a menu-bar corner widget but keeps its existing `activated` → new/delete/switch logic intact.
|
||||||
|
- Risk: re-parenting a large `__init__` is error-prone. Mitigate by moving controls in small, independently-runnable stages (menu bar → status bar → deck tabs → transport bar → polish), launching the app after each.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What this does NOT do
|
||||||
|
|
||||||
|
- No change to export, scan, tracking, or DB logic — `core/` untouched.
|
||||||
|
- No change to keyboard shortcuts or timeline mouse interactions.
|
||||||
|
- No theme change — stays dark Fusion.
|
||||||
|
- No new features — every control already exists; this is rehousing + polish.
|
||||||
|
- No change to the Queue or Scan-results panes' internal structure.
|
||||||
@@ -0,0 +1,547 @@
|
|||||||
|
# Main Window UI Restructure — Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Re-house `MainWindow`'s ~50 flat controls into a menu bar (rare actions), an always-visible transport bar, a 3-tab control deck (Export / Crop & Track / Scan), and a real status bar — then a visual-polish pass — without changing any behavior, shortcut, or `core/` logic.
|
||||||
|
|
||||||
|
**Architecture:** Pure layout reorganization inside `main.py`'s `MainWindow`. Existing widget objects and every `connect()` are **preserved and re-parented**, not recreated. The monster `__init__` is incrementally broken into `_build_*` helper methods (stays single-file — matches the project's architecture). Companion design doc: `docs/plans/2026-06-13-ui-restructure-design.md`.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.11+, PyQt6, pytest. App entry: `main.py`; launch via `./8cut.sh`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Conventions for every task
|
||||||
|
|
||||||
|
- **Line references drift** as edits land. Always locate by the named symbol (method/variable), not the line number alone. Numbers are the *starting* anchors as of this plan.
|
||||||
|
- **Authoritative verification is a manual launch.** After each task, run `./8cut.sh`, load a video, and confirm the task's controls work AND prior behavior is intact (play, scrub, export, scan). Use the `verify` skill for structured manual checks.
|
||||||
|
- **Structure test is the safety net.** `tests/test_ui_structure.py` (built in Task 0.2) constructs `MainWindow` and asserts containment invariants. It **skips gracefully** if construction fails (e.g. no GL for `MpvWidget` in headless CI), so it never blocks `core/` tests. Run with a display: `pytest tests/test_ui_structure.py -v`.
|
||||||
|
- **Commit after every task.** Small, reversible commits. Commit message convention matches the repo (`feat:`/`fix:`/`refactor:`/`change:`).
|
||||||
|
- **Do not touch** `core/`, export/scan/tracking logic, the `QShortcut` block (around main.py:4450–4483), `_KeyFilter`, or `TimelineWidget` mouse handling.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 0 — Branch & safety net
|
||||||
|
|
||||||
|
### Task 0.1: Create a working branch
|
||||||
|
|
||||||
|
**Step 1:** Confirm clean intent and branch off `master`:
|
||||||
|
```bash
|
||||||
|
git switch -c ui-restructure
|
||||||
|
```
|
||||||
|
**Step 2:** Verify: `git branch --show-current` → `ui-restructure`.
|
||||||
|
(The repo has pre-existing untracked/modified files; leave them alone — they are not part of this work.)
|
||||||
|
|
||||||
|
### Task 0.2: Add the structure-test safety net
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_ui_structure.py`
|
||||||
|
|
||||||
|
**Step 1: Write the test harness + baseline invariant**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# A real platform is needed because MpvWidget creates a GL context.
|
||||||
|
# If construction fails for any environment reason, skip — this test is a
|
||||||
|
# best-effort structural net, not a gate on core/ tests.
|
||||||
|
pytestmark = pytest.mark.gui
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def app():
|
||||||
|
from PyQt6.QtWidgets import QApplication
|
||||||
|
inst = QApplication.instance() or QApplication([])
|
||||||
|
yield inst
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def win(app):
|
||||||
|
try:
|
||||||
|
from main import MainWindow
|
||||||
|
w = MainWindow()
|
||||||
|
except Exception as e: # GL/mpv/display unavailable, etc.
|
||||||
|
pytest.skip(f"MainWindow could not be constructed here: {e}")
|
||||||
|
yield w
|
||||||
|
w.close()
|
||||||
|
w.deleteLater()
|
||||||
|
|
||||||
|
|
||||||
|
def _descendant_object_names(widget):
|
||||||
|
"""All objectNames in a widget's child tree (for containment asserts)."""
|
||||||
|
return {c.objectName() for c in widget.findChildren(object) if c.objectName()}
|
||||||
|
|
||||||
|
|
||||||
|
def test_window_constructs(win):
|
||||||
|
assert win.windowTitle() == "8-cut"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Run it**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_ui_structure.py -v`
|
||||||
|
Expected: `test_window_constructs` PASSES (with a display) or SKIPS (headless). Either is acceptable — it must not ERROR.
|
||||||
|
|
||||||
|
**Step 3:** Register the `gui` marker to silence warnings.
|
||||||
|
|
||||||
|
Modify `conftest.py` — append:
|
||||||
|
```python
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line("markers", "gui: constructs Qt widgets; needs a display")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Confirm core tests still pass**
|
||||||
|
|
||||||
|
Run: `pytest tests/test_utils.py tests/test_db.py -q`
|
||||||
|
Expected: PASS (unchanged).
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
```bash
|
||||||
|
git add tests/test_ui_structure.py conftest.py
|
||||||
|
git commit -m "test: add MainWindow structure smoke test (skips headless)"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 1 — Menu bar
|
||||||
|
|
||||||
|
Add a `QMenuBar` whose actions reuse existing handler methods. Move the profile combo and `?` button into menu-bar corner widgets. Keep the original buttons that also live elsewhere (Scan, Auto) — menus and buttons share handlers.
|
||||||
|
|
||||||
|
### Task 1.1: Extract a `_build_menubar()` and add the five menus
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `main.py` `MainWindow.__init__` (call site) and add method `_build_menubar`
|
||||||
|
|
||||||
|
**Step 1:** Add the method (place near other `_build`/setup helpers, e.g. after `__init__`). Wire each action to the **existing** handler method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _build_menubar(self) -> None:
|
||||||
|
from PyQt6.QtGui import QAction
|
||||||
|
mb = self.menuBar()
|
||||||
|
|
||||||
|
# File
|
||||||
|
m_file = mb.addMenu("&File")
|
||||||
|
m_file.addAction("Open Files…", self._on_open_files)
|
||||||
|
m_file.addAction("Set export folder…", self._pick_folder)
|
||||||
|
m_file.addSeparator()
|
||||||
|
m_file.addAction("Quit", self.close)
|
||||||
|
|
||||||
|
# Edit
|
||||||
|
m_edit = mb.addMenu("&Edit")
|
||||||
|
self._act_undo = m_edit.addAction("Undo scan edit", self._scan_panel.undo)
|
||||||
|
self._act_undo.setShortcut("Ctrl+Z")
|
||||||
|
m_edit.addSeparator()
|
||||||
|
m_subs = m_edit.addMenu("Subprofiles")
|
||||||
|
m_subs.addAction("Add…", self._new_subprofile)
|
||||||
|
self._menu_subprofiles_remove = m_subs.addMenu("Remove")
|
||||||
|
self._rebuild_remove_subprofile_menu() # built in Task 4.x
|
||||||
|
|
||||||
|
# Scan
|
||||||
|
m_scan = mb.addMenu("&Scan")
|
||||||
|
m_scan.addAction("Scan current", self._start_scan)
|
||||||
|
m_scan.addAction("Auto-export", self._auto_export)
|
||||||
|
m_scan.addSeparator()
|
||||||
|
m_scan.addAction("Scan All…", self._start_scan_all)
|
||||||
|
m_scan.addAction("Train classifier…", self._open_train_dialog)
|
||||||
|
|
||||||
|
# View
|
||||||
|
m_view = mb.addMenu("&View")
|
||||||
|
self._act_review = m_view.addAction("Review mode")
|
||||||
|
self._act_review.setCheckable(True)
|
||||||
|
self._act_review.toggled.connect(self._btn_scan_mode.setChecked)
|
||||||
|
m_view.addAction("Subcategory markers…", self._show_subcat_menu)
|
||||||
|
m_view.addSeparator()
|
||||||
|
self._act_hide_exported = m_view.addAction("Hide exported")
|
||||||
|
self._act_hide_exported.setCheckable(True)
|
||||||
|
self._act_hide_exported.toggled.connect(self._chk_hide_exported.setChecked)
|
||||||
|
self._chk_hide_exported.toggled.connect(self._act_hide_exported.setChecked)
|
||||||
|
self._act_show_hidden = m_view.addAction("Show hidden")
|
||||||
|
self._act_show_hidden.setCheckable(True)
|
||||||
|
self._act_show_hidden.toggled.connect(self._btn_show_hidden.setChecked)
|
||||||
|
self._btn_show_hidden.toggled.connect(self._act_show_hidden.setChecked)
|
||||||
|
|
||||||
|
# Help
|
||||||
|
m_help = mb.addMenu("&Help")
|
||||||
|
m_help.addAction("Keyboard shortcuts", self._show_shortcuts).setShortcut("F1")
|
||||||
|
m_help.addAction("What's new", self._show_changelog)
|
||||||
|
m_help.addAction("About", self._show_about) # tiny method, Task 1.3
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Sync note:** `QAction.toggled`/`QAbstractButton.toggled` do not re-emit when the value is unchanged, so the bidirectional `setChecked` connections (Review, Hide exported, Show hidden) cannot loop. `_btn_scan_mode` → `_act_review` reverse sync is added in Task 3.4 once the button is in the Scan tab.
|
||||||
|
|
||||||
|
**Step 2:** Stub the two small new methods referenced above:
|
||||||
|
```python
|
||||||
|
def _show_about(self) -> None:
|
||||||
|
QMessageBox.about(self, "About 8-cut",
|
||||||
|
f"<b>8-cut</b> v{self.APP_VERSION}<br>"
|
||||||
|
"8-second clips for foley datasets.")
|
||||||
|
|
||||||
|
def _rebuild_remove_subprofile_menu(self) -> None:
|
||||||
|
self._menu_subprofiles_remove.clear()
|
||||||
|
for name in self._subprofiles:
|
||||||
|
self._menu_subprofiles_remove.addAction(
|
||||||
|
name, lambda _=False, n=name: self._remove_subprofile(n))
|
||||||
|
self._menu_subprofiles_remove.setEnabled(bool(self._subprofiles))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3:** Call `self._build_menubar()` in `__init__`, **after** `self._scan_panel` and all referenced buttons exist (i.e. just before/after the splitter assembly around main.py:4429). The scan panel is created at main.py:4414, so place the call after that.
|
||||||
|
|
||||||
|
**Step 4 (manual verify):** `./8cut.sh` → menu bar shows File/Edit/Scan/View/Help; each item triggers its action; Ctrl+Z still undoes scan edits; F1 shows shortcuts.
|
||||||
|
|
||||||
|
**Step 5:** Commit: `feat: add menu bar wired to existing handlers`.
|
||||||
|
|
||||||
|
### Task 1.2: Move profile combo + `?` into menu-bar corner
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — `top_bar` assembly (around main.py:4290–4294) and `_build_menubar`.
|
||||||
|
|
||||||
|
**Step 1:** Remove `self._cmb_profile` and `self._btn_shortcuts` (and the `"Profile:"` `QLabel`) from `top_bar`. Keep `self._lbl_file` in `top_bar` (it stays as the slim filename header above the video).
|
||||||
|
|
||||||
|
**Step 2:** In `_build_menubar`, set a corner widget:
|
||||||
|
```python
|
||||||
|
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QLabel
|
||||||
|
corner = QWidget()
|
||||||
|
ch = QHBoxLayout(corner)
|
||||||
|
ch.setContentsMargins(0, 0, 6, 0)
|
||||||
|
ch.addWidget(QLabel("Profile:"))
|
||||||
|
ch.addWidget(self._cmb_profile)
|
||||||
|
ch.addWidget(self._btn_shortcuts)
|
||||||
|
mb.setCornerWidget(corner, Qt.Corner.TopRightCorner)
|
||||||
|
```
|
||||||
|
(Build the corner widget at the end of `_build_menubar`, after `self._cmb_profile` exists — it is created at main.py:4272.)
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** Profile dropdown works (switch/new/delete); `?` opens shortcuts; filename still shows above the video.
|
||||||
|
|
||||||
|
**Step 4:** Commit: `change: move profile selector and help into menu-bar corner`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 2 — Status bar
|
||||||
|
|
||||||
|
### Task 2.1: Restore `QStatusBar` and route `_show_status` to it
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — `__init__` (`setStatusBar(None)` at main.py:4440, `_lbl_status`/`_status_timer` at main.py:4364–4370) and `_show_status` (main.py:5065).
|
||||||
|
|
||||||
|
**Step 1:** Replace `self.setStatusBar(None)` with a real status bar built in a helper:
|
||||||
|
```python
|
||||||
|
def _build_status_bar(self) -> None:
|
||||||
|
sb = self.statusBar()
|
||||||
|
self._status_perm = QLabel("")
|
||||||
|
self._status_perm.setStyleSheet("color: #888;")
|
||||||
|
sb.addPermanentWidget(self._status_perm)
|
||||||
|
self._update_status_perm()
|
||||||
|
|
||||||
|
def _update_status_perm(self) -> None:
|
||||||
|
name = os.path.basename(self._file_path) if self._file_path else "—"
|
||||||
|
self._status_perm.setText(
|
||||||
|
f"{name} · profile: {self._profile()} · {self._spn_workers.value()} workers")
|
||||||
|
```
|
||||||
|
Call `self._build_status_bar()` in `__init__` near the menubar call.
|
||||||
|
|
||||||
|
**Step 2:** Rewrite `_show_status` to use the status bar (this subsumes `_status_timer`):
|
||||||
|
```python
|
||||||
|
def _show_status(self, msg: str, timeout: int = 0) -> None:
|
||||||
|
"""Show a transient message in the status bar. timeout in ms (0 = sticky)."""
|
||||||
|
self.statusBar().showMessage(msg, timeout)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3:** Delete `self._lbl_status`, `self._status_timer`, and `settings_row.addWidget(self._lbl_status)` (main.py:4364–4370). Remove the `_status_timer.timeout` connection.
|
||||||
|
|
||||||
|
**Step 4:** Keep `_update_status_perm()` fresh — call it where file/profile/workers change: end of `_after_load`, in `_on_profile_activated`, and in the `_spn_workers.valueChanged` lambda.
|
||||||
|
|
||||||
|
**Step 5 (manual verify):** Start an export → status text appears bottom-left and auto-clears; bottom-right shows file · profile · workers and updates on file/profile/worker change.
|
||||||
|
|
||||||
|
**Step 6:** Commit: `feat: real status bar replaces inline status label`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 3 — Control deck (the core move)
|
||||||
|
|
||||||
|
Build a fixed-height `QTabWidget` with three tab pages, then **re-parent** the existing controls from `path_row` and `settings_row` into them. Give each page an `objectName` for the structure test. Do tabs one at a time so the app stays runnable.
|
||||||
|
|
||||||
|
### Task 3.1: Build the empty deck and mount it
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — `right_layout` assembly (main.py:4372–4382).
|
||||||
|
|
||||||
|
**Step 1:** Add a helper that creates the deck and three empty pages:
|
||||||
|
```python
|
||||||
|
def _build_control_deck(self) -> "QTabWidget":
|
||||||
|
from PyQt6.QtWidgets import QTabWidget, QWidget
|
||||||
|
deck = QTabWidget()
|
||||||
|
deck.setObjectName("control_deck")
|
||||||
|
deck.setDocumentMode(True)
|
||||||
|
self._tab_export = QWidget(); self._tab_export.setObjectName("export_tab")
|
||||||
|
self._tab_crop = QWidget(); self._tab_crop.setObjectName("crop_tab")
|
||||||
|
self._tab_scan = QWidget(); self._tab_scan.setObjectName("scan_tab")
|
||||||
|
deck.addTab(self._tab_export, "Export")
|
||||||
|
deck.addTab(self._tab_crop, "Crop && Track")
|
||||||
|
deck.addTab(self._tab_scan, "Scan")
|
||||||
|
self._control_deck = deck
|
||||||
|
return deck
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** In `right_layout`, **keep** `transport_row` for now, but replace the `path_row` and `settings_row` additions with the deck:
|
||||||
|
- Remove `right_layout.addLayout(path_row)` and `right_layout.addLayout(settings_row)`.
|
||||||
|
- Add `right_layout.addWidget(self._build_control_deck())`.
|
||||||
|
- Leave the `path_row`/`settings_row` *construction* in place for this task (the widgets are still parented to nothing visible) — they get moved into tabs in 3.2–3.4. **App is briefly missing those controls between 3.1 and 3.4; that's expected mid-stage.**
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** App launches; three empty tabs appear under the transport bar; switching tabs doesn't resize the video (height fixed in Task 3.5).
|
||||||
|
|
||||||
|
**Step 4:** Commit: `refactor: add empty 3-tab control deck under transport`.
|
||||||
|
|
||||||
|
### Task 3.2: Populate the Export tab
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — move widgets from `path_row` (main.py:4322–4331) and the encode/clip parts of `settings_row` (main.py:4334–4348) plus `_spn_workers` (main.py:4213).
|
||||||
|
|
||||||
|
**Step 1:** Build the Export tab with an aligned grid:
|
||||||
|
```python
|
||||||
|
def _build_export_tab(self) -> None:
|
||||||
|
from PyQt6.QtWidgets import QGridLayout, QLabel, QHBoxLayout
|
||||||
|
g = QGridLayout(self._tab_export)
|
||||||
|
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||||
|
# Row 0: annotation
|
||||||
|
g.addWidget(QLabel("Label:"), 0, 0); g.addWidget(self._txt_label, 0, 1)
|
||||||
|
g.addWidget(QLabel("Cat:"), 0, 2); g.addWidget(self._cmb_category, 0, 3)
|
||||||
|
g.addWidget(QLabel("Name:"), 0, 4); g.addWidget(self._txt_name, 0, 5)
|
||||||
|
# Row 1: output path
|
||||||
|
folder_row = QHBoxLayout()
|
||||||
|
folder_row.addWidget(self._txt_folder, 1); folder_row.addWidget(self._btn_folder)
|
||||||
|
g.addWidget(QLabel("Folder:"), 1, 0); g.addLayout(folder_row, 1, 1, 1, 5)
|
||||||
|
# Row 2: encode / clip params
|
||||||
|
g.addWidget(QLabel("Format:"), 2, 0); g.addWidget(self._cmb_format, 2, 1)
|
||||||
|
g.addWidget(self._chk_hw, 2, 2)
|
||||||
|
g.addWidget(QLabel("Resize:"), 2, 3); g.addWidget(self._spn_resize, 2, 4)
|
||||||
|
# Row 3: batch params + actions
|
||||||
|
g.addWidget(QLabel("Duration:"), 3, 0); g.addWidget(self._spn_clip_dur, 3, 1)
|
||||||
|
g.addWidget(QLabel("Clips:"), 3, 2); g.addWidget(self._spn_clips, 3, 3)
|
||||||
|
g.addWidget(QLabel("Spread:"), 3, 4); g.addWidget(self._spn_spread, 3, 5)
|
||||||
|
g.addWidget(QLabel("Workers:"), 4, 0); g.addWidget(self._spn_workers, 4, 1)
|
||||||
|
g.addWidget(self._btn_reexport, 4, 5)
|
||||||
|
```
|
||||||
|
Call it from `_build_control_deck` (or right after, in `__init__`).
|
||||||
|
|
||||||
|
**Step 2:** Delete the now-duplicate `addWidget` calls for these widgets from `path_row` and `settings_row` construction. (Re-parenting via `addWidget` into the grid auto-removes them from the old layout, but remove the dead lines to keep `__init__` honest.)
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** Export tab shows aligned Label/Cat/Name, Folder+browse, Format/HW/Resize, Duration/Clips/Spread/Workers/Re-export. Change each → still persists to `QSettings` and updates the timeline span / next-label as before. Export still works (E).
|
||||||
|
|
||||||
|
**Step 4:** Commit: `refactor: move export & encode controls into Export tab`.
|
||||||
|
|
||||||
|
### Task 3.3: Populate the Crop & Track tab
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — move `_cmb_portrait`, `_chk_rand_portrait`, `_chk_rand_square`, `_chk_track` from `settings_row` (main.py:4337, 4349–4351).
|
||||||
|
|
||||||
|
**Step 1:**
|
||||||
|
```python
|
||||||
|
def _build_crop_tab(self) -> None:
|
||||||
|
from PyQt6.QtWidgets import QGridLayout, QLabel
|
||||||
|
g = QGridLayout(self._tab_crop)
|
||||||
|
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||||
|
g.addWidget(QLabel("Portrait:"), 0, 0); g.addWidget(self._cmb_portrait, 0, 1)
|
||||||
|
g.addWidget(self._chk_rand_portrait, 1, 0, 1, 2)
|
||||||
|
g.addWidget(self._chk_rand_square, 2, 0, 1, 2)
|
||||||
|
g.addWidget(self._chk_track, 3, 0, 1, 2)
|
||||||
|
g.setRowStretch(4, 1); g.setColumnStretch(2, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** Remove those four widgets' old `settings_row.addWidget` lines.
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** Crop & Track tab shows the four controls; portrait ratio still toggles the crop overlay/crop-bar; random/track checkboxes persist.
|
||||||
|
|
||||||
|
**Step 4:** Commit: `refactor: move crop & track controls into their tab`.
|
||||||
|
|
||||||
|
### Task 3.4: Populate the Scan tab (and drop menu-only buttons)
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — move scan widgets from `settings_row` (main.py:4352–4362). Buttons that became **menu-only** (Train, Scan All, Sub) are NOT added to the tab and are deleted.
|
||||||
|
|
||||||
|
**Step 1:**
|
||||||
|
```python
|
||||||
|
def _build_scan_tab(self) -> None:
|
||||||
|
from PyQt6.QtWidgets import QGridLayout, QLabel, QHBoxLayout
|
||||||
|
g = QGridLayout(self._tab_scan)
|
||||||
|
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||||
|
model_row = QHBoxLayout()
|
||||||
|
model_row.addWidget(self._cmb_scan_model, 1); model_row.addWidget(self._btn_model_history)
|
||||||
|
g.addWidget(QLabel("Model:"), 0, 0); g.addLayout(model_row, 0, 1, 1, 3)
|
||||||
|
g.addWidget(self._btn_scan, 1, 0); g.addWidget(self._btn_auto_export, 1, 1)
|
||||||
|
g.addWidget(self._btn_speech, 1, 2); g.addWidget(self._btn_scan_mode, 1, 3)
|
||||||
|
g.addWidget(self._spn_auto_fuse, 2, 0); g.addWidget(self._sld_threshold, 2, 1)
|
||||||
|
g.setColumnStretch(3, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** Reverse-sync Review with the View menu (the forward sync was added in Task 1.1):
|
||||||
|
```python
|
||||||
|
self._btn_scan_mode.toggled.connect(self._act_review.setChecked)
|
||||||
|
```
|
||||||
|
Add this right after `_build_scan_tab` runs (both `_btn_scan_mode` and `_act_review` exist by then).
|
||||||
|
|
||||||
|
**Step 3:** Delete the menu-only buttons and their `settings_row` lines: `self._btn_train` (main.py:4167–4170), `self._btn_scan_all` (main.py:4172–4174), `self._btn_hide_subcats` (main.py:4154–4157). Their handlers (`_open_train_dialog`, `_start_scan_all`, `_show_subcat_menu`) stay — now reached via menus.
|
||||||
|
|
||||||
|
**Step 4:** Re-anchor `_show_subcat_menu` (main.py:5989) so it no longer depends on the deleted `_btn_hide_subcats`:
|
||||||
|
```python
|
||||||
|
# was: self._btn_hide_subcats.mapToGlobal(self._btn_hide_subcats.rect().bottomLeft())
|
||||||
|
from PyQt6.QtGui import QCursor
|
||||||
|
menu.exec(QCursor.pos())
|
||||||
|
```
|
||||||
|
Apply to **both** `exec` call sites in that method.
|
||||||
|
|
||||||
|
**Step 5 (manual verify):** Scan tab shows Model+history, Scan/Auto/Speech/Review, Fuse/Threshold. `Scan` runs; `Review` toggles and stays in sync with View ▸ Review mode (both directions); View ▸ Subcategory markers… opens the full popup near the cursor; Scan ▸ Scan All / Train still work.
|
||||||
|
|
||||||
|
**Step 6:** Commit: `refactor: move scan controls into Scan tab; Train/ScanAll/Sub to menus`.
|
||||||
|
|
||||||
|
### Task 3.5: Fix deck height; remove dead `path_row`/`settings_row`
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — `__init__`.
|
||||||
|
|
||||||
|
**Step 1:** The `path_row`/`settings_row` `QHBoxLayout`s should now be empty. Delete their construction blocks entirely (main.py:4321–4370 minus what was already removed), including the `self._transport_row = transport_row` line only if unused elsewhere (it IS used by `_rebuild_subprofile_buttons` — keep `transport_row`).
|
||||||
|
|
||||||
|
**Step 2:** Pin the deck height so tab switches don't move the video:
|
||||||
|
```python
|
||||||
|
self._control_deck.setFixedHeight(self._control_deck.sizeHint().height())
|
||||||
|
```
|
||||||
|
Call after all three tabs are built. If the tallest tab (Export, 5 rows) clips, set an explicit value instead (e.g. `setFixedHeight(150)`); confirm visually.
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** Switching Export↔Crop↔Scan keeps the video size constant; no clipped controls; all three tabs fully usable.
|
||||||
|
|
||||||
|
**Step 4:** Commit: `refactor: fix control-deck height; drop dead settings rows`.
|
||||||
|
|
||||||
|
### Task 3.6: Extend the structure test for the deck
|
||||||
|
|
||||||
|
**Files:** Modify `tests/test_ui_structure.py`.
|
||||||
|
|
||||||
|
**Step 1:** Add invariants:
|
||||||
|
```python
|
||||||
|
def test_menubar_has_expected_menus(win):
|
||||||
|
titles = [m.title().replace("&", "") for m in win.menuBar().findChildren(type(win.menuBar().addMenu("")))]
|
||||||
|
for expected in ("File", "Edit", "Scan", "View", "Help"):
|
||||||
|
assert any(expected == t for t in titles)
|
||||||
|
|
||||||
|
def test_status_bar_exists(win):
|
||||||
|
assert win.statusBar() is not None
|
||||||
|
|
||||||
|
def test_workers_spinbox_in_export_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QSpinBox
|
||||||
|
assert win._spn_workers in win._tab_export.findChildren(QSpinBox)
|
||||||
|
|
||||||
|
def test_scan_button_in_scan_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QPushButton
|
||||||
|
assert win._btn_scan in win._tab_scan.findChildren(QPushButton)
|
||||||
|
|
||||||
|
def test_portrait_combo_in_crop_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QComboBox
|
||||||
|
assert win._cmb_portrait in win._tab_crop.findChildren(QComboBox)
|
||||||
|
```
|
||||||
|
(Adjust the menu-title introspection if the helper is awkward; the key invariants are the tab-containment ones.)
|
||||||
|
|
||||||
|
**Step 2:** Run: `pytest tests/test_ui_structure.py -v` → PASS with a display (or SKIP headless).
|
||||||
|
|
||||||
|
**Step 3:** Commit: `test: assert control-deck containment invariants`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 4 — Transport bar tidy & subprofile menu sync
|
||||||
|
|
||||||
|
### Task 4.1: Confirm transport bar contents; keep subprofile export buttons inline
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — `transport_row` (main.py:4296–4319).
|
||||||
|
|
||||||
|
**Step 1:** The workers spinbox was moved in Task 3.2 — confirm `transport_row.addWidget(self._spn_workers)` is gone. Remaining transport order: Play, Pause, x2, x4, Lock, time, stretch, next-label, **Export**, subprofile buttons, `+` (add subprofile), Cancel, Delete. Leave subprofile **export** buttons inline (they carry the 1–9 shortcuts and belong with Export).
|
||||||
|
|
||||||
|
**Step 2:** Keep the inline `+` add-subprofile button, but also ensure the Edit ▸ Subprofiles ▸ Remove submenu is rebuilt whenever subprofiles change. In `_rebuild_subprofile_buttons` (main.py:5530-ish) and after add/remove, call `self._rebuild_remove_subprofile_menu()`.
|
||||||
|
|
||||||
|
**Step 3 (manual verify):** Transport row reads cleanly; adding/removing a subprofile updates both the inline buttons and Edit ▸ Subprofiles ▸ Remove; number keys 1–9 still export to subprofiles.
|
||||||
|
|
||||||
|
**Step 4:** Commit: `change: tidy transport row; sync subprofile remove menu`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 5 — Visual polish
|
||||||
|
|
||||||
|
All Stage 5 verification is **manual** (visual). Take a screenshot before 5.1 for comparison (use the `run`/`verify` skill).
|
||||||
|
|
||||||
|
### Task 5.1: Consolidate the stylesheet (tabs, status bar, toggles, primary button)
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — global stylesheet in `main()` (main.py:3811–3827).
|
||||||
|
|
||||||
|
**Step 1:** Extend the central sheet (append rules; keep existing ones):
|
||||||
|
```css
|
||||||
|
QTabWidget::pane { border: 1px solid #444; border-radius: 3px; top: -1px; }
|
||||||
|
QTabBar::tab { background: #2a2a2a; color: #bbb; padding: 5px 12px;
|
||||||
|
border: 1px solid #444; border-bottom: none;
|
||||||
|
border-top-left-radius: 3px; border-top-right-radius: 3px; }
|
||||||
|
QTabBar::tab:selected { background: #333; color: #fff; }
|
||||||
|
QPushButton:checked { background: #4a3000; border-color: #ffd230; color: #fff; }
|
||||||
|
QStatusBar { background: #1a1a1a; color: #bbb; }
|
||||||
|
QStatusBar::item { border: none; }
|
||||||
|
QPushButton#primary { background: #3a6ea8; border-color: #4f86c6; color: #fff; }
|
||||||
|
QPushButton#primary:hover { background: #4f86c6; }
|
||||||
|
QMenuBar { background: #1e1e1e; } QMenuBar::item:selected { background: #3a6ea8; }
|
||||||
|
QMenu { background: #2a2a2a; border: 1px solid #555; }
|
||||||
|
QMenu::item:selected { background: #3a6ea8; }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** Mark Export primary: `self._btn_export.setObjectName("primary")`.
|
||||||
|
|
||||||
|
**Step 3:** Replace Lock's inline stylesheet swap (main.py:5705) — since `QPushButton:checked` now styles all toggles, delete the two `self._btn_lock.setStyleSheet(...)` lines in `_on_lock_toggled` (keep the rest of the handler).
|
||||||
|
|
||||||
|
**Step 4 (manual verify):** Tabs, menus, status bar, and checked toggles (x2/x4/Lock/Review) all read consistently; Export stands out as primary; Lock still highlights when active.
|
||||||
|
|
||||||
|
**Step 5:** Commit: `style: unify tab/menu/statusbar/toggle styling; mark Export primary`.
|
||||||
|
|
||||||
|
### Task 5.2: Preserve the "armed to overwrite" Export state
|
||||||
|
|
||||||
|
**Files:** Inspect `main.py` — the red-Export swaps (main.py:5403, and the resets at 4960/5211/5447/7170/7199/7218).
|
||||||
|
|
||||||
|
**Step 1:** These set/clear `self._btn_export.setStyleSheet("QPushButton { background: #6a3030; ... }")` to mean "this export will overwrite". With Export now `objectName("primary")`, an empty `setStyleSheet("")` reset reverts to the **primary** look (good). Confirm the armed (red) state still visually overrides primary — inline stylesheet beats the objectName rule, so it does.
|
||||||
|
|
||||||
|
**Step 2 (manual verify):** Select a marker for re-export → Export turns red (armed); deselect → returns to blue primary; export → resets correctly.
|
||||||
|
|
||||||
|
**Step 3:** Commit (only if changes were needed): `fix: keep armed-overwrite Export state over primary style`.
|
||||||
|
|
||||||
|
### Task 5.3: Label cleanup
|
||||||
|
|
||||||
|
**Files:** Modify `main.py` — prefixes/labels.
|
||||||
|
|
||||||
|
**Step 1:** De-abbreviate where free: `_sld_threshold.setPrefix("Threshold: ")` (main.py:4207) → keep short if it overflows the tab; `_spn_auto_fuse` prefix stays `"Fuse: "`. Replace the `⏲` history button text with a tooltip-backed `"History"` or a clearer glyph; keep `setFixedWidth` generous enough.
|
||||||
|
|
||||||
|
**Step 2 (manual verify):** Labels legible; nothing clipped in the Scan tab.
|
||||||
|
|
||||||
|
**Step 3:** Commit: `style: de-abbreviate scan labels`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 6 — Finalize
|
||||||
|
|
||||||
|
### Task 6.1: Full regression pass
|
||||||
|
|
||||||
|
**Step 1 (manual, use `verify` skill):** With a real video loaded, confirm end-to-end: scrub/play/pause/speed/lock; export (E) single + batch + subprofile (1–9); re-export; delete; portrait crop + random + track; scan + auto + speech + review + threshold/fuse; scan-all; train dialog opens; profile switch; queue filter/hide/show-hidden; Ctrl+Z undo; F1/`?` shortcuts.
|
||||||
|
|
||||||
|
**Step 2:** Run `pytest -q` (all suites). Expected: `core/` PASS; `test_ui_structure` PASS (display) or SKIP.
|
||||||
|
|
||||||
|
### Task 6.2: Docs & changelog
|
||||||
|
|
||||||
|
**Files:** Modify `README.md` (UI/shortcuts sections if any references moved) and the in-app `CHANGELOG` list (main.py:4500) — bump `APP_VERSION` and add a "UI restructure" entry so the What's-new dialog announces it.
|
||||||
|
|
||||||
|
**Step 1:** Add changelog entry summarizing: menu bar, tabbed control deck, status bar, visual polish; note all shortcuts unchanged.
|
||||||
|
|
||||||
|
**Step 2:** Commit: `docs: changelog + README for UI restructure`.
|
||||||
|
|
||||||
|
### Task 6.3: Hand off the branch
|
||||||
|
|
||||||
|
**Step 1:** `git log --oneline master..ui-restructure` — review the commit series.
|
||||||
|
**Step 2:** Offer the user: merge to `master`, open a PR, or keep iterating (use `finishing-a-development-branch` skill).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risk register
|
||||||
|
|
||||||
|
| Risk | Mitigation |
|
||||||
|
|------|-----------|
|
||||||
|
| Re-parenting breaks a `connect()` | Widgets keep identity; only layout membership changes. Manual launch after every task catches breakage immediately. |
|
||||||
|
| Headless test can't build `MpvWidget` | Structure test skips on construction failure; manual launch is authoritative. |
|
||||||
|
| Menu/button state desync (Review, Hide exported) | Bidirectional `setChecked` (no re-emit on equal value → no loop); verified manually in 3.4. |
|
||||||
|
| Subcat popup anchored to deleted button | Re-anchored to `QCursor.pos()` in Task 3.4. |
|
||||||
|
| Deck height jump on tab switch | `setFixedHeight` in Task 3.5. |
|
||||||
|
| Armed-overwrite red Export lost under primary style | Inline stylesheet overrides objectName rule; verified in 5.2. |
|
||||||
|
| Mid-Stage-3 app missing controls | Expected between 3.1–3.4; each sub-task is still committable and launchable. |
|
||||||
|
|
||||||
|
## What this plan does NOT change
|
||||||
|
|
||||||
|
`core/` logic · export/scan/tracking/DB behavior · keyboard shortcuts · timeline mouse interactions · the Queue and Scan-results panes' internals · the dark Fusion theme.
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
# Multi-pane Control Deck — Design + Plan Addendum
|
||||||
|
|
||||||
|
> Addendum to `2026-06-13-ui-restructure-design.md` / `-implementation.md`. Same branch (`ui-restructure`), same constraints (preserve behavior; reorg/feature only; no `core/` changes).
|
||||||
|
|
||||||
|
**Goal:** Let the control-deck panels (Export / Crop & Track / Scan) optionally show **side-by-side as resizable columns** instead of one-at-a-time tabs — mirroring the existing playlist pin→side-by-side pattern.
|
||||||
|
|
||||||
|
> **Revision (post-use, 2026-06-13):** The first implementation showed unpinned panels as a "leftover" tab-column so nothing was hidden — but in use, pinning 2 panels then displayed 3 columns, which read as "all three pinned" and was confusing (and inconsistent with what persisted). **Revised behavior:** the split view shows **exactly the pinned panels** as columns (pin 2 → 2 columns, pin 3 → 3). Unpinned panels are not shown as columns. Because the right-click-tab "Show side-by-side" gesture only works in tabbed mode, an always-available **View ▸ Side-by-side panels ▸ Export / Crop / Scan** submenu of checkable toggles is the way to pin/unpin any panel (including adding a 3rd while already in split view). The `if leftovers:` block below is removed; the View submenu + its sync in `_refresh_deck_layout` replace it.
|
||||||
|
|
||||||
|
**Mirror these existing playlist members** (study them — the deck is a simpler, fixed-3-panel version): `_PlaylistTabBar` (main.py:3284), `_refresh_layout` (~4872), `_on_pin_toggle`/`_on_unpin` (~4942), `_detach_all_pws`/`_clear_split_container` (~4861), and the `_list_stack`/`_split_container` setup (~3916–3923).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Design
|
||||||
|
|
||||||
|
### Panel identity
|
||||||
|
The deck's three pages (`_tab_export`, `_tab_crop`, `_tab_scan`) each get three attributes (set in `_build_control_deck`):
|
||||||
|
- `_pinned: bool = False`
|
||||||
|
- `_label: str` — "Export" / "Crop & Track" / "Scan"
|
||||||
|
- `_deck_key: str` — "export" / "crop" / "scan" (stable key for persistence)
|
||||||
|
|
||||||
|
Keep an ordered list `self._deck_panels = [self._tab_export, self._tab_crop, self._tab_scan]` for deterministic column order.
|
||||||
|
|
||||||
|
### Tab bar
|
||||||
|
New `class _DeckTabBar(QTabBar)` (minimal version of `_PlaylistTabBar`): on `contextMenuEvent`, show a checkable "Show side-by-side" action reflecting the page's `_pinned`, and emit `pin_toggle_requested(idx)` when chosen. No rename/folder. Install via `self._control_deck.setTabBar(_DeckTabBar())` in `_build_control_deck` and connect `pin_toggle_requested → self._on_deck_pin_toggle`.
|
||||||
|
|
||||||
|
### Stacked container (mirrors `_list_stack`)
|
||||||
|
Wrap the deck so it can swap between tabbed and split views:
|
||||||
|
- `self._deck_split_container = QWidget()` with an `QHBoxLayout` (`_deck_split_layout`, margins 0, spacing 2).
|
||||||
|
- `self._deck_stack = QStackedWidget()`; page 0 = `self._control_deck`, page 1 = `self._deck_split_container`.
|
||||||
|
- In `right_layout`, mount `self._deck_stack` where `self._control_deck` is currently added (replace that one `addWidget`).
|
||||||
|
|
||||||
|
### `_refresh_deck_layout()` (mirrors `_refresh_layout`)
|
||||||
|
```
|
||||||
|
pinned = [p for p in self._deck_panels if p._pinned]
|
||||||
|
guard self._deck_loading = True (avoid re-entrant signals)
|
||||||
|
detach all panels (setParent(None)); self._control_deck.clear(); clear _deck_split_layout
|
||||||
|
if len(pinned) >= 2:
|
||||||
|
splitter = QSplitter(Horizontal); splitter.setChildrenCollapsible(False)
|
||||||
|
leftovers = []
|
||||||
|
for panel in self._deck_panels: # preserve deck order
|
||||||
|
if panel._pinned:
|
||||||
|
col = QWidget(); v = QVBoxLayout(col) (0 margins)
|
||||||
|
header = label(panel._label, bold) + "✕" button (unpin, fixed 18x18,
|
||||||
|
tooltip "Return to tabs", clicked → self._on_deck_unpin(panel))
|
||||||
|
header fixed height ~22
|
||||||
|
panel.setVisible(True) # reparented pages start hidden
|
||||||
|
v.addWidget(header); v.addWidget(panel, 1)
|
||||||
|
splitter.addWidget(col)
|
||||||
|
else:
|
||||||
|
leftovers.append(panel)
|
||||||
|
if leftovers: # keep unpinned reachable as a tab-column
|
||||||
|
lt = QTabWidget(); lt.setDocumentMode(True)
|
||||||
|
for panel in leftovers:
|
||||||
|
panel.setVisible(True); lt.addTab(panel, panel._label)
|
||||||
|
splitter.addWidget(lt)
|
||||||
|
splitter.setSizes([1000]*splitter.count())
|
||||||
|
_deck_split_layout.addWidget(splitter)
|
||||||
|
self._deck_stack.setCurrentWidget(self._deck_split_container)
|
||||||
|
else:
|
||||||
|
for panel in self._deck_panels: # fixed order
|
||||||
|
self._control_deck.addTab(panel, panel._label)
|
||||||
|
self._deck_stack.setCurrentWidget(self._control_deck)
|
||||||
|
restore self._deck_loading
|
||||||
|
```
|
||||||
|
|
||||||
|
### Toggle handlers (mirror `_on_pin_toggle`/`_on_unpin`)
|
||||||
|
- `_on_deck_pin_toggle(idx)`: `panel = self._control_deck.widget(idx)` (only valid in tabbed mode — pin is only offered there); flip `panel._pinned`; if now pinned and `<2` pinned, `_show_status("Pin another panel to show them side-by-side", 3500)`; `_refresh_deck_layout()`; `_save_deck_layout()`.
|
||||||
|
- `_on_deck_unpin(panel)`: `panel._pinned = False`; `_refresh_deck_layout()`; `_save_deck_layout()`.
|
||||||
|
|
||||||
|
### Persistence
|
||||||
|
- `_save_deck_layout()`: `self._settings.setValue("deck_pinned", [p._deck_key for p in self._deck_panels if p._pinned])`.
|
||||||
|
- Restore at the end of `__init__` (after the deck + menubar exist): read `deck_pinned` (handle str/list like the subprofiles loader at main.py:3867), set each panel's `_pinned`, then `_refresh_deck_layout()` once.
|
||||||
|
|
||||||
|
### Height
|
||||||
|
The deck pages now also render with a 22px header in split mode. After building, set the stack's minimum height to fit the tallest **split-mode** column (header + Export content) so split mode never clips: compute once via `self._deck_stack.setMinimumHeight(...)` using `sizeHint`, and keep vertical size policy `Fixed` (as the deck has now). Switching INTO split mode may change the deck height slightly (deliberate user action — acceptable); switching tabs within tabbed mode must still not jump. Reuse the existing height-pin logic — apply it to `_deck_stack` instead of `_control_deck`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation tasks (bite-sized, commit per task)
|
||||||
|
|
||||||
|
**Task M.1 — scaffolding (no behavior change yet).** Add `_DeckTabBar`; in `_build_control_deck` set it on the deck, set `_pinned/_label/_deck_key` on the three pages, build `self._deck_panels`, create `_deck_split_container`/`_deck_split_layout`/`_deck_stack`, and mount `_deck_stack` in `right_layout` instead of `_control_deck`. Connect `pin_toggle_requested` to a stub. App still behaves as plain tabs. Verify: `import main`, structure tests 6/6, and a probe that `_deck_stack.currentWidget() is _control_deck`.
|
||||||
|
|
||||||
|
**Task M.2 — split rendering.** Implement `_refresh_deck_layout`, `_detach_deck_panels`, `_clear_deck_split`, `_on_deck_pin_toggle`, `_on_deck_unpin`. Verify with a probe: set two panels `_pinned=True`, call `_refresh_deck_layout()`, assert stack shows `_deck_split_container`, the splitter has 3 columns (2 pinned + 1 leftover QTabWidget), and all three panels are visible/parented; unpin one → back to `_control_deck` with 3 tabs in order.
|
||||||
|
|
||||||
|
**Task M.3 — persistence.** Add `_save_deck_layout()` + restore block in `__init__`. Verify a probe round-trips a pinned set through QSettings (use an isolated QSettings scope in the test if needed) without error and that restore calls refresh exactly once.
|
||||||
|
|
||||||
|
**Task M.4 — height + tests.** Apply the height-pin to `_deck_stack`; confirm split mode doesn't clip the tallest column. Add structure tests: `test_deck_stack_exists`, and `test_pinning_two_panels_switches_to_split` (programmatically pin 2, refresh, assert `_deck_stack.currentWidget() is _deck_split_container`).
|
||||||
|
|
||||||
|
## Verification note
|
||||||
|
Env quirk (same as the restructure): bare `python -c` constructing `MainWindow` segfaults on mpv GL; run checks under the pytest fixture and `LD_PRELOAD=/usr/lib/libstdc++.so.6 QT_QPA_PLATFORM=offscreen`. Visual confirmation (drag dividers, pin/unpin gestures, persistence across real launches) is the user's, done at the end.
|
||||||
|
|
||||||
|
## Risks
|
||||||
|
- **Reparenting hidden pages:** QTabWidget hides non-current pages; reparented panels must be `setVisible(True)` in split columns (same gotcha the playlist documents at main.py:4909-4911).
|
||||||
|
- **Signal re-entrancy:** guard with `_deck_loading` during refresh.
|
||||||
|
- **Pin offered in split mode:** `_on_deck_pin_toggle` reads `_control_deck.widget(idx)`, which is only meaningful in tabbed mode. The ✕ header is the unpin path in split mode — don't rely on the context menu there.
|
||||||
|
- **Height jump on mode toggle:** acceptable (deliberate); tab-switch-within-tabs must remain jump-free.
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
# LTX-2 per-tab export mode — Design
|
||||||
|
|
||||||
|
**Goal:** Add an export *pipeline mode* to each file-list tab — **Foley** (current behavior) or **LTX-2** — so the same source videos can feed both a Foley dataset (8 s clips) and an LTX-2 V2A dataset (frame-exact, ÷32, 25 fps) without the two ever mixing.
|
||||||
|
|
||||||
|
**Depends on:** the per-tab export folder feature (branch `tab-export-folder`) — this design extends that per-tab state. Implementation branch `ltx2-preset` is based on it.
|
||||||
|
|
||||||
|
**Scope:** soft preset (no hard enforcement — defaults are LTX-2-legal but every control stays editable). `core/` gains optional pipeline params; Foley path is byte-for-byte unchanged.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## LTX-2 constraints (why this exists)
|
||||||
|
|
||||||
|
LTX-2 (32× spatial VAE, 8× temporal + 1) requires, for a clip:
|
||||||
|
- **W and H each divisible by 32.**
|
||||||
|
- **Frame count F such that `F % 8 == 1`** → 9, 17, 25, … 201, … (transformer seq-len ∝ `(W/32)·(H/32)·((F−1)/8+1)`).
|
||||||
|
- **fps** only sets real duration `F/fps`; for V2A it fixes the paired-audio length and audio↔motion sync, so it must be **consistent across the dataset and equal to the inference `frame_rate`**. Target: **25 fps**.
|
||||||
|
- V2A video is frozen conditioning → low spatial res (384–512) is fine and cheaper.
|
||||||
|
|
||||||
|
Note: 8 s @ 25 fps = 200 frames, and `200 % 8 == 0` → **8 s is not legal**. Nearest legal: F=193 (7.72 s) or **F=201 (8.04 s)**.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model: per-tab mode
|
||||||
|
|
||||||
|
Each tab (`PlaylistWidget`) gains `_mode ∈ {"foley","ltx2"}`, persisted alongside `_dest_folder`/`_pinned`/`_tab_folder` in `_save_playlist_tabs`/`_load_playlist_tabs`. Default `"foley"` → existing tabs load unchanged. The **active tab's mode drives the export pipeline and the length control.**
|
||||||
|
|
||||||
|
### Tab context menu (`_DeckTabBar`/`_PlaylistTabBar`)
|
||||||
|
- **Duplicate as LTX-2** — headline action: clone the tab's file list + separators into a new tab; set `mode="ltx2"`; derive a separate export folder `"<dest_folder>_ltx2"`; load LTX-2 default geometry. Lets you spin an LTX-2 dataset off a Foley working set.
|
||||||
|
- **Duplicate tab** — clone keeping the same mode.
|
||||||
|
- **LTX-2 mode** — checkable, flips an existing tab between foley/ltx2.
|
||||||
|
- Tab label shows a small **`[LTX2]`** badge when `mode=="ltx2"`.
|
||||||
|
|
||||||
|
## What `ltx2` mode changes (soft — still editable)
|
||||||
|
|
||||||
|
| Aspect | Foley | LTX-2 |
|
||||||
|
|--------|-------|-------|
|
||||||
|
| Clip length | Duration spinbox (seconds) | **Frame-count F** control stepping the legal series (9, 17, …, 201, …); shows `= F/25 s` |
|
||||||
|
| Output fps | inherits source | **forced 25 fps** (resample; preserves duration/sync) |
|
||||||
|
| Output W×H | short-side resize → even long side | **center-cropped to ÷32** on both axes (no aspect distortion; loses ≤31 px/side); resize default **512** |
|
||||||
|
| Frame exactness | duration-based | exactly **F** frames (`-frames:v F`) |
|
||||||
|
|
||||||
|
Defaults loaded on convert: resize **512**, **F = 201** (≈8.04 s, mirrors the 8 s Foley clips), ratio as set. All editable afterward.
|
||||||
|
|
||||||
|
## Pipeline (`core/ffmpeg.build_ffmpeg_command`)
|
||||||
|
|
||||||
|
Add optional params; Foley calls pass none → identical output to today:
|
||||||
|
- `target_fps: float | None` — when set, append `fps={target_fps}` filter and `-r {target_fps}`.
|
||||||
|
- `snap32: bool` — when true, after the scale append a centered crop to the nearest lower multiple of 32 on each axis: `crop=trunc(iw/32)*32:trunc(ih/32)*32`.
|
||||||
|
- Frame-exact length: caller computes `duration = F/target_fps` and passes `-frames:v F` on the video output so the clip has exactly F frames; audio extract uses the same `F/target_fps` duration so V2A pairing stays aligned.
|
||||||
|
|
||||||
|
Filter order: portrait-crop (aspect) → scale (short side, ÷32 default) → snap32 crop → fps. The snap32 center-crop runs after scaling so the ÷32 trim is on final pixels.
|
||||||
|
|
||||||
|
## UI wiring (`MainWindow`)
|
||||||
|
|
||||||
|
- The length spinbox area swaps with the active tab's mode: Foley shows *Duration (s)*; LTX-2 shows *Frames (F)* with a live `= s @25fps` readout. Switching tabs (or toggling mode) reconfigures it; uses the existing `_sync_folder_field_to_tab`-style sync hook on tab change.
|
||||||
|
- `_on_export` / `_start_export_batch`: when the active tab is `ltx2`, pass `target_fps=25`, `snap32=True`, and frame-exact length to the ffmpeg builder; otherwise unchanged.
|
||||||
|
- The mismatch guardrail (just added) and per-tab folder continue to apply.
|
||||||
|
|
||||||
|
## Persistence & migration
|
||||||
|
`_mode` added to each tab's saved JSON (default `"foley"` when absent). No DB changes. Existing sessions load every tab as Foley → zero behavior change until a tab is converted.
|
||||||
|
|
||||||
|
## What this does NOT do
|
||||||
|
- No hard enforcement: you can set an illegal F or non-÷32 resize manually; the pipeline still crops to ÷32 and uses whatever F you pick (the *control* defaults/steps keep you legal, but nothing blocks you).
|
||||||
|
- No motion interpolation on fps resample (frame drop/dup only); keep sources native 25 fps where possible.
|
||||||
|
- No change to Foley exports, the scan pipeline, or the DB schema.
|
||||||
|
- No automatic re-export of existing clips into LTX-2 — you cut LTX-2 clips in the converted tab.
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
# LTX-2 per-tab export mode — Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Add a per-tab export pipeline mode (Foley | LTX-2) so the same videos can feed both an 8 s Foley dataset and a frame-exact, ÷32, 25 fps LTX-2 V2A dataset, with a "Duplicate as LTX-2" tab action.
|
||||||
|
|
||||||
|
**Architecture:** `core/ffmpeg.build_ffmpeg_command` gains optional `target_fps` / `snap32` / `frames` params (Foley path unchanged); a tiny `core/ltx2.py` holds the legal-frame math. `PlaylistWidget` gains `_mode`; the tab menu gains duplicate/convert actions; the length control + `_on_export` wiring switch on the active tab's mode. Soft preset — defaults are legal, everything stays editable.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.11+, PyQt6, ffmpeg, pytest. Branch `ltx2-preset` (based on `tab-export-folder`). Design: `docs/plans/2026-06-18-ltx2-preset-design.md`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
- **Core (`core/ffmpeg.py`, `core/ltx2.py`) is real TDD** — pure functions tested in `tests/test_utils.py` style. Run: `LD_PRELOAD=/usr/lib/libstdc++.so.6 python -m pytest tests/test_utils.py -q` (the preload is needed because importing `main` pulls `mpv`; see `project_qt_test_env`). 3 pre-existing failures there are unrelated — don't count them.
|
||||||
|
- **GUI parts** verified by the offscreen structure test (`LD_PRELOAD=/usr/lib/libstdc++.so.6 QT_QPA_PLATFORM=offscreen python -m pytest tests/test_ui_structure.py -v`) plus a **manual launch** (`./8cut.sh`).
|
||||||
|
- Line numbers are starting anchors; locate by symbol. Commit per task. Co-author trailer on every commit:
|
||||||
|
`Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 1 — LTX-2 math (`core/ltx2.py`) [TDD]
|
||||||
|
|
||||||
|
### Task 1.1: legal-frame helpers
|
||||||
|
**Files:** Create `core/ltx2.py`; Test in `tests/test_utils.py` (append).
|
||||||
|
|
||||||
|
**Step 1 — failing tests** (append to `tests/test_utils.py`):
|
||||||
|
```python
|
||||||
|
from core.ltx2 import is_legal_frames, nearest_legal_frames, frames_for_duration, duration_for_frames, legal_frames
|
||||||
|
|
||||||
|
def test_ltx2_is_legal():
|
||||||
|
assert is_legal_frames(201) and is_legal_frames(9) and is_legal_frames(25)
|
||||||
|
assert not is_legal_frames(200) and not is_legal_frames(8)
|
||||||
|
|
||||||
|
def test_ltx2_nearest():
|
||||||
|
assert nearest_legal_frames(200) == 201 # 200 -> nearest 8k+1
|
||||||
|
assert nearest_legal_frames(196) == 193
|
||||||
|
assert nearest_legal_frames(5) == 9 # floor at 9
|
||||||
|
|
||||||
|
def test_ltx2_duration_roundtrip():
|
||||||
|
assert duration_for_frames(201, 25) == 201 / 25
|
||||||
|
assert frames_for_duration(8.0, 25) == 201 # 200 -> 201
|
||||||
|
|
||||||
|
def test_ltx2_legal_series():
|
||||||
|
s = legal_frames(min_f=9, max_f=33)
|
||||||
|
assert s == [9, 17, 25, 33]
|
||||||
|
```
|
||||||
|
**Step 2 — run, expect ImportError/FAIL:** `LD_PRELOAD=/usr/lib/libstdc++.so.6 python -m pytest tests/test_utils.py -k ltx2 -q`
|
||||||
|
|
||||||
|
**Step 3 — implement `core/ltx2.py`:**
|
||||||
|
```python
|
||||||
|
"""LTX-2 frame-count math. Legal F satisfy F % 8 == 1 (8x temporal + 1)."""
|
||||||
|
|
||||||
|
def is_legal_frames(f: int) -> bool:
|
||||||
|
return f >= 9 and f % 8 == 1
|
||||||
|
|
||||||
|
def legal_frames(min_f: int = 9, max_f: int = 1000) -> list[int]:
|
||||||
|
start = max(9, min_f + ((1 - min_f) % 8)) # first 8k+1 >= min_f
|
||||||
|
return list(range(start, max_f + 1, 8))
|
||||||
|
|
||||||
|
def nearest_legal_frames(f: int) -> int:
|
||||||
|
if f <= 9:
|
||||||
|
return 9
|
||||||
|
low = ((f - 1) // 8) * 8 + 1
|
||||||
|
high = low + 8
|
||||||
|
return low if (f - low) <= (high - f) else high
|
||||||
|
|
||||||
|
def duration_for_frames(frames: int, fps: float) -> float:
|
||||||
|
return frames / fps
|
||||||
|
|
||||||
|
def frames_for_duration(duration: float, fps: float) -> int:
|
||||||
|
return nearest_legal_frames(round(duration * fps))
|
||||||
|
```
|
||||||
|
**Step 4 — run, expect PASS** (same command). **Step 5 — commit:** `feat: LTX-2 legal-frame helpers (core/ltx2.py)`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 2 — ffmpeg pipeline params [TDD]
|
||||||
|
|
||||||
|
### Task 2.1: `target_fps`, `snap32`, `frames` in `build_ffmpeg_command`
|
||||||
|
**Files:** Modify `core/ffmpeg.py:74` (`build_ffmpeg_command`); Test `tests/test_utils.py`.
|
||||||
|
|
||||||
|
**Step 1 — failing tests:**
|
||||||
|
```python
|
||||||
|
def test_ffmpeg_ltx2_fps_and_frames():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||||
|
short_side=512, target_fps=25, frames=201)
|
||||||
|
assert "-r" in cmd and cmd[cmd.index("-r")+1] == "25"
|
||||||
|
assert "-frames:v" in cmd and cmd[cmd.index("-frames:v")+1] == "201"
|
||||||
|
vf = cmd[cmd.index("-vf")+1]
|
||||||
|
assert "fps=25" in vf
|
||||||
|
|
||||||
|
def test_ffmpeg_ltx2_snap32_crop():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||||
|
short_side=512, snap32=True)
|
||||||
|
vf = cmd[cmd.index("-vf")+1]
|
||||||
|
assert "crop=trunc(iw/32)*32:trunc(ih/32)*32" in vf
|
||||||
|
|
||||||
|
def test_ffmpeg_foley_unchanged():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4", short_side=256)
|
||||||
|
assert "-r" not in cmd and "-frames:v" not in cmd
|
||||||
|
assert "crop=trunc" not in cmd[cmd.index("-vf")+1]
|
||||||
|
```
|
||||||
|
**Step 2 — run, expect FAIL** (unexpected kwargs).
|
||||||
|
|
||||||
|
**Step 3 — implement:** add params `target_fps: float | None = None, snap32: bool = False, frames: int | None = None` to the signature. After the scale filter (and before the VAAPI block), append:
|
||||||
|
```python
|
||||||
|
if snap32:
|
||||||
|
filters.append("crop=trunc(iw/32)*32:trunc(ih/32)*32")
|
||||||
|
if target_fps is not None:
|
||||||
|
filters.append(f"fps={target_fps:g}")
|
||||||
|
```
|
||||||
|
Add output flags: after `-t duration` (or near the encoder args, before `output_path`), when `target_fps` set add `cmd += ["-r", f"{target_fps:g}"]`; when `frames` set add `cmd += ["-frames:v", str(frames)]` (video frame cap — exact F). Ensure ordering keeps `-vf` before outputs. Keep `fps`/`snap32` filters out of the `image_sequence=False` vs `True` branches consistently (they apply to both; webp seq also benefits from fps/÷32).
|
||||||
|
|
||||||
|
**Step 4 — run, expect PASS.** Also run full `tests/test_utils.py` (the 3 pre-existing failures only). **Step 5 — commit:** `feat: LTX-2 ffmpeg params (target_fps, snap32, frames)`.
|
||||||
|
|
||||||
|
### Task 2.2: audio extract honors frame-exact duration
|
||||||
|
**Files:** `core/ffmpeg.py:145` (`build_audio_extract_command`) — confirm it takes a duration; if it derives from a fixed 8 s, add a `duration` param so the `.wav` for an LTX-2 webp sequence is exactly `F/25 s`. Add a test mirroring `test_audio_extract_timing` asserting the `-t` value equals `frames/fps`. Commit: `fix: audio extract duration for LTX-2 frame-exact clips`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 3 — per-tab `_mode`
|
||||||
|
|
||||||
|
### Task 3.1: attribute + persistence + migration
|
||||||
|
**Files:** `main.py` — `PlaylistWidget.__init__` (~3409, next to `_dest_folder`); `_save_playlist_tabs` (~5271); `_load_playlist_tabs` (~5315).
|
||||||
|
- Add `self._mode: str = "foley"` in `PlaylistWidget.__init__`.
|
||||||
|
- `_save_playlist_tabs`: add `"mode": pw._mode` to each tab dict.
|
||||||
|
- `_load_playlist_tabs`: after creating each pw, `pw._mode = t.get("mode", "foley")`.
|
||||||
|
- `_add_playlist_tab`: new tabs default `_mode="foley"` (already via init).
|
||||||
|
|
||||||
|
**Verify:** structure test passes; add `test_tab_mode_defaults_foley` (construct, assert each `_pws[i]._mode == "foley"`). Commit: `feat: per-tab export mode attribute (foley default)`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 4 — tab menu: duplicate / convert / toggle
|
||||||
|
|
||||||
|
### Task 4.1: menu actions + label badge
|
||||||
|
**Files:** `main.py` — `_PlaylistTabBar.contextMenuEvent` (~3300) add items; new handlers in `MainWindow`; tab-title rendering.
|
||||||
|
- Add to the tab context menu: **"Duplicate tab"**, **"Duplicate as LTX-2"**, and a checkable **"LTX-2 mode"** (checked when `pw._mode=="ltx2"`). Emit new signals (e.g. `duplicate_requested(idx, as_ltx2: bool)`, `mode_toggle_requested(idx)`) like the existing `pin_toggle_requested`.
|
||||||
|
- `MainWindow._on_duplicate_tab(idx, as_ltx2)`: build a new tab via `_add_playlist_tab(label=…, files=list(src._paths), separators=sorted(src._separators_before), select=True)`; set `pw._dest_folder = src._dest_folder + ("_ltx2" if as_ltx2 else "")`; `pw._mode = "ltx2" if as_ltx2 else src._mode`; if ltx2, apply LTX-2 defaults (Stage 5 hook); `_save_playlist_tabs()`; refresh.
|
||||||
|
- `MainWindow._on_tab_mode_toggle(idx)`: flip `pw._mode`; if now ltx2, apply LTX-2 defaults; `_save_playlist_tabs()`; re-sync controls (Stage 5).
|
||||||
|
- Label badge: when adding/refreshing a tab whose `_mode=="ltx2"`, show `f"{label} [LTX2]"` (or set a distinct color) — apply in `_refresh_layout`/`_add_playlist_tab` title set.
|
||||||
|
|
||||||
|
**Verify:** manual launch — right-click a tab → Duplicate as LTX-2 creates a `[LTX2]` tab with `_ltx2` folder; toggle works. Structure test still green. Commit: `feat: tab duplicate / Duplicate-as-LTX-2 / mode toggle + [LTX2] badge`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 5 — length control swap + export wiring
|
||||||
|
|
||||||
|
### Task 5.1: length control reflects active tab mode
|
||||||
|
**Files:** `main.py` — the clip-length widgets (`_spn_clip_dur` ~4051 area) + the tab-change sync hook (`_on_tab_changed` / `_sync_folder_field_to_tab` neighbor).
|
||||||
|
- Add a frames spinbox `_spn_frames` (min 9, singleStep 8 → always 8k+1; suffix " f"; tooltip live `= F/25 s`). Default 201.
|
||||||
|
- Add `_apply_mode_to_controls()`: if active tab `ltx2` → show `_spn_frames` (+ "Frames" label), hide the seconds Duration control, default resize 512 if unset; else show Duration (seconds), hide frames. Call it from `_on_tab_changed`, after `_on_duplicate_tab`/`_on_tab_mode_toggle`, and once after `_load_playlist_tabs`.
|
||||||
|
- A small label shows `= {F/25:.2f}s @25fps` updating on `_spn_frames.valueChanged`.
|
||||||
|
|
||||||
|
### Task 5.2: route LTX-2 params through export
|
||||||
|
**Files:** `main.py` — `_on_export` (~7317) + `ExportWorker` construction (~7484) + `_update_next_label`.
|
||||||
|
- When the active tab's `_mode=="ltx2"`: compute `frames = self._spn_frames.value()`; `fps = 25`; `duration = frames / fps`; pass `target_fps=25, snap32=True, frames=frames, duration=duration` through to `ExportWorker` → `build_ffmpeg_command`. Default `short_side` to 512 if 0/None in ltx2.
|
||||||
|
- Foley path: unchanged (no new params).
|
||||||
|
- `ExportWorker.__init__`/`run`: thread the new params (default None/False) into `build_ffmpeg_command`.
|
||||||
|
|
||||||
|
**Verify (manual, authoritative):** in an LTX-2 tab, export → inspect an output clip: `ffprobe` shows **25 fps, exactly F frames, W&H ÷32**; a Foley tab still exports 8 s/source-fps unchanged. Structure test green; full `pytest tests/test_utils.py` (3 pre-existing fails only). Commit: `feat: route LTX-2 (25fps, ÷32 crop, F frames) through export for ltx2 tabs`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 6 — finalize
|
||||||
|
- **Task 6.1:** Full regression — `pytest tests/test_ui_structure.py` + `tests/test_utils.py` separately; manual: Foley export unchanged, LTX-2 export legal (ffprobe), duplicate/convert, persistence across relaunch, guardrail + per-tab folder still work.
|
||||||
|
- **Task 6.2:** Changelog (`main.py` CHANGELOG, bump APP_VERSION) + README note (per-tab LTX-2 mode). Commit `docs: changelog + README for LTX-2 export mode`.
|
||||||
|
- **Task 6.3:** Hand off branch (depends on `tab-export-folder`; merge that first, then this).
|
||||||
|
|
||||||
|
## Risks
|
||||||
|
| Risk | Mitigation |
|
||||||
|
|------|-----------|
|
||||||
|
| `-frames:v` vs `-t` interaction yields F±1 frames | Set both `-t F/fps` and `-frames:v F`; verify exact count with ffprobe in 5.2. |
|
||||||
|
| `fps` filter + HW (VAAPI) filter ordering | Place `fps`/`snap32` among CPU filters before the VAAPI hwupload block; test a HW-encoder build if available. |
|
||||||
|
| Length-control swap leaves stale state across tab switches | `_apply_mode_to_controls()` called on every tab change + mode toggle + load. |
|
||||||
|
| Depends on unmerged `tab-export-folder` | Branch is based on it; land that branch first. |
|
||||||
|
|
||||||
|
## NOT in scope
|
||||||
|
Hard enforcement (illegal F/resize allowed manually), motion-interpolated fps, auto re-export of existing Foley clips, DB schema changes, scan-pipeline changes.
|
||||||
@@ -1,4 +1,23 @@
|
|||||||
|
# Core GUI
|
||||||
PyQt6>=6.4
|
PyQt6>=6.4
|
||||||
python-mpv>=1.0
|
python-mpv>=1.0
|
||||||
pytest>=7.0
|
|
||||||
|
# Audio & ML
|
||||||
|
librosa>=0.10
|
||||||
|
numpy>=1.24
|
||||||
|
scikit-learn>=1.3
|
||||||
|
joblib>=1.3
|
||||||
|
soundfile>=0.12
|
||||||
|
|
||||||
|
# Deep learning — install via setup_env.sh for correct CUDA version,
|
||||||
|
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
torch>=2.0
|
||||||
|
torchaudio>=2.0
|
||||||
|
transformers>=4.30,<5.0 # EAT remote model code incompatible with transformers 5.x
|
||||||
|
timm>=0.9
|
||||||
|
|
||||||
|
# Object detection
|
||||||
ultralytics>=8.0
|
ultralytics>=8.0
|
||||||
|
|
||||||
|
# Dev
|
||||||
|
pytest>=7.0
|
||||||
|
|||||||
@@ -1,19 +1,46 @@
|
|||||||
# 8-cut Windows setup script
|
# 8-cut Windows setup script
|
||||||
# Run once: powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
# Run once: powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||||
#
|
#
|
||||||
# Prerequisites: Python 3.10+ must be installed and on PATH
|
# Prerequisites: Python 3.11+ must be installed and on PATH
|
||||||
# https://www.python.org/downloads/
|
# https://www.python.org/downloads/
|
||||||
|
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
|
trap { Write-Host "`n$_" -ForegroundColor Red; Read-Host "Press Enter to close"; exit 1 }
|
||||||
$root = Split-Path -Parent $MyInvocation.MyCommand.Path
|
$root = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||||
|
|
||||||
Write-Host "=== 8-cut Windows Setup ===" -ForegroundColor Cyan
|
Write-Host "=== 8-cut Windows Setup ===" -ForegroundColor Cyan
|
||||||
|
|
||||||
# ── Python deps ────────────────────────────────────────────
|
# ── Virtual environment ───────────────────────────────────
|
||||||
Write-Host "`nInstalling Python dependencies..."
|
$venvDir = Join-Path $root ".venv"
|
||||||
pip install PyQt6 python-mpv
|
if (Test-Path (Join-Path $venvDir "Scripts\python.exe")) {
|
||||||
|
Write-Host "`nVirtual environment already exists, activating..." -ForegroundColor Green
|
||||||
|
} else {
|
||||||
|
Write-Host "`nCreating virtual environment..."
|
||||||
|
python -m venv $venvDir
|
||||||
|
Write-Host "Virtual environment created at $venvDir" -ForegroundColor Green
|
||||||
|
}
|
||||||
|
& "$venvDir\Scripts\Activate.ps1"
|
||||||
|
|
||||||
# ── libmpv ─────────────────────────────────────────────────
|
# ── PyTorch ───────────────────────────────────────────────
|
||||||
|
# Detect NVIDIA GPU via nvidia-smi
|
||||||
|
$hasNvidia = Get-Command nvidia-smi -ErrorAction SilentlyContinue
|
||||||
|
if ($hasNvidia) {
|
||||||
|
$torchIndex = "https://download.pytorch.org/whl/cu128"
|
||||||
|
Write-Host "`nNVIDIA GPU detected — using CUDA 12.8 PyTorch index" -ForegroundColor Green
|
||||||
|
} else {
|
||||||
|
$torchIndex = "https://download.pytorch.org/whl/cpu"
|
||||||
|
Write-Host "`nNo NVIDIA GPU detected — using CPU-only PyTorch index" -ForegroundColor Yellow
|
||||||
|
}
|
||||||
|
# Always install/upgrade torch stack from correct index
|
||||||
|
# (pip install is a no-op if already at the right version)
|
||||||
|
Write-Host "Installing PyTorch + torchaudio + torchvision..."
|
||||||
|
pip install torch torchaudio torchvision --index-url $torchIndex
|
||||||
|
|
||||||
|
# ── Python deps ───────────────────────────────────────────
|
||||||
|
Write-Host "`nInstalling project dependencies..."
|
||||||
|
pip install -r (Join-Path $root "requirements.txt") --extra-index-url $torchIndex
|
||||||
|
|
||||||
|
# ── libmpv ────────────────────────────────────────────────
|
||||||
$mpvDll = Join-Path $root "libmpv-2.dll"
|
$mpvDll = Join-Path $root "libmpv-2.dll"
|
||||||
if (Test-Path $mpvDll) {
|
if (Test-Path $mpvDll) {
|
||||||
Write-Host "`nlibmpv-2.dll already present, skipping." -ForegroundColor Green
|
Write-Host "`nlibmpv-2.dll already present, skipping." -ForegroundColor Green
|
||||||
@@ -30,12 +57,11 @@ if (Test-Path $mpvDll) {
|
|||||||
Write-Host "libmpv-2.dll downloaded." -ForegroundColor Green
|
Write-Host "libmpv-2.dll downloaded." -ForegroundColor Green
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── ffmpeg ─────────────────────────────────────────────────
|
# ── ffmpeg ────────────────────────────────────────────────
|
||||||
$ffmpeg = Join-Path $root "ffmpeg.exe"
|
$ffmpeg = Join-Path $root "ffmpeg.exe"
|
||||||
if (Test-Path $ffmpeg) {
|
if (Test-Path $ffmpeg) {
|
||||||
Write-Host "`nffmpeg.exe already present, skipping." -ForegroundColor Green
|
Write-Host "`nffmpeg.exe already present, skipping." -ForegroundColor Green
|
||||||
} else {
|
} else {
|
||||||
# Check if ffmpeg is on PATH
|
|
||||||
$onPath = Get-Command ffmpeg -ErrorAction SilentlyContinue
|
$onPath = Get-Command ffmpeg -ErrorAction SilentlyContinue
|
||||||
if ($onPath) {
|
if ($onPath) {
|
||||||
Write-Host "`nffmpeg found on PATH: $($onPath.Source)" -ForegroundColor Green
|
Write-Host "`nffmpeg found on PATH: $($onPath.Source)" -ForegroundColor Green
|
||||||
@@ -54,6 +80,12 @@ if (Test-Path $ffmpeg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ── Verify ────────────────────────────────────────────────
|
||||||
|
Write-Host "`n--- Verification ---" -ForegroundColor Cyan
|
||||||
|
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||||
|
python -c "import sklearn, librosa, torchaudio; print('All imports OK')"
|
||||||
|
|
||||||
Write-Host "`n=== Setup complete ===" -ForegroundColor Cyan
|
Write-Host "`n=== Setup complete ===" -ForegroundColor Cyan
|
||||||
Write-Host "Run 8-cut with: python main.py"
|
Write-Host "Run 8-cut with: .venv\Scripts\python.exe main.py"
|
||||||
Write-Host "Or double-click: 8cut.bat"
|
Write-Host "Or double-click: 8cut.bat"
|
||||||
|
Read-Host "`nPress Enter to close"
|
||||||
|
|||||||
@@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# 8-cut environment setup — supports conda (miniforge) or python venv
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./setup_env.sh # auto-detect (prefers conda if available)
|
||||||
|
# ./setup_env.sh --conda # force conda
|
||||||
|
# ./setup_env.sh --venv # force python venv
|
||||||
|
# ─��────────────────────────────��───────────────────────────────────────
|
||||||
|
|
||||||
|
ENV_NAME="8cut"
|
||||||
|
PYTHON_VERSION="3.12"
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||||
|
|
||||||
|
# Auto-detect GPU for PyTorch index URL
|
||||||
|
if command -v nvidia-smi &>/dev/null; then
|
||||||
|
TORCH_INDEX="https://download.pytorch.org/whl/cu128"
|
||||||
|
echo "NVIDIA GPU detected — will install PyTorch with CUDA 12.8"
|
||||||
|
else
|
||||||
|
TORCH_INDEX="https://download.pytorch.org/whl/cpu"
|
||||||
|
echo "No NVIDIA GPU detected — will install CPU-only PyTorch"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Parse args ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
MODE=""
|
||||||
|
for arg in "$@"; do
|
||||||
|
case "$arg" in
|
||||||
|
--conda) MODE="conda" ;;
|
||||||
|
--venv) MODE="venv" ;;
|
||||||
|
*) echo "Unknown arg: $arg"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$MODE" ]; then
|
||||||
|
if command -v conda &>/dev/null; then
|
||||||
|
MODE="conda"
|
||||||
|
else
|
||||||
|
MODE="venv"
|
||||||
|
fi
|
||||||
|
echo "Auto-detected mode: $MODE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Conda setup ─────────────��─────────────────────────────────────────
|
||||||
|
|
||||||
|
setup_conda() {
|
||||||
|
echo "==> Setting up conda environment: $ENV_NAME"
|
||||||
|
|
||||||
|
# Source conda shell hooks if not already active
|
||||||
|
if ! command -v conda &>/dev/null; then
|
||||||
|
echo "conda not found in PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
eval "$(conda shell.bash hook)"
|
||||||
|
|
||||||
|
if conda env list | grep -qw "$ENV_NAME"; then
|
||||||
|
echo " Environment '$ENV_NAME' already exists, updating..."
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
else
|
||||||
|
echo " Creating environment '$ENV_NAME' with Python $PYTHON_VERSION..."
|
||||||
|
conda create -y -n "$ENV_NAME" python="$PYTHON_VERSION"
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo " Installing project dependencies..."
|
||||||
|
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done! Activate with:"
|
||||||
|
echo " conda activate $ENV_NAME"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Venv setup ───────��────────────────────────────────────────────────
|
||||||
|
|
||||||
|
setup_venv() {
|
||||||
|
echo "==> Setting up Python venv at: $VENV_DIR"
|
||||||
|
|
||||||
|
if [ ! -d "$VENV_DIR" ]; then
|
||||||
|
python3 -m venv "$VENV_DIR"
|
||||||
|
echo " Created venv"
|
||||||
|
else
|
||||||
|
echo " Venv already exists, updating..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
source "$VENV_DIR/bin/activate"
|
||||||
|
|
||||||
|
echo " Installing PyTorch + torchaudio (CUDA 12.8)..."
|
||||||
|
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo " Installing project dependencies..."
|
||||||
|
pip install -r "$SCRIPT_DIR/requirements.txt" --extra-index-url "$TORCH_INDEX"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done! Activate with:"
|
||||||
|
echo " source $VENV_DIR/bin/activate"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Run ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
case "$MODE" in
|
||||||
|
conda) setup_conda ;;
|
||||||
|
venv) setup_venv ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Verify with:"
|
||||||
|
echo " python -c \"import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)\""
|
||||||
|
echo " python -c \"import librosa, torchaudio, sklearn; print('All imports OK')\""
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import tempfile, os
|
||||||
|
import numpy as np
|
||||||
|
from core.audio_scan import scan_video, load_classifier, default_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_video_no_model_returns_empty():
|
||||||
|
"""scan_video with no model should return empty list."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(vid.name, np.random.randn(16000 * 20).astype(np.float32) * 0.1, 16000)
|
||||||
|
try:
|
||||||
|
regions = scan_video(vid.name, model=None)
|
||||||
|
assert regions == []
|
||||||
|
finally:
|
||||||
|
os.unlink(vid.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_classifier_missing_returns_none():
|
||||||
|
assert load_classifier("/no/such/model.joblib") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_model_path_contains_profile():
|
||||||
|
path = default_model_path("test_profile")
|
||||||
|
assert "test_profile" in path
|
||||||
|
assert path.endswith(".joblib")
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_multi_layer():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
# Multi-layer models should report concatenated dimension
|
||||||
|
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||||
|
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||||
|
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||||
|
# Single-layer unchanged
|
||||||
|
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||||
|
|
||||||
|
|
||||||
|
def test_ml_config():
|
||||||
|
from core.audio_scan import _ml_config
|
||||||
|
assert _ml_config("HUBERT_XLARGE") is None
|
||||||
|
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||||
|
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||||
|
assert base == "HUBERT_XLARGE"
|
||||||
|
assert layers == [11, 23, 35, 47]
|
||||||
|
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||||
|
assert base == "HUBERT_BASE"
|
||||||
|
assert layers == [2, 5, 8, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_ast():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("AST") == 768
|
||||||
|
assert _embed_dim("AST_ML") == 3072
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_dim_eat():
|
||||||
|
from core.audio_scan import _embed_dim
|
||||||
|
assert _embed_dim("EAT") == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_get_all_export_paths():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||||
|
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||||
|
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||||
|
paths = db.get_all_export_paths("test")
|
||||||
|
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from core.db import ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_folders_excludes_scan_exports():
|
||||||
|
"""Scan-export-only folders should not appear when include_scan_exports=False."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Manual export
|
||||||
|
db.add("a.mp4", 10.0, "/out/mp4_Intense/g1/clip.mp4", profile="test")
|
||||||
|
# Scan export to different folder
|
||||||
|
db.add("a.mp4", 20.0, "/out/mp4_ScanOnly/g1/clip.mp4", profile="test",
|
||||||
|
scan_export=True)
|
||||||
|
folders = db.get_export_folders("test")
|
||||||
|
assert "mp4_Intense" in folders
|
||||||
|
assert "mp4_ScanOnly" not in folders, "scan-only folder should be excluded"
|
||||||
|
# With include_scan_exports=True, both should appear
|
||||||
|
folders_all = db.get_export_folders("test", include_scan_exports=True)
|
||||||
|
assert "mp4_ScanOnly" in folders_all
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_result_history():
|
||||||
|
"""save_scan_results should keep multiple versions."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Save three versions (microsecond-precision timestamps avoid collisions)
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(0, 8, 0.9)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A",
|
||||||
|
[(0, 8, 0.8), (10, 18, 0.7)])
|
||||||
|
db.save_scan_results("v.mp4", "test", "MODEL_A", [(5, 13, 0.95)])
|
||||||
|
versions = db.get_scan_versions("v.mp4", "test", "MODEL_A")
|
||||||
|
assert len(versions) == 3
|
||||||
|
# Most recent first
|
||||||
|
assert versions[0]["count"] == 1 # latest: 1 region
|
||||||
|
assert versions[1]["count"] == 2 # middle: 2 regions
|
||||||
|
assert versions[2]["count"] == 1 # oldest: 1 region
|
||||||
|
# get_scan_results returns latest version by default
|
||||||
|
results = db.get_scan_results("v.mp4", "test")
|
||||||
|
assert len(results.get("MODEL_A", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hard_negatives_source_model():
|
||||||
|
"""Hard negatives should store source_model."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0],
|
||||||
|
source_path="/a.mp4", source_model="HUBERT_XLARGE")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all(r["source_model"] == "HUBERT_XLARGE" for r in rows)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_data_skips_hard_negatives():
|
||||||
|
"""get_training_data with use_hard_negatives=False should skip them."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
# Create a source file that "exists" — use the temp db file itself
|
||||||
|
db.add("a.mp4", 10.0, "/out/folder/g/clip.mp4", profile="test",
|
||||||
|
source_path=path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [500.0], source_path=path)
|
||||||
|
# With hard negatives
|
||||||
|
data_with = db.get_training_data("test", "folder", use_hard_negatives=True)
|
||||||
|
# Without hard negatives
|
||||||
|
data_without = db.get_training_data("test", "folder", use_hard_negatives=False)
|
||||||
|
assert len(data_with) >= 1
|
||||||
|
# The "with" case should have the hard negative time in neg list
|
||||||
|
neg_with = sum(len(vi[3]) for vi in data_with)
|
||||||
|
neg_without = sum(len(vi[3]) for vi in data_without)
|
||||||
|
assert neg_with > neg_without, "hard negatives should be excluded when use_hard_negatives=False"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hard_negatives_by_ids():
|
||||||
|
"""delete_hard_negatives_by_ids should remove specific rows."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
db = ProcessedDB(path)
|
||||||
|
db.add_hard_negatives("a.mp4", "test", [10.0, 20.0, 30.0],
|
||||||
|
source_path="/a.mp4")
|
||||||
|
rows = db.get_hard_negatives("test")
|
||||||
|
assert len(rows) == 3
|
||||||
|
# Delete first two
|
||||||
|
db.delete_hard_negatives_by_ids([rows[0]["id"], rows[1]["id"]])
|
||||||
|
remaining = db.get_hard_negatives("test")
|
||||||
|
assert len(remaining) == 1
|
||||||
|
assert remaining[0]["start_time"] == 30.0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
# Redirect QSettings to a throwaway dir BEFORE any MainWindow is constructed, so
|
||||||
|
# these GUI tests can never read or clobber the user's real ~/.config/8cut.conf
|
||||||
|
# (constructing MainWindow loads — and on window close re-saves — the playlist
|
||||||
|
# tabs; a test mutating tab state would otherwise persist into the real session).
|
||||||
|
import tempfile as _tempfile
|
||||||
|
from PyQt6.QtCore import QSettings as _QSettings
|
||||||
|
_QS_DIR = _tempfile.mkdtemp(prefix="8cut-test-qs-")
|
||||||
|
_QSettings.setPath(_QSettings.Format.NativeFormat, _QSettings.Scope.UserScope, _QS_DIR)
|
||||||
|
_QSettings.setPath(_QSettings.Format.IniFormat, _QSettings.Scope.UserScope, _QS_DIR)
|
||||||
|
|
||||||
|
# A real platform is needed because MpvWidget creates a GL context.
|
||||||
|
# If construction fails for any environment reason, skip — this test is a
|
||||||
|
# best-effort structural net, not a gate on core/ tests.
|
||||||
|
pytestmark = pytest.mark.gui
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def app():
|
||||||
|
from PyQt6.QtWidgets import QApplication
|
||||||
|
inst = QApplication.instance() or QApplication([])
|
||||||
|
yield inst
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def win(app):
|
||||||
|
try:
|
||||||
|
from main import MainWindow
|
||||||
|
w = MainWindow()
|
||||||
|
except Exception as e: # GL/mpv/display unavailable, etc.
|
||||||
|
pytest.skip(f"MainWindow could not be constructed here: {e}")
|
||||||
|
# Deterministic deck state regardless of any persisted side-by-side layout
|
||||||
|
# (construction restores deck_pinned from QSettings).
|
||||||
|
for _p in w._deck_panels:
|
||||||
|
_p._pinned = False
|
||||||
|
w._refresh_deck_layout()
|
||||||
|
yield w
|
||||||
|
w.close()
|
||||||
|
w.deleteLater()
|
||||||
|
|
||||||
|
|
||||||
|
def test_window_constructs(win):
|
||||||
|
assert win.windowTitle().startswith("8-cut")
|
||||||
|
|
||||||
|
|
||||||
|
def test_status_bar_exists(win):
|
||||||
|
assert win.statusBar() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_workers_spinbox_in_export_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QSpinBox
|
||||||
|
assert win._spn_workers in win._tab_export.findChildren(QSpinBox)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_button_in_scan_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QPushButton
|
||||||
|
assert win._btn_scan in win._tab_scan.findChildren(QPushButton)
|
||||||
|
|
||||||
|
|
||||||
|
def test_portrait_combo_in_crop_tab(win):
|
||||||
|
from PyQt6.QtWidgets import QComboBox
|
||||||
|
assert win._cmb_portrait in win._tab_crop.findChildren(QComboBox)
|
||||||
|
|
||||||
|
|
||||||
|
def test_menu_only_buttons_not_in_deck(win):
|
||||||
|
from PyQt6.QtWidgets import QPushButton
|
||||||
|
deck_btns = win._control_deck.findChildren(QPushButton)
|
||||||
|
assert win._btn_train not in deck_btns
|
||||||
|
assert win._btn_scan_all not in deck_btns
|
||||||
|
assert win._btn_hide_subcats not in deck_btns
|
||||||
|
|
||||||
|
|
||||||
|
def test_deck_stack_exists(win):
|
||||||
|
# The deck is wrapped in a stack so it can swap tabbed <-> side-by-side.
|
||||||
|
# Default (nothing pinned) shows the tabbed control deck.
|
||||||
|
assert win._deck_stack is not None
|
||||||
|
assert win._deck_stack.currentWidget() is win._control_deck
|
||||||
|
|
||||||
|
|
||||||
|
def _split_columns(win):
|
||||||
|
"""Widgets of the splitter actually mounted in the layout (not findChild,
|
||||||
|
which can return a stale deleteLater'd splitter)."""
|
||||||
|
from PyQt6.QtWidgets import QSplitter
|
||||||
|
item = win._deck_split_layout.itemAt(0)
|
||||||
|
spl = item.widget() if item else None
|
||||||
|
assert isinstance(spl, QSplitter)
|
||||||
|
return [spl.widget(i) for i in range(spl.count())]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pinning_two_panels_shows_exactly_two_columns(win):
|
||||||
|
# Pin two panels directly (avoid the toggle handler so no QSettings write
|
||||||
|
# leaks into other test windows) and refresh.
|
||||||
|
from PyQt6.QtWidgets import QTabWidget
|
||||||
|
win._tab_export._pinned = True
|
||||||
|
win._tab_crop._pinned = True
|
||||||
|
win._refresh_deck_layout()
|
||||||
|
assert win._deck_stack.currentWidget() is win._deck_split_container
|
||||||
|
cols = _split_columns(win)
|
||||||
|
assert len(cols) == 2 # only the pinned ones
|
||||||
|
assert not any(isinstance(c, QTabWidget) for c in cols) # no leftover tab-column
|
||||||
|
|
||||||
|
|
||||||
|
def test_side_by_side_menu_pins_third_panel(win):
|
||||||
|
# In split mode the View ▸ Side-by-side menu is the way to pin a 3rd panel
|
||||||
|
# (there's no tab bar to right-click). Suppress the QSettings save via the
|
||||||
|
# _deck_loading guard so this doesn't leak into other windows.
|
||||||
|
win._tab_export._pinned = True
|
||||||
|
win._tab_scan._pinned = True
|
||||||
|
win._refresh_deck_layout()
|
||||||
|
assert len(_split_columns(win)) == 2
|
||||||
|
act = next(a for a, p in win._deck_pin_actions if p is win._tab_crop)
|
||||||
|
win._deck_loading = True # suppress _save_deck_layout
|
||||||
|
try:
|
||||||
|
act.trigger() # simulate clicking the menu item
|
||||||
|
finally:
|
||||||
|
win._deck_loading = False
|
||||||
|
assert win._tab_crop._pinned is True
|
||||||
|
assert len(_split_columns(win)) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_tab(win):
|
||||||
|
# Right-click → Duplicate tab: clones files into a new tab with an adapted
|
||||||
|
# name + adapted own folder, no file moves. Suppress QSettings writes via
|
||||||
|
# _loading_tabs so the test can't touch the real session.
|
||||||
|
win._loading_tabs = True
|
||||||
|
try:
|
||||||
|
src = win._pws[0]
|
||||||
|
src._label = "AlexisCrystal"
|
||||||
|
src._dest_folder = "/data/alexis/" # trailing slash, like real folders
|
||||||
|
n_before = len(win._pws)
|
||||||
|
win._on_duplicate_tab(win._playlist_tabs.indexOf(src))
|
||||||
|
finally:
|
||||||
|
win._loading_tabs = False
|
||||||
|
assert len(win._pws) == n_before + 1
|
||||||
|
dup = win._pws[-1]
|
||||||
|
assert dup._label == "AlexisCrystal copy"
|
||||||
|
# sibling, not a child: ".../alexis/" -> ".../alexis_copy" (not ".../alexis/_copy")
|
||||||
|
assert dup._dest_folder == "/data/alexis_copy"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tab_mode_defaults_foley(win):
|
||||||
|
# Fresh tabs use the Foley pipeline; sessions/tabs without a stored mode
|
||||||
|
# load unchanged.
|
||||||
|
assert win._pws
|
||||||
|
for pw in win._pws:
|
||||||
|
assert pw._mode == "foley"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tab_mode_toggle(win):
|
||||||
|
# Right-click → "LTX-2 mode" flips the per-tab mode and the displayed title
|
||||||
|
# gains a [LTX2] badge (without mutating pw._label). Suppress QSettings
|
||||||
|
# writes via _loading_tabs so the test can't touch the real session.
|
||||||
|
win._loading_tabs = True
|
||||||
|
try:
|
||||||
|
win._on_tab_mode_toggle(win._playlist_tabs.indexOf(win._pws[0]))
|
||||||
|
finally:
|
||||||
|
win._loading_tabs = False
|
||||||
|
assert win._pws[0]._mode == "ltx2"
|
||||||
|
assert win._tab_title(win._pws[0]).endswith("[LTX2]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ltx2_params_none_for_foley(win):
|
||||||
|
# A Foley tab feeds no LTX-2 ffmpeg params into export. Set the mode
|
||||||
|
# explicitly: a prior test's closeEvent can persist an ltx2 tab into the
|
||||||
|
# shared (throwaway) QSettings, so don't rely on the loaded default here.
|
||||||
|
win._playlist._mode = "foley"
|
||||||
|
assert win._ltx2_export_params() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_ltx2_params_for_ltx2_tab(win):
|
||||||
|
# An ltx2-mode active tab: _ltx2_export_params returns the 25fps / ÷32 /
|
||||||
|
# exact-frames kwargs, and _apply_mode_to_controls swaps the length control
|
||||||
|
# (Duration hidden, frames shown). short_side defaults to 512 when unset.
|
||||||
|
win._spn_resize.setValue(0) # force the 512 LTX-2 default path
|
||||||
|
win._pws[0]._mode = "ltx2"
|
||||||
|
win._active_pw = win._pws[0]
|
||||||
|
win._playlist_tabs.setCurrentWidget(win._pws[0])
|
||||||
|
win._spn_frames.setValue(201)
|
||||||
|
win._apply_mode_to_controls()
|
||||||
|
|
||||||
|
assert win._ltx2_export_params() == {
|
||||||
|
"target_fps": 25.0,
|
||||||
|
"snap32": True,
|
||||||
|
"frames": 201,
|
||||||
|
"duration": 201 / 25,
|
||||||
|
"short_side": 512,
|
||||||
|
}
|
||||||
|
# In offscreen, isVisibleTo(win) may be False for both; assert via the
|
||||||
|
# show/hide flag that the Duration control is hidden in ltx2 mode.
|
||||||
|
assert win._spn_clip_dur.isHidden()
|
||||||
|
assert not win._spn_frames.isHidden()
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_preserves_ltx2_mode(win):
|
||||||
|
# Duplicating an LTX-2 tab must yield an LTX-2 tab (mode is copied alongside
|
||||||
|
# the folder fields). Suppress QSettings writes via _loading_tabs.
|
||||||
|
win._loading_tabs = True
|
||||||
|
try:
|
||||||
|
src = win._pws[0]
|
||||||
|
src._mode = "ltx2"
|
||||||
|
win._on_duplicate_tab(win._playlist_tabs.indexOf(src))
|
||||||
|
finally:
|
||||||
|
win._loading_tabs = False
|
||||||
|
dup = win._pws[-1]
|
||||||
|
assert dup._mode == "ltx2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_snaps_to_legal(win):
|
||||||
|
# A typed (illegal) frame count snaps to the nearest legal 8k+1 value so the
|
||||||
|
# displayed value == the exported value and is always a valid LTX-2 clip.
|
||||||
|
win._spn_frames.setValue(100)
|
||||||
|
win._snap_frames_to_legal() # the editingFinished slot
|
||||||
|
assert win._spn_frames.value() == 97 # nearest 8k+1 to 100
|
||||||
|
assert (win._spn_frames.value() - 1) % 8 == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_base_name_handles_trailing_slash(win):
|
||||||
|
# A folder ending in "/" must still yield the real base name, else
|
||||||
|
# subprofile naming breaks ("_blowjob" instead of "mp4_blowjob").
|
||||||
|
win._txt_folder.setText("/x/AlexisCrystal/mp4/")
|
||||||
|
assert win._export_base_name() == "mp4"
|
||||||
|
win._txt_folder.setText("/x/AlexisCrystal/mp4")
|
||||||
|
assert win._export_base_name() == "mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprofile_button_visibility_exact_match(win):
|
||||||
|
# A subcategory's export button must track ITS folder exactly. A ghost
|
||||||
|
# "_blowjob" (empty-base leftover) or an unrelated "mp4_no_clap" must NOT
|
||||||
|
# hide the "blowjob"/"clap" buttons (the old fuzzy endswith() match did,
|
||||||
|
# so enabling a subcategory never revealed its export button).
|
||||||
|
win._txt_folder.setText("/x/AlexisCrystal/mp4")
|
||||||
|
win._subprofiles = ["blowjob", "clap"]
|
||||||
|
win._rebuild_subprofile_buttons()
|
||||||
|
btns = {b.text().removeprefix("▸ "): b for b in win._subprofile_btns}
|
||||||
|
|
||||||
|
win._hidden_subcats = {"_blowjob", "mp4_no_clap"}
|
||||||
|
win._apply_subcat_visibility()
|
||||||
|
assert not btns["blowjob"].isHidden() # ghost "_blowjob" must not hide it
|
||||||
|
assert not btns["clap"].isHidden() # "mp4_no_clap" must not hide "clap"
|
||||||
|
|
||||||
|
win._hidden_subcats = {"mp4_blowjob"} # exact folder -> hidden
|
||||||
|
win._apply_subcat_visibility()
|
||||||
|
assert btns["blowjob"].isHidden()
|
||||||
|
assert not btns["clap"].isHidden()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_audio_controls_exist(win):
|
||||||
|
from PyQt6.QtWidgets import QPushButton, QDoubleSpinBox
|
||||||
|
assert isinstance(win._btn_extract_audio, QPushButton)
|
||||||
|
assert isinstance(win._spn_audio_len, QDoubleSpinBox)
|
||||||
|
# Disabled until a file is loaded.
|
||||||
|
assert not win._btn_extract_audio.isEnabled()
|
||||||
|
# Arrows step by 1s and there's no practical upper cap (long audio areas).
|
||||||
|
assert win._spn_audio_len.singleStep() == 1.0
|
||||||
|
assert win._spn_audio_len.maximum() >= 3600.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_region_tracks_cursor_and_length(win):
|
||||||
|
# The teal audio band spans [cursor, cursor + length]; changing the length
|
||||||
|
# or moving the cursor moves the band. Fake a loaded file so the guard in
|
||||||
|
# _update_audio_region passes.
|
||||||
|
win._file_path = "/x/video.mp4"
|
||||||
|
win._cursor = 10.0
|
||||||
|
win._spn_audio_len.setValue(4.0) # fires _on_audio_len_changed
|
||||||
|
assert win._timeline._audio_region == (10.0, 14.0)
|
||||||
|
win._cursor = 20.0
|
||||||
|
win._update_audio_region()
|
||||||
|
assert win._timeline._audio_region == (20.0, 24.0)
|
||||||
|
# No file -> band cleared.
|
||||||
|
win._file_path = ""
|
||||||
|
win._update_audio_region()
|
||||||
|
assert win._timeline._audio_region is None
|
||||||
@@ -1,24 +1,26 @@
|
|||||||
import tempfile, os, json
|
import tempfile, os, json
|
||||||
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, build_annotation_json_path, upsert_clip_annotation, resolve_keyframe, apply_keyframes_to_jobs
|
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, resolve_keyframe, apply_keyframes_to_jobs
|
||||||
|
from core.ffmpeg import build_audio_clip_command
|
||||||
|
from core.annotations import build_annotation_json_path, upsert_clip_annotation
|
||||||
from main import ProcessedDB
|
from main import ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
def test_build_export_path_first():
|
def test_build_export_path_first():
|
||||||
assert build_export_path("/out", "clip", 1) == "/out/clip_001/clip_001.mp4"
|
assert build_export_path("/out", "clip", 1) == "/out/clip_001.mp4"
|
||||||
|
|
||||||
def test_build_export_path_counter():
|
def test_build_export_path_counter():
|
||||||
assert build_export_path("/out", "clip", 42) == "/out/clip_042/clip_042.mp4"
|
assert build_export_path("/out", "clip", 42) == "/out/clip_042.mp4"
|
||||||
|
|
||||||
def test_build_export_path_deep_counter():
|
def test_build_export_path_deep_counter():
|
||||||
assert build_export_path("/out", "shot", 999) == "/out/shot_999/shot_999.mp4"
|
assert build_export_path("/out", "shot", 999) == "/out/shot_999.mp4"
|
||||||
|
|
||||||
def test_build_export_path_sub():
|
def test_build_export_path_sub():
|
||||||
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0.mp4"
|
assert build_export_path("/out", "clip", 1, sub=0) == "/out/clip_001_0.mp4"
|
||||||
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001/clip_001_2.mp4"
|
assert build_export_path("/out", "clip", 1, sub=2) == "/out/clip_001_2.mp4"
|
||||||
|
|
||||||
def test_build_sequence_dir_sub():
|
def test_build_sequence_dir_sub():
|
||||||
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001/clip_001_0"
|
assert build_sequence_dir("/out", "clip", 1, sub=0) == "/out/clip_001_0"
|
||||||
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001/clip_001_1"
|
assert build_sequence_dir("/out", "clip", 1, sub=1) == "/out/clip_001_1"
|
||||||
|
|
||||||
def test_format_time_seconds():
|
def test_format_time_seconds():
|
||||||
assert format_time(0.0) == "0:00.0"
|
assert format_time(0.0) == "0:00.0"
|
||||||
@@ -53,6 +55,27 @@ def test_ffmpeg_command_with_resize():
|
|||||||
assert cmd[-1] == "/out/clip_001.mp4"
|
assert cmd[-1] == "/out/clip_001.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_clip_command_exact_length():
|
||||||
|
cmd = build_audio_clip_command("/in/video.mp4", 12.5, 3.2, "/out/clip.wav")
|
||||||
|
assert cmd[0] == "ffmpeg"
|
||||||
|
# fast seek before input, exact duration, no video
|
||||||
|
assert cmd[cmd.index("-ss") + 1] == "12.5"
|
||||||
|
assert cmd[cmd.index("-t") + 1] == "3.2"
|
||||||
|
assert cmd.index("-ss") < cmd.index("-i")
|
||||||
|
assert "-vn" in cmd
|
||||||
|
assert cmd[-1] == "/out/clip.wav"
|
||||||
|
|
||||||
|
def test_audio_clip_command_codec_by_extension():
|
||||||
|
assert "pcm_s16le" in build_audio_clip_command("/in.mp4", 0, 1, "/o/a.wav")
|
||||||
|
assert "libmp3lame" in build_audio_clip_command("/in.mp4", 0, 1, "/o/a.mp3")
|
||||||
|
assert "flac" in build_audio_clip_command("/in.mp4", 0, 1, "/o/a.flac")
|
||||||
|
# Unknown extension -> no explicit -c:a, let ffmpeg pick from the container.
|
||||||
|
assert "-c:a" not in build_audio_clip_command("/in.mp4", 0, 1, "/o/a.xyz")
|
||||||
|
|
||||||
|
def test_audio_clip_command_extension_case_insensitive():
|
||||||
|
assert "flac" in build_audio_clip_command("/in.mp4", 0, 1, "/o/A.FLAC")
|
||||||
|
|
||||||
|
|
||||||
# --- ProcessedDB ---
|
# --- ProcessedDB ---
|
||||||
|
|
||||||
def test_db_add_and_get_markers():
|
def test_db_add_and_get_markers():
|
||||||
@@ -177,10 +200,10 @@ def test_audio_extract_timing():
|
|||||||
|
|
||||||
|
|
||||||
def test_build_sequence_dir_basic():
|
def test_build_sequence_dir_basic():
|
||||||
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001/clip_001"
|
assert build_sequence_dir("/out", "clip", 1) == "/out/clip_001"
|
||||||
|
|
||||||
def test_build_sequence_dir_counter():
|
def test_build_sequence_dir_counter():
|
||||||
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042/clip_042"
|
assert build_sequence_dir("/out", "clip", 42) == "/out/clip_042"
|
||||||
|
|
||||||
def test_ffmpeg_command_image_sequence():
|
def test_ffmpeg_command_image_sequence():
|
||||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/seq_001", image_sequence=True)
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/seq_001", image_sequence=True)
|
||||||
@@ -264,13 +287,13 @@ def test_db_get_group_returns_all_sub_clips():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_2.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_2.mp4")
|
||||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(group) == 3
|
assert len(group) == 3
|
||||||
assert "/out/clip_001/clip_001_0.mp4" in group
|
assert "/out/vid_001/clip_001_0.mp4" in group
|
||||||
assert "/out/clip_001/clip_001_2.mp4" in group
|
assert "/out/vid_001/clip_001_2.mp4" in group
|
||||||
finally:
|
finally:
|
||||||
os.unlink(path)
|
os.unlink(path)
|
||||||
|
|
||||||
@@ -280,10 +303,10 @@ def test_db_get_group_isolates_by_start_time():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||||
group = db.get_group("/out/clip_001/clip_001_0.mp4")
|
group = db.get_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(group) == 2
|
assert len(group) == 2
|
||||||
finally:
|
finally:
|
||||||
os.unlink(path)
|
os.unlink(path)
|
||||||
@@ -294,10 +317,10 @@ def test_db_delete_group_removes_all():
|
|||||||
path = f.name
|
path = f.name
|
||||||
try:
|
try:
|
||||||
db = ProcessedDB(path)
|
db = ProcessedDB(path)
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_0.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_0.mp4")
|
||||||
db.add("video.mp4", 10.0, "/out/clip_001/clip_001_1.mp4")
|
db.add("video.mp4", 10.0, "/out/vid_001/clip_001_1.mp4")
|
||||||
db.add("video.mp4", 30.0, "/out/clip_002/clip_002_0.mp4")
|
db.add("video.mp4", 30.0, "/out/vid_001/clip_002_0.mp4")
|
||||||
deleted = db.delete_group("/out/clip_001/clip_001_0.mp4")
|
deleted = db.delete_group("/out/vid_001/clip_001_0.mp4")
|
||||||
assert len(deleted) == 2
|
assert len(deleted) == 2
|
||||||
# clip_002 should still exist
|
# clip_002 should still exist
|
||||||
markers = db.get_markers("video.mp4")
|
markers = db.get_markers("video.mp4")
|
||||||
@@ -438,3 +461,57 @@ def test_apply_keyframes_before_first_uses_base():
|
|||||||
result = apply_keyframes_to_jobs(jobs, kfs, base_center=0.5, base_ratio="4:5",
|
result = apply_keyframes_to_jobs(jobs, kfs, base_center=0.5, base_ratio="4:5",
|
||||||
base_rand_p=True, base_rand_s=False)
|
base_rand_p=True, base_rand_s=False)
|
||||||
assert result == [(1.0, "/out/a", "4:5", 0.5, True, False)]
|
assert result == [(1.0, "/out/a", "4:5", 0.5, True, False)]
|
||||||
|
|
||||||
|
|
||||||
|
# --- LTX-2 legal-frame math (core/ltx2.py) ---
|
||||||
|
|
||||||
|
from core.ltx2 import is_legal_frames, nearest_legal_frames, frames_for_duration, duration_for_frames, legal_frames
|
||||||
|
|
||||||
|
def test_ltx2_is_legal():
|
||||||
|
assert is_legal_frames(201) and is_legal_frames(9) and is_legal_frames(25)
|
||||||
|
assert not is_legal_frames(200) and not is_legal_frames(8)
|
||||||
|
|
||||||
|
def test_ltx2_nearest():
|
||||||
|
assert nearest_legal_frames(200) == 201 # 200 -> nearest 8k+1
|
||||||
|
assert nearest_legal_frames(196) == 193
|
||||||
|
assert nearest_legal_frames(5) == 9 # floor at 9
|
||||||
|
|
||||||
|
def test_ltx2_duration_roundtrip():
|
||||||
|
assert duration_for_frames(201, 25) == 201 / 25
|
||||||
|
assert frames_for_duration(8.0, 25) == 201 # 200 -> 201
|
||||||
|
|
||||||
|
def test_ltx2_legal_series():
|
||||||
|
s = legal_frames(min_f=9, max_f=33)
|
||||||
|
assert s == [9, 17, 25, 33]
|
||||||
|
|
||||||
|
|
||||||
|
# --- LTX-2 ffmpeg params (target_fps, snap32, frames) ---
|
||||||
|
|
||||||
|
def test_ffmpeg_ltx2_fps_and_frames():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||||
|
short_side=512, target_fps=25, frames=201)
|
||||||
|
assert "-r" in cmd and cmd[cmd.index("-r")+1] == "25"
|
||||||
|
assert "-frames:v" in cmd and cmd[cmd.index("-frames:v")+1] == "201"
|
||||||
|
vf = cmd[cmd.index("-vf")+1]
|
||||||
|
assert "fps=25" in vf
|
||||||
|
|
||||||
|
def test_ffmpeg_ltx2_snap32_crop():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||||
|
short_side=512, snap32=True)
|
||||||
|
vf = cmd[cmd.index("-vf")+1]
|
||||||
|
assert "crop=trunc(iw/32)*32:trunc(ih/32)*32" in vf
|
||||||
|
|
||||||
|
def test_ffmpeg_foley_unchanged():
|
||||||
|
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4", short_side=256)
|
||||||
|
assert "-r" not in cmd and "-frames:v" not in cmd
|
||||||
|
assert "crop=trunc" not in cmd[cmd.index("-vf")+1]
|
||||||
|
|
||||||
|
|
||||||
|
# --- LTX-2 audio extract frame-exact duration ---
|
||||||
|
|
||||||
|
def test_audio_extract_ltx2_duration():
|
||||||
|
frames, fps = 201, 25
|
||||||
|
cmd = build_audio_extract_command("/in/v.mp4", 0.0, "/out/clip_001",
|
||||||
|
duration=frames / fps)
|
||||||
|
assert "-t" in cmd
|
||||||
|
assert cmd[cmd.index("-t") + 1] == str(frames / fps)
|
||||||
|
|||||||