Compare commits
165 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eab5c690c7 | |||
| 4445f0e7f4 | |||
| ed63d04abf | |||
| 7ae1720b9e | |||
| 514607eddd | |||
| 4299de5f97 | |||
| 86ab606059 | |||
| 87ccd8650c | |||
| ad9e564991 | |||
| 4baac54930 | |||
| 879684ce25 | |||
| 92774216d4 | |||
| 02fd0f0919 | |||
| c537ac678d | |||
| 755f7e5131 | |||
| 1eb7de2a1a | |||
| d7680283a2 | |||
| bf4b6dad2d | |||
| 4715c0ce49 | |||
| e5ce59c065 | |||
| cbbdfeadb1 | |||
| 8a7d761815 | |||
| 140a424469 | |||
| bc6e30a2d4 | |||
| 2ea3a9149a | |||
| e820c106af | |||
| 780832d4aa | |||
| 6037f15e7b | |||
| 035eaf3894 | |||
| 35ea1baec8 | |||
| 6a71386ed8 | |||
| d1fb35af8e | |||
| c55693094d | |||
| 5832d08b26 | |||
| b4cfa7561a | |||
| 0ccc29709e | |||
| 7e917d00a6 | |||
| 2ffb81eaa3 | |||
| b448085242 | |||
| 7cf90c1e5c | |||
| 5aa6878cf6 | |||
| 0e903812fa | |||
| d23ae2e88a | |||
| d97de8de10 | |||
| c6673228fa | |||
| fa4104eded | |||
| 9f7d2e1185 | |||
| c2e6c62c00 | |||
| 8aa8d8805b | |||
| 35c67f4bd5 | |||
| b738a19304 | |||
| dbd8e6a8ac | |||
| 73dfea4ae9 | |||
| 2170e72cbd | |||
| c9915914c4 | |||
| 251747bb0b | |||
| 13c4d3f7f6 | |||
| 1d49ce7cee | |||
| 109bc658c3 | |||
| ec7138f51b | |||
| 68c633ab46 | |||
| d0a94e7b68 | |||
| 632c2dc076 | |||
| 0f335c5e66 | |||
| f1f8fd5244 | |||
| 299779cf29 | |||
| 56218c18f4 | |||
| 2c45aff668 | |||
| 07e2f733b9 | |||
| 8c5a4c4524 | |||
| 4e5b631efb | |||
| ec77b8224f | |||
| 9becd5a06d | |||
| fae5560e2d | |||
| 07e3a1223c | |||
| 3af6e05fb7 | |||
| d787871735 | |||
| 85c08d7c48 | |||
| f6966a092a | |||
| 7cee3ab768 | |||
| 47f910644d | |||
| e972c7a2ae | |||
| cb805c5bda | |||
| bf14247b00 | |||
| 73396659dc | |||
| c8bc629419 | |||
| de8840e1eb | |||
| def966a913 | |||
| bc4ae21153 | |||
| a731fbfc32 | |||
| 1bdeb33a6f | |||
| 387ed7bc6a | |||
| f268d61fe4 | |||
| 24db32c09f | |||
| 0f6ae88ea6 | |||
| 4d99cf6015 | |||
| b75fa85ff5 | |||
| e7d47331c6 | |||
| 7cd31ebe55 | |||
| 3a37dddfd9 | |||
| b249705506 | |||
| aaf405dd3d | |||
| cb2060beb8 | |||
| 0db412baf4 | |||
| 876026d1f6 | |||
| 6c1d42adfe | |||
| d8b3972bdc | |||
| bd345abca2 | |||
| 7d6fee9df1 | |||
| fd043f4172 | |||
| 3c3b1d74bb | |||
| a3c657c66e | |||
| 5d45b8d8eb | |||
| e6db83f00b | |||
| edc5784ba6 | |||
| 8ed9fbf557 | |||
| 4fb2ae144f | |||
| 2614a765d5 | |||
| c020c0dfec | |||
| e7b791fbfa | |||
| f5361a963e | |||
| 8fb8581816 | |||
| 5b25e85e98 | |||
| e3f133ef84 | |||
| 4736f150b0 | |||
| 52aa982aa2 | |||
| 07457d0d6f | |||
| c5d613fc5f | |||
| 7855ea62c2 | |||
| 70be5974cf | |||
| a0286d5cf9 | |||
| 2b7dfb330d | |||
| 518554f788 | |||
| 282156e8ed | |||
| 3417a0f603 | |||
| cd0552197f | |||
| 7dffcb08eb | |||
| 93bcb23fa7 | |||
| eda7826a40 | |||
| e7e20b0fe6 | |||
| 814ef946eb | |||
| 2e738df9ae | |||
| 6ddfcde8ee | |||
| b161412d94 | |||
| 5a9e068903 | |||
| 6870e5aaf3 | |||
| f597ff29e8 | |||
| e1789d4e71 | |||
| 7834b1d05c | |||
| 12ed183f1b | |||
| f2c38aee79 | |||
| 8ab5bdba77 | |||
| c6c5934fe8 | |||
| 73d5367424 | |||
| 1e2cebd424 | |||
| c439aca9b9 | |||
| afda9b2d9f | |||
| fd42791c9f | |||
| 4cf54f2642 | |||
| e7f4de9ec1 | |||
| 9cf9e3233f | |||
| e17d8f67aa | |||
| b1980de6d1 | |||
| 85e0641440 | |||
| 834b89b682 |
@@ -3,5 +3,8 @@ __pycache__/
|
||||
*.pyo
|
||||
.pytest_cache/
|
||||
.worktrees/
|
||||
client/node_modules/
|
||||
client/src-tauri/target/
|
||||
.venv/
|
||||
models/
|
||||
cache/
|
||||
*.joblib
|
||||
*.pt
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
@echo off
|
||||
cd /d "%~dp0"
|
||||
python main.py %*
|
||||
if exist ".venv\Scripts\python.exe" (
|
||||
.venv\Scripts\python.exe main.py %*
|
||||
) else (
|
||||
python main.py %*
|
||||
)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
# Launch 8-cut with auto-detected venv/conda environment
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
ENV_NAME="8cut"
|
||||
CONDA_PREFIX_BASE="/media/p5/miniforge3"
|
||||
export LD_PRELOAD=/usr/lib/libstdc++.so.6
|
||||
|
||||
# 1. Try .venv in project dir
|
||||
if [ -f "$SCRIPT_DIR/.venv/bin/activate" ]; then
|
||||
source "$SCRIPT_DIR/.venv/bin/activate"
|
||||
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
|
||||
# 2. Try conda env (works without shell init)
|
||||
CONDA_PYTHON="$CONDA_PREFIX_BASE/envs/$ENV_NAME/bin/python"
|
||||
if [ -x "$CONDA_PYTHON" ]; then
|
||||
exec "$CONDA_PYTHON" "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
|
||||
# 3. Try conda via shell hook (interactive shells)
|
||||
if command -v conda &>/dev/null; then
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
if conda env list 2>/dev/null | grep -qw "$ENV_NAME"; then
|
||||
conda activate "$ENV_NAME"
|
||||
exec python "$SCRIPT_DIR/main.py" "$@"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 4. Fallback to system Python
|
||||
exec python3 "$SCRIPT_DIR/main.py" "$@"
|
||||
@@ -30,6 +30,11 @@ mpv_dir = Path(os.environ.get("MPV_DIR", base))
|
||||
|
||||
datas = []
|
||||
|
||||
# Bundled assets (icons, logo) — must exist at runtime under sys._MEIPASS/assets
|
||||
assets_dir = base / "assets"
|
||||
if assets_dir.exists():
|
||||
datas.append((str(assets_dir), "assets"))
|
||||
|
||||
# YOLOv8 model (optional — large, skip if missing)
|
||||
yolo = base / "yolov8n.pt"
|
||||
if yolo.exists():
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Calibration — per-video normalized features + classifier."""
|
||||
import sys, os, time, warnings
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
||||
from core.audio_scan import _SR, _WINDOW
|
||||
|
||||
_HOP_LENGTH = 1024
|
||||
_N_FFT = 2048
|
||||
from core.db import ProcessedDB
|
||||
|
||||
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||
PROFILE_NAME = "JAV_missionary"
|
||||
TOLERANCE = 12.0
|
||||
NEG_MARGIN = 120.0
|
||||
|
||||
|
||||
def extract_rich_features(y, sr=_SR):
|
||||
"""Per-frame features: onset, energy, spectral shape, mel bands (22 features)."""
|
||||
hop = _HOP_LENGTH
|
||||
S = np.abs(librosa.stft(y, n_fft=_N_FFT, hop_length=hop)) ** 2
|
||||
rms = librosa.feature.rms(S=S, hop_length=hop)
|
||||
cent = librosa.feature.spectral_centroid(S=S, sr=sr)
|
||||
bw = librosa.feature.spectral_bandwidth(S=S, sr=sr)
|
||||
rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr)
|
||||
flatness = librosa.feature.spectral_flatness(S=S)
|
||||
zcr = librosa.feature.zero_crossing_rate(y, hop_length=hop)
|
||||
onset = librosa.onset.onset_strength(S=librosa.power_to_db(S), sr=sr, hop_length=hop).reshape(1, -1)
|
||||
|
||||
mel_S = librosa.feature.melspectrogram(S=S, sr=sr, hop_length=hop, n_mels=128)
|
||||
mel_freqs = librosa.mel_frequencies(n_mels=128, fmin=0, fmax=sr/2)
|
||||
bands = [(0, 100), (100, 300), (300, 600), (600, 1200),
|
||||
(1200, 2000), (2000, 3500), (3500, 5500), (5500, 8000)]
|
||||
band_feats = []
|
||||
for flo, fhi in bands:
|
||||
mask = (mel_freqs >= flo) & (mel_freqs < fhi)
|
||||
if mask.sum() > 0:
|
||||
band_feats.append(librosa.power_to_db(mel_S[mask].mean(axis=0, keepdims=True) + 1e-10))
|
||||
else:
|
||||
band_feats.append(np.zeros((1, mel_S.shape[1])))
|
||||
|
||||
sc = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=hop)
|
||||
|
||||
min_t = min(rms.shape[1], cent.shape[1], onset.shape[1], sc.shape[1],
|
||||
band_feats[0].shape[1])
|
||||
return np.vstack([
|
||||
rms[:, :min_t], cent[:, :min_t], bw[:, :min_t], rolloff[:, :min_t],
|
||||
flatness[:, :min_t], zcr[:, :min_t], onset[:, :min_t],
|
||||
] + [b[:, :min_t] for b in band_feats]
|
||||
+ [sc[:, :min_t]])
|
||||
|
||||
|
||||
def compute_window_stats(feat, hop=1.0):
|
||||
"""Sliding window mean/std → (timestamps, feature_vectors)."""
|
||||
n_feats, T = feat.shape
|
||||
fps = _SR / _HOP_LENGTH
|
||||
win_frames = int(_WINDOW * fps)
|
||||
hop_frames = int(hop * fps)
|
||||
if win_frames > T:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
cumsum = np.zeros((n_feats, T + 1))
|
||||
cumsum[:, 1:] = np.cumsum(feat, axis=1)
|
||||
cumsq = np.zeros((n_feats, T + 1))
|
||||
cumsq[:, 1:] = np.cumsum(feat ** 2, axis=1)
|
||||
|
||||
starts = np.arange(0, T - win_frames + 1, hop_frames)
|
||||
ends = starts + win_frames
|
||||
sums = cumsum[:, ends] - cumsum[:, starts]
|
||||
sq_sums = cumsq[:, ends] - cumsq[:, starts]
|
||||
means = sums / win_frames
|
||||
stds = np.sqrt(np.maximum(sq_sums / win_frames - means ** 2, 0) + 1e-10)
|
||||
|
||||
return starts / fps, np.vstack([means, stds]).T
|
||||
|
||||
|
||||
def label_windows(timestamps, gt_intense, gt_soft):
|
||||
all_gt = list(gt_intense) + list(gt_soft)
|
||||
labels = np.zeros(len(timestamps), dtype=int)
|
||||
for i, t in enumerate(timestamps):
|
||||
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||
if di < TOLERANCE:
|
||||
labels[i] = 1
|
||||
elif da > NEG_MARGIN:
|
||||
labels[i] = -1
|
||||
return labels
|
||||
|
||||
|
||||
def main():
|
||||
db = ProcessedDB()
|
||||
rows = db._con.execute(
|
||||
"SELECT filename, start_time, output_path FROM processed WHERE profile = ?",
|
||||
(PROFILE_NAME,),
|
||||
).fetchall()
|
||||
|
||||
intense_by_video, soft_by_video = {}, {}
|
||||
for fn, st, op in rows:
|
||||
if '/mp4_Intense/' in op:
|
||||
intense_by_video.setdefault(fn, set()).add(st)
|
||||
elif '/mp4_Soft/' in op:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
videos = [fn for fn in intense_by_video
|
||||
if os.path.exists(os.path.join(PLEX_DIR, fn))]
|
||||
n_vids = int(sys.argv[1]) if len(sys.argv) > 1 else len(videos)
|
||||
videos = videos[:n_vids]
|
||||
print(f"Processing {len(videos)} videos...")
|
||||
|
||||
all_data_raw = [] # raw features
|
||||
all_data_norm = [] # per-video z-scored features
|
||||
|
||||
for vi, vname in enumerate(videos):
|
||||
vpath = os.path.join(PLEX_DIR, vname)
|
||||
gt_intense = sorted(intense_by_video.get(vname, set()))
|
||||
gt_soft = sorted(soft_by_video.get(vname, set()))
|
||||
|
||||
t0 = time.time()
|
||||
y, _ = librosa.load(vpath, sr=_SR, mono=True)
|
||||
feat = extract_rich_features(y)
|
||||
timestamps, window_vectors = compute_window_stats(feat, hop=1.0)
|
||||
dt = time.time() - t0
|
||||
|
||||
if len(timestamps) == 0:
|
||||
continue
|
||||
|
||||
labels = label_windows(timestamps, gt_intense, gt_soft)
|
||||
|
||||
# Per-video z-score normalization
|
||||
vid_mean = window_vectors.mean(axis=0)
|
||||
vid_std = window_vectors.std(axis=0)
|
||||
vid_std = np.maximum(vid_std, 1e-6)
|
||||
normed = (window_vectors - vid_mean) / vid_std
|
||||
|
||||
n_pos = (labels == 1).sum()
|
||||
n_neg = (labels == -1).sum()
|
||||
print(f" [{vi+1}/{len(videos)}] {vname[:55]} pos={n_pos} neg={n_neg} ({dt:.1f}s)")
|
||||
|
||||
all_data_raw.append((vi, vname, timestamps, window_vectors, labels))
|
||||
all_data_norm.append((vi, vname, timestamps, normed, labels))
|
||||
|
||||
# Run CV for both raw and normalized
|
||||
for label, data in [("RAW features", all_data_raw),
|
||||
("PER-VIDEO NORMALIZED features", all_data_norm)]:
|
||||
print(f"\n{'='*70}")
|
||||
print(f" {label}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
all_y_true, all_y_prob = [], []
|
||||
|
||||
for test_idx in range(len(data)):
|
||||
_, vname, _, test_X, test_labels = data[test_idx]
|
||||
test_mask = test_labels != 0
|
||||
if test_mask.sum() == 0 or (test_labels[test_mask] == 1).sum() == 0:
|
||||
continue
|
||||
X_test = test_X[test_mask]
|
||||
y_test = (test_labels[test_mask] == 1).astype(int)
|
||||
|
||||
X_parts, y_parts = [], []
|
||||
for i, (_, _, _, feats, labs) in enumerate(data):
|
||||
if i == test_idx:
|
||||
continue
|
||||
m = labs != 0
|
||||
if m.sum() == 0:
|
||||
continue
|
||||
X_parts.append(feats[m])
|
||||
y_parts.append((labs[m] == 1).astype(int))
|
||||
|
||||
if not X_parts:
|
||||
continue
|
||||
X_train = np.vstack(X_parts)
|
||||
y_train = np.concatenate(y_parts)
|
||||
|
||||
pos_idx = np.where(y_train == 1)[0]
|
||||
neg_idx = np.where(y_train == 0)[0]
|
||||
if len(pos_idx) == 0 or len(neg_idx) == 0:
|
||||
continue
|
||||
rng = np.random.RandomState(42)
|
||||
n_neg = min(len(neg_idx), len(pos_idx) * 3)
|
||||
neg_sample = rng.choice(neg_idx, n_neg, replace=False)
|
||||
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||
|
||||
clf = GradientBoostingClassifier(
|
||||
n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42
|
||||
)
|
||||
clf.fit(X_train[train_idx], y_train[train_idx])
|
||||
probs = clf.predict_proba(X_test)[:, 1]
|
||||
|
||||
tp = ((probs >= 0.5) & (y_test == 1)).sum()
|
||||
fp = ((probs >= 0.5) & (y_test == 0)).sum()
|
||||
fn_count = ((probs < 0.5) & (y_test == 1)).sum()
|
||||
pos_s = probs[y_test == 1].mean() if (y_test == 1).sum() > 0 else 0
|
||||
neg_s = probs[y_test == 0].mean() if (y_test == 0).sum() > 0 else 0
|
||||
print(f" {vname[:50]:50s} TP={tp:3d} FP={fp:4d} FN={fn_count:3d} pos_p={pos_s:.3f} neg_p={neg_s:.3f}")
|
||||
|
||||
all_y_true.extend(y_test)
|
||||
all_y_prob.extend(probs)
|
||||
|
||||
if not all_y_true:
|
||||
print(" No test results.")
|
||||
continue
|
||||
|
||||
y_true = np.array(all_y_true)
|
||||
y_prob = np.array(all_y_prob)
|
||||
pos_probs = y_prob[y_true == 1]
|
||||
neg_probs = y_prob[y_true == 0]
|
||||
|
||||
if len(pos_probs) > 0 and len(neg_probs) > 0:
|
||||
print(f"\n POS: 25%={np.percentile(pos_probs,25):.3f} 50%={np.percentile(pos_probs,50):.3f}"
|
||||
f" 75%={np.percentile(pos_probs,75):.3f} max={pos_probs.max():.3f}")
|
||||
print(f" NEG: 25%={np.percentile(neg_probs,25):.3f} 50%={np.percentile(neg_probs,50):.3f}"
|
||||
f" 75%={np.percentile(neg_probs,75):.3f} max={neg_probs.max():.3f}")
|
||||
|
||||
best_f1, best_thr = 0, 0
|
||||
print(f"\n {'thr':>5} {'prec':>6} {'recall':>6} {'TP':>5} {'FP':>5} {'FN':>4} {'F1':>6}")
|
||||
for thr in np.arange(0.10, 0.91, 0.05):
|
||||
tp = ((y_prob >= thr) & (y_true == 1)).sum()
|
||||
fp = ((y_prob >= thr) & (y_true == 0)).sum()
|
||||
fn_count = ((y_prob < thr) & (y_true == 1)).sum()
|
||||
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
rec = tp / (tp + fn_count) if (tp + fn_count) > 0 else 0
|
||||
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||
if f1 > best_f1:
|
||||
best_f1, best_thr = f1, thr
|
||||
print(f" {thr:.2f} {prec:.4f} {rec:.4f} {tp:5d} {fp:5d} {fn_count:4d} {f1:.4f}")
|
||||
print(f"\n Best F1={best_f1:.4f} at thr={best_thr:.2f}")
|
||||
|
||||
# Feature importance
|
||||
X_all = np.vstack([f[l != 0] for _, _, _, f, l in data])
|
||||
y_all = np.concatenate([(l[l != 0] == 1).astype(int) for _, _, _, _, l in data])
|
||||
pos_idx = np.where(y_all == 1)[0]
|
||||
neg_idx = np.where(y_all == 0)[0]
|
||||
rng = np.random.RandomState(42)
|
||||
neg_sub = rng.choice(neg_idx, min(len(neg_idx), len(pos_idx)*3), replace=False)
|
||||
clf = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42)
|
||||
clf.fit(X_all[np.concatenate([pos_idx, neg_sub])], y_all[np.concatenate([pos_idx, neg_sub])])
|
||||
|
||||
feat_names = (
|
||||
["rms", "centroid", "bw", "rolloff", "flat", "zcr", "onset"]
|
||||
+ [f"mel{i}" for i in range(8)]
|
||||
+ [f"sc{i}" for i in range(7)]
|
||||
)
|
||||
stat_names = [f"{f}_m" for f in feat_names] + [f"{f}_s" for f in feat_names]
|
||||
imp = clf.feature_importances_
|
||||
top = sorted(zip(stat_names, imp), key=lambda x: -x[1])[:10]
|
||||
print(f" Top features: {', '.join(f'{n}={v:.3f}' for n, v in top)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train an audio scan classifier from DB ground truth.
|
||||
|
||||
Usage:
|
||||
python 8cut_train.py # default model, auto-detect positive
|
||||
python 8cut_train.py --model BEATS # specific embedding model
|
||||
python 8cut_train.py --positive mp4_Intense # explicit positive folder
|
||||
python 8cut_train.py --positive mp4_Intense --model BEATS # both
|
||||
"""
|
||||
import sys, os, warnings
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
|
||||
from core.db import ProcessedDB
|
||||
|
||||
PROFILE_NAME = "JAV_missionary"
|
||||
|
||||
# Fallback for old DB rows without source_path
|
||||
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
||||
|
||||
|
||||
def main():
|
||||
embed_model = None
|
||||
if "--model" in sys.argv:
|
||||
idx = sys.argv.index("--model")
|
||||
if idx + 1 < len(sys.argv):
|
||||
embed_model = sys.argv[idx + 1]
|
||||
if embed_model not in _EMBED_MODELS:
|
||||
print(f"Unknown model: {embed_model}")
|
||||
print(f"Available: {', '.join(_EMBED_MODELS)}")
|
||||
sys.exit(1)
|
||||
|
||||
positive_suffix = None
|
||||
if "--positive" in sys.argv:
|
||||
idx = sys.argv.index("--positive")
|
||||
if idx + 1 < len(sys.argv):
|
||||
positive_suffix = sys.argv[idx + 1]
|
||||
|
||||
db = ProcessedDB()
|
||||
|
||||
# If --positive given, use the new DB helper
|
||||
if positive_suffix:
|
||||
video_infos = db.get_training_data(
|
||||
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
|
||||
)
|
||||
if not video_infos:
|
||||
print(f"No training data found for positive='{positive_suffix}'")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Legacy fallback: classify by folder path pattern
|
||||
rows = db._con.execute(
|
||||
"SELECT filename, start_time, output_path, source_path"
|
||||
" FROM processed WHERE profile = ?",
|
||||
(PROFILE_NAME,),
|
||||
).fetchall()
|
||||
|
||||
intense_by_video, soft_by_video = {}, {}
|
||||
source_by_fn = {}
|
||||
for fn, st, op, sp in rows:
|
||||
if sp:
|
||||
source_by_fn[fn] = sp
|
||||
if "/mp4_Intense/" in op or "_Intense/" in op:
|
||||
intense_by_video.setdefault(fn, set()).add(st)
|
||||
elif "/mp4_Soft/" in op or "_Soft/" in op:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
video_infos = []
|
||||
for fn in intense_by_video:
|
||||
# Try source_path from DB first, fall back to PLEX_DIR
|
||||
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
|
||||
if not os.path.exists(vpath):
|
||||
print(f" skip (not found): {fn}")
|
||||
continue
|
||||
gt_intense = sorted(intense_by_video[fn])
|
||||
gt_soft = sorted(soft_by_video.get(fn, set()))
|
||||
video_infos.append((vpath, gt_intense, gt_soft))
|
||||
|
||||
label = embed_model or "WAV2VEC2_BASE"
|
||||
print(f"Training {label} model on {len(video_infos)} videos...")
|
||||
model_path = default_model_path(PROFILE_NAME)
|
||||
result = train_classifier(
|
||||
video_infos, model_path=model_path, embed_model=embed_model,
|
||||
)
|
||||
if result is None:
|
||||
print("Training failed: no valid samples or missing class balance")
|
||||
sys.exit(1)
|
||||
print(f"Model saved to {model_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,13 +0,0 @@
|
||||
FROM nvidia/cuda:12.6.3-runtime-ubuntu24.04
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
COPY core/ core/
|
||||
COPY server/ server/
|
||||
RUN pip install --no-cache-dir --break-system-packages fastapi uvicorn[standard]
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -8,7 +8,7 @@
|
||||
<a href="https://github.com/ethanfel/8-cut/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
|
||||
</p>
|
||||
|
||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets.
|
||||
A desktop tool for cutting 8-second clips from video files, designed for building foley datasets. Includes audio classification for automated scanning and batch export.
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -22,19 +22,54 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
||||
|
||||
## Features
|
||||
|
||||
### Clip export
|
||||
|
||||
- **Frame-accurate scrubbing** — click or drag the timeline; arrow keys and J/L for frame-by-frame, Shift for 1-second steps
|
||||
- **Batch export** — export multiple overlapping clips per cut point with configurable count and spread offset
|
||||
- **Two export formats** — H.264 MP4 with lossless PCM audio, or WebP image sequence (frames + `.wav`)
|
||||
- **Portrait crop** — crop to 9:16, 4:5, or 1:1 before export; click the video or crop bar to reposition
|
||||
- **Random portrait** — optionally apply a random portrait crop to a subset of each batch
|
||||
- **Random portrait/square** — optionally apply a random crop to a subset of each batch
|
||||
- **Resize** — scale short side to a fixed pixel size (e.g. 512)
|
||||
- **Sound annotation** — label and category fields saved to the clip database; label also written to `dataset.json`
|
||||
- **Export history** — timeline markers show previously exported clips; double-click to enter overwrite mode; right-click to delete
|
||||
- **End-frame preview** — floating window shows the last frame of the selection region
|
||||
- **Playlist** — drag-and-drop or use the Open Files button; right-click to remove items
|
||||
- **Playback loop** — plays the exact selection region on loop so you can preview what will be exported
|
||||
- **Group operations** — delete or overwrite acts on all sub-clips in a batch, not just one
|
||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait") for the same video
|
||||
- **Hardware encoding** — GPU-accelerated export via NVENC, VAAPI, QSV, AMF, or VideoToolbox
|
||||
- **Subject tracking** — auto-adjust crop center using YOLOv8 detection (optional)
|
||||
|
||||
### Audio scanning
|
||||
|
||||
- **Embedding models** — WAV2VEC2 (base/large), HuBERT (base/large/xlarge), BEATs
|
||||
- **Train classifier** — train a gradient boosting classifier on your exported clips to find similar audio
|
||||
- **Scan video** — detect regions matching your trained model with configurable threshold
|
||||
- **Scan All** — batch scan every video in the playlist
|
||||
- **Region fusion** — merge overlapping detections into contiguous regions
|
||||
- **Hard negatives** — mark false positives to refine training
|
||||
- **Model versioning** — timestamped backups with rollback support
|
||||
- **Scan export** — batch export from scan results with spread and minimum duration filtering
|
||||
|
||||
### Scan results panel
|
||||
|
||||
- **Tabbed results** — one tab per model, showing start/end/score per region
|
||||
- **Disable regions** — Delete/Backspace toggles regions off (greyed out, excluded from export) without removing them
|
||||
- **Resize regions** — double-click Time or End cells to edit, or drag region edges directly on the timeline
|
||||
- **Grey ghost** — trimmed portions of resized regions shown as grey overlay on timeline
|
||||
- **Undo** — Ctrl+Z reverts the last disable, resize, drag, or negative toggle
|
||||
|
||||
### Organization
|
||||
|
||||
- **Sound annotation** — label and category fields saved to the clip database and `dataset.json`
|
||||
- **Export history** — timeline markers show previously exported clips; double-click to overwrite; right-click to delete
|
||||
- **Playlist** — drag-and-drop video queue with progress tracking
|
||||
- **Profiles** — switch between independent marker sets (e.g. "landscape" vs "portrait")
|
||||
- **Subprofiles** — lightweight export folder variants for multiple output targets
|
||||
- **Review mode** — clean timeline view for navigating scan results without export clutter
|
||||
|
||||
### Interface
|
||||
|
||||
- **Menu bar** — File / Edit / Scan / View / Help hold the occasional actions (open files, train, scan all, profiles); the profile selector and shortcuts (`?`) sit in the top-right corner
|
||||
- **Control deck** — a compact tabbed panel under the video groups the settings into **Export** (label, name, folder, format, resize, duration/clips/spread, workers), **Crop & Track**, and **Scan** (model, threshold, fuse, scan/auto/speech/review)
|
||||
- **Side-by-side panels** — pin deck panels to view them as resizable columns: right-click a deck tab → *Show side-by-side*, or toggle them under *View ▸ Side-by-side panels*; drag the dividers to reallocate space, and the layout persists between sessions
|
||||
- **Per-tab export folder** — each file-list tab remembers its own output folder; switching tabs follows that tab's folder, and a guardrail warns when the loaded video doesn't match the destination
|
||||
- **Duplicate tab** — right-click a file-list tab → *Duplicate tab* to clone its files into a new tab with its own export folder
|
||||
- **LTX-2 export mode** — per-tab **Foley | LTX-2** toggle (right-click a tab, shown with an `[LTX2]` badge): LTX-2 clips are frame-exact (`frames % 8 == 1`), forced to 25 fps, and center-cropped so width & height are divisible by 32 — for LTX-2 video-to-audio datasets; applies to manual, re-export, and auto-export
|
||||
- **Status bar** — export/scan progress and messages, with the current file · profile · worker count always shown
|
||||
|
||||
## Keyboard shortcuts
|
||||
|
||||
@@ -50,37 +85,158 @@ All clips are exactly 8 seconds — the standard length for foley sound datasets
|
||||
| `M` | Jump to next marker (wraps) |
|
||||
| `N` | Next file in playlist |
|
||||
| `G` | Toggle cursor lock |
|
||||
| `Delete` / `Backspace` | Toggle disable on selected scan regions |
|
||||
| `Ctrl+Z` | Undo last scan panel action |
|
||||
| `?` / `F1` | Show keyboard shortcuts |
|
||||
|
||||
Shortcuts are suppressed when a text field has focus.
|
||||
|
||||
## Requirements
|
||||
## Installation
|
||||
|
||||
- Python 3.11+
|
||||
- `ffmpeg` on `PATH`
|
||||
- PyQt6
|
||||
- python-mpv (requires libmpv)
|
||||
### Prerequisites
|
||||
|
||||
- **Python 3.11+** — [python.org/downloads](https://www.python.org/downloads/)
|
||||
- **ffmpeg** — video encoding
|
||||
- **libmpv** — video playback
|
||||
|
||||
### Quick start (all platforms)
|
||||
|
||||
The setup script creates a virtual environment and installs everything including PyTorch with CUDA support:
|
||||
|
||||
```bash
|
||||
# Linux / macOS
|
||||
./setup_env.sh
|
||||
|
||||
# Windows (PowerShell)
|
||||
powershell -ExecutionPolicy Bypass -File setup-windows.ps1
|
||||
```
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
# Linux / macOS
|
||||
./8cut.sh
|
||||
|
||||
# Windows
|
||||
8cut.bat
|
||||
```
|
||||
|
||||
The launch scripts auto-detect your venv or conda environment.
|
||||
|
||||
### Manual installation
|
||||
|
||||
#### 1. Install system dependencies
|
||||
|
||||
**Linux (Arch):**
|
||||
```bash
|
||||
pacman -S python mpv ffmpeg
|
||||
```
|
||||
|
||||
**Linux (Debian/Ubuntu):**
|
||||
```bash
|
||||
apt install python3 python3-venv libmpv-dev ffmpeg
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```powershell
|
||||
# ffmpeg
|
||||
winget install ffmpeg
|
||||
|
||||
# libmpv — download mpv-2.dll and place next to main.py
|
||||
# https://sourceforge.net/projects/mpv-player-windows/files/libmpv/
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install python mpv ffmpeg
|
||||
```
|
||||
|
||||
#### 2. Create a virtual environment
|
||||
|
||||
A virtual environment keeps 8-cut's dependencies isolated from your system Python. This is strongly recommended — PyTorch alone is several GB and can conflict with other projects.
|
||||
|
||||
**Using venv (recommended):**
|
||||
|
||||
```bash
|
||||
# Create the venv in the project directory
|
||||
python3 -m venv .venv
|
||||
|
||||
# Activate it
|
||||
# Linux / macOS:
|
||||
source .venv/bin/activate
|
||||
# Windows (cmd):
|
||||
.venv\Scripts\activate.bat
|
||||
# Windows (PowerShell):
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
**Using conda / miniforge:**
|
||||
|
||||
```bash
|
||||
conda create -n 8cut python=3.12
|
||||
conda activate 8cut
|
||||
```
|
||||
|
||||
You must activate the environment every time you open a new terminal before running 8-cut. The `8cut.sh` launcher does this automatically.
|
||||
|
||||
#### 3. Install PyTorch
|
||||
|
||||
PyTorch must be installed separately with the correct CUDA version for GPU acceleration. Without CUDA, audio scanning will fall back to CPU (much slower).
|
||||
|
||||
**With NVIDIA GPU (CUDA 12.8):**
|
||||
```bash
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
```
|
||||
|
||||
**CPU only (no GPU):**
|
||||
```bash
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
Check available CUDA versions at [pytorch.org/get-started](https://pytorch.org/get-started/locally/).
|
||||
|
||||
#### 4. Install project dependencies
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Platform notes
|
||||
#### 5. Verify
|
||||
|
||||
| Platform | libmpv |
|
||||
|----------|--------|
|
||||
| **Linux** | `apt install libmpv-dev` or `pacman -S mpv` |
|
||||
| **macOS** | `brew install mpv` |
|
||||
| **Windows** | Download `mpv-2.dll` from [mpv Windows builds](https://sourceforge.net/projects/mpv-player-windows/files/libmpv/) and place it in `PATH` or next to `main.py` |
|
||||
```bash
|
||||
python -c "import torch; print('PyTorch', torch.__version__, 'CUDA', torch.version.cuda)"
|
||||
python -c "import librosa, torchaudio, sklearn; print('All imports OK')"
|
||||
```
|
||||
|
||||
Windows also needs `ffmpeg.exe` on `PATH` (e.g. `winget install ffmpeg`).
|
||||
### Running
|
||||
|
||||
```bash
|
||||
# With venv activated:
|
||||
python main.py
|
||||
|
||||
# Or use the launcher (auto-activates venv/conda):
|
||||
./8cut.sh # Linux / macOS
|
||||
8cut.bat # Windows
|
||||
```
|
||||
|
||||
### GPU encoding
|
||||
|
||||
Hardware encoders are auto-detected from ffmpeg. Available encoders by platform:
|
||||
|
||||
| Platform | Encoders |
|
||||
|----------|----------|
|
||||
| **Linux** | `h264_nvenc` (NVIDIA), `h264_vaapi` (AMD/Intel), `h264_qsv` (Intel) |
|
||||
| **Windows** | `h264_nvenc` (NVIDIA), `h264_qsv` (Intel), `h264_amf` (AMD) |
|
||||
| **macOS** | `h264_videotoolbox` |
|
||||
|
||||
Enable the **HW** checkbox in the export controls to use GPU encoding.
|
||||
|
||||
### Optional: audio scanning
|
||||
|
||||
Audio scanning requires PyTorch (installed above). Embedding models are downloaded on first use and cached in `cache/downloads/`. A CUDA-capable GPU is strongly recommended for training and scanning speed.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
Drop videos onto the queue or click **+ Open Files**. Scrub to your cut point, then press **Export** (or `E`).
|
||||
|
||||
### Export layout
|
||||
@@ -109,6 +265,20 @@ output/
|
||||
clip_001_0.wav
|
||||
```
|
||||
|
||||
### Scan export layout
|
||||
|
||||
Scan exports create one group folder per detected area:
|
||||
|
||||
```
|
||||
output/
|
||||
clip_037/
|
||||
clip_037_a1_0.mp4 # area 1, clip 0
|
||||
clip_037_a1_1.mp4 # area 1, clip 1
|
||||
clip_038/
|
||||
clip_038_a2_0.mp4 # area 2, clip 0
|
||||
...
|
||||
```
|
||||
|
||||
### Sound annotation
|
||||
|
||||
Set a **Label** (e.g. "dog barking") and **Category** (Human / Animal / Vehicle / Tool / Music / Nature / Sport / Other) before exporting. These are saved to:
|
||||
@@ -124,9 +294,73 @@ Labels persist between exports so you can cut many clips of the same class witho
|
||||
- **Right-click** a marker to delete it from the database
|
||||
- The **Delete** button removes all clips in a group from disk, database, and `dataset.json`
|
||||
|
||||
## Audio scan workflow
|
||||
|
||||
### 1. Build a dataset
|
||||
|
||||
Export clips manually from several videos. Clips from the same export folder (e.g. `mp4_Intense`) become your positive training class.
|
||||
|
||||
**Minimum dataset:** ~20 clips from 2–3 different videos. This is enough for the classifier to learn a basic boundary, but expect noisy results — you'll need to mark hard negatives and retrain.
|
||||
|
||||
**Ideal dataset:** 50–100+ clips from 5+ videos covering the full range of variation in your target sound (different recording conditions, distances, intensities). More variety in your positives makes the model generalize better to unseen footage. Negatives are sampled automatically from regions far from your markers, but adding explicit negatives of confusable sounds (e.g. thunder when training for explosions) significantly reduces false positives.
|
||||
|
||||
The classifier improves iteratively: export a small initial set → train → scan → mark false positives as hard negatives → retrain. Each cycle sharpens the decision boundary without needing a large upfront dataset.
|
||||
|
||||
### 2. Train a classifier
|
||||
|
||||
Click **Train** to open the training dialog:
|
||||
|
||||
- **Positive class** — select the export folder containing your target sounds
|
||||
- **Negative class** — optional explicit negatives, or leave as "(auto only)" for automatic sampling
|
||||
- **Model** — embedding model to use (HuBERT XLARGE recommended)
|
||||
- **Auto-neg margin** — distance from markers to sample automatic negatives (30s default)
|
||||
- **Include scan-exported clips** — whether to include previously scan-exported clips in training
|
||||
|
||||
The classifier trains a `HistGradientBoostingClassifier` on audio embeddings and saves to `models/`.
|
||||
|
||||
### 3. Scan videos
|
||||
|
||||
Select a trained model from the dropdown and click **Scan**. Adjust the threshold slider to control sensitivity. Detected regions appear as colored bands on the timeline and as rows in the results panel.
|
||||
|
||||
Audio embeddings are computed once per video and cached to disk (`cache/w2v/`). Subsequent scans with the same embedding model skip the GPU entirely and only re-run the classifier, which takes milliseconds. This makes the retrain → rescan loop nearly free after the first pass.
|
||||
|
||||
### 4. Review and refine
|
||||
|
||||
- Toggle **Review** mode for a clean timeline focused on scan results
|
||||
- **Disable** false positive regions (Delete key) — they stay in the list but are excluded from export
|
||||
- **Resize** regions by dragging edges on the timeline or editing times in the table
|
||||
- **Mark as negative** — add false positives to the hard negative set for retraining
|
||||
- **Ctrl+Z** to undo any of the above
|
||||
|
||||
### 5. Export results
|
||||
|
||||
Click **Export Scan Results** to batch export all enabled regions. The button shows the estimated clip count based on spread and minimum duration settings.
|
||||
|
||||
### 6. Retrain with feedback
|
||||
|
||||
Train again — hard negatives are automatically included. Each training run saves with a timestamp. Click the **⏲** button next to the model dropdown to restore a previous version if results degrade — restoring automatically rescans with the selected version.
|
||||
|
||||
## Database
|
||||
|
||||
Export history is stored in `~/.8cut.db` (SQLite). The database records filename, start time, output path, label, category, and all encoding settings for every clip. When you open a file, 8-cut matches the filename and pre-populates the timeline with existing markers.
|
||||
Export history is stored in `~/.8cut.db` (SQLite). Tables:
|
||||
|
||||
| Table | Purpose |
|
||||
|-------|---------|
|
||||
| `processed` | Every exported clip with full encoding settings |
|
||||
| `scan_results` | Audio scan detections per video/model |
|
||||
| `hard_negatives` | Timestamps marked as false positives for training |
|
||||
| `hidden_files` | Playlist files hidden by the user |
|
||||
|
||||
The database auto-migrates when new columns are added.
|
||||
|
||||
## File locations
|
||||
|
||||
| Path | Contents |
|
||||
|------|----------|
|
||||
| `~/.8cut.db` | SQLite database |
|
||||
| `models/` | Trained classifier models (`.joblib`) |
|
||||
| `cache/w2v/` | Embedding cache (`.npz`, keyed by video hash) |
|
||||
| `cache/downloads/` | Downloaded pretrained models |
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
|
||||
<defs>
|
||||
<linearGradient id="g8" x1="0" y1="0" x2="0" y2="1">
|
||||
<stop offset="0%" stop-color="#ffd230"/>
|
||||
<stop offset="100%" stop-color="#e6a800"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<rect width="64" height="64" rx="13" fill="#161616"/>
|
||||
<rect x="8" y="42" width="48" height="11" rx="2" fill="#2a2a2a" stroke="#333" stroke-width="1"/>
|
||||
<rect x="26" y="42" width="16" height="11" fill="#3c82dc" fill-opacity="0.45"/>
|
||||
<line x1="26" y1="38" x2="26" y2="55" stroke="#ffd230" stroke-width="2"/>
|
||||
<polygon points="22,38 30,38 26,44" fill="#ffd230"/>
|
||||
<text x="32" y="33" font-family="'Helvetica Neue',Helvetica,Arial,sans-serif" font-size="34" font-weight="bold" fill="url(#g8)" text-anchor="middle">8</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 790 B |
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||
<path d="M7.5 10 V7.5 a4.5 4.5 0 0 1 9 0 V10" stroke="#ffd230" stroke-width="2"/>
|
||||
<rect x="5" y="10" width="14" height="10" rx="2" fill="#ffd230"/>
|
||||
<circle cx="12" cy="14.3" r="1.4" fill="#161616"/>
|
||||
<rect x="11.2" y="14.3" width="1.6" height="3.4" rx="0.8" fill="#161616"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 362 B |
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||
<path d="M7.5 10 V7.5 a4.5 4.5 0 0 1 8.6 -1.8" stroke="#8a8a8a" stroke-width="2"/>
|
||||
<rect x="5" y="10" width="14" height="10" rx="2" fill="#8a8a8a"/>
|
||||
<circle cx="12" cy="14.3" r="1.4" fill="#1e1e1e"/>
|
||||
<rect x="11.2" y="14.3" width="1.6" height="3.4" rx="0.8" fill="#1e1e1e"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 363 B |
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<rect x="6.5" y="5" width="4" height="14" rx="1.2" fill="#ffd230"/>
|
||||
<rect x="13.5" y="5" width="4" height="14" rx="1.2" fill="#ffd230"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 209 B |
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<path d="M7 5 L19 12 L7 19 Z" fill="#ffd230" stroke="#ffd230" stroke-width="1.5" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 177 B |
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#aad4ff" stroke-width="2" stroke-linecap="round">
|
||||
<circle cx="10.5" cy="10.5" r="6"/>
|
||||
<line x1="15" y1="15" x2="20" y2="20"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 217 B |
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#ffd230" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="6.5" cy="6.5" r="2.6"/>
|
||||
<circle cx="6.5" cy="17.5" r="2.6"/>
|
||||
<line x1="8.8" y1="8" x2="20" y2="17"/>
|
||||
<line x1="8.8" y1="16" x2="20" y2="7"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 322 B |
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#ffd230" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<polyline points="4,17 10,11 14,14 20,6"/>
|
||||
<polyline points="15,6 20,6 20,11"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 245 B |
@@ -1,10 +0,0 @@
|
||||
.DS_Store
|
||||
node_modules
|
||||
/build
|
||||
/.svelte-kit
|
||||
/package
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
@@ -1,7 +0,0 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"svelte.svelte-vscode",
|
||||
"tauri-apps.tauri-vscode",
|
||||
"rust-lang.rust-analyzer"
|
||||
]
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"svelte.enable-ts-plugin": true
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
# Tauri + SvelteKit + TypeScript
|
||||
|
||||
This template should help get you started developing with Tauri, SvelteKit and TypeScript in Vite.
|
||||
|
||||
## Recommended IDE Setup
|
||||
|
||||
[VS Code](https://code.visualstudio.com/) + [Svelte](https://marketplace.visualstudio.com/items?itemName=svelte.svelte-vscode) + [Tauri](https://marketplace.visualstudio.com/items?itemName=tauri-apps.tauri-vscode) + [rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer).
|
||||
@@ -1,29 +0,0 @@
|
||||
{
|
||||
"name": "client",
|
||||
"version": "0.1.0",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite dev",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview",
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
||||
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
|
||||
"tauri": "tauri"
|
||||
},
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@tauri-apps/api": "^2",
|
||||
"@tauri-apps/plugin-opener": "^2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sveltejs/adapter-static": "^3.0.6",
|
||||
"@sveltejs/kit": "^2.9.0",
|
||||
"@sveltejs/vite-plugin-svelte": "^5.0.0",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
"typescript": "~5.6.2",
|
||||
"vite": "^6.0.3",
|
||||
"@tauri-apps/cli": "^2"
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
|
||||
# Generated by Tauri
|
||||
# will have schema files for capabilities auto-completion
|
||||
/gen/schemas
|
||||
@@ -1,25 +0,0 @@
|
||||
[package]
|
||||
name = "client"
|
||||
version = "0.1.0"
|
||||
description = "A Tauri App"
|
||||
authors = ["you"]
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
# The `_lib` suffix may seem redundant but it is necessary
|
||||
# to make the lib name unique and wouldn't conflict with the bin name.
|
||||
# This seems to be only an issue on Windows, see https://github.com/rust-lang/cargo/issues/8519
|
||||
name = "client_lib"
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2", features = [] }
|
||||
|
||||
[dependencies]
|
||||
tauri = { version = "2", features = [] }
|
||||
tauri-plugin-opener = "2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
fn main() {
|
||||
tauri_build::build()
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"$schema": "../gen/schemas/desktop-schema.json",
|
||||
"identifier": "default",
|
||||
"description": "Capability for the main window",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default",
|
||||
"opener:default"
|
||||
]
|
||||
}
|
||||
|
Before Width: | Height: | Size: 3.4 KiB |
|
Before Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 974 B |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 3.8 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 7.6 KiB |
|
Before Width: | Height: | Size: 903 B |
|
Before Width: | Height: | Size: 8.4 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 2.4 KiB |
|
Before Width: | Height: | Size: 1.5 KiB |
|
Before Width: | Height: | Size: 85 KiB |
|
Before Width: | Height: | Size: 14 KiB |
@@ -1,56 +0,0 @@
|
||||
use tauri::State;
|
||||
use std::sync::Mutex;
|
||||
use crate::mpv::Mpv;
|
||||
|
||||
pub struct MpvState(pub Mutex<Mpv>);
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_start(state: State<MpvState>) -> Result<(), String> {
|
||||
state.0.lock().unwrap().start()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_stop(state: State<MpvState>) -> Result<(), String> {
|
||||
state.0.lock().unwrap().stop();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_load(state: State<MpvState>, video_url: String, audio_url: String) -> Result<(), String> {
|
||||
state.0.lock().unwrap().load_file(&video_url, &audio_url)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_seek(state: State<MpvState>, time: f64) -> Result<(), String> {
|
||||
state.0.lock().unwrap().seek(time)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_pause(state: State<MpvState>) -> Result<(), String> {
|
||||
state.0.lock().unwrap().pause()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_resume(state: State<MpvState>) -> Result<(), String> {
|
||||
state.0.lock().unwrap().resume()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_set_loop(state: State<MpvState>, a: f64, b: f64) -> Result<(), String> {
|
||||
state.0.lock().unwrap().set_loop(a, b)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_clear_loop(state: State<MpvState>) -> Result<(), String> {
|
||||
state.0.lock().unwrap().clear_loop()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_time_pos(state: State<MpvState>) -> Result<f64, String> {
|
||||
state.0.lock().unwrap().time_pos()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn mpv_duration(state: State<MpvState>) -> Result<f64, String> {
|
||||
state.0.lock().unwrap().get_duration()
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
mod mpv;
|
||||
mod commands;
|
||||
|
||||
use commands::MpvState;
|
||||
use mpv::Mpv;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_opener::init())
|
||||
.manage(MpvState(Mutex::new(Mpv::new())))
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
commands::mpv_start,
|
||||
commands::mpv_stop,
|
||||
commands::mpv_load,
|
||||
commands::mpv_seek,
|
||||
commands::mpv_pause,
|
||||
commands::mpv_resume,
|
||||
commands::mpv_set_loop,
|
||||
commands::mpv_clear_loop,
|
||||
commands::mpv_time_pos,
|
||||
commands::mpv_duration,
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
fn main() {
|
||||
client_lib::run()
|
||||
}
|
||||
@@ -1,167 +0,0 @@
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::os::unix::net::UnixStream;
|
||||
use std::process::{Child, Command};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub struct Mpv {
|
||||
process: Option<Child>,
|
||||
writer: Option<UnixStream>,
|
||||
reader: Option<BufReader<UnixStream>>,
|
||||
socket_path: String,
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl Mpv {
|
||||
pub fn new() -> Self {
|
||||
let socket_path = format!("/tmp/8cut-mpv-{}", std::process::id());
|
||||
Mpv {
|
||||
process: None,
|
||||
writer: None,
|
||||
reader: None,
|
||||
socket_path,
|
||||
next_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self) -> Result<(), String> {
|
||||
self.stop();
|
||||
|
||||
let child = Command::new("mpv")
|
||||
.args([
|
||||
"--idle=yes",
|
||||
"--force-window=no",
|
||||
"--vo=null",
|
||||
"--keep-open=yes",
|
||||
&format!("--input-ipc-server={}", self.socket_path),
|
||||
])
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to start mpv: {e}"))?;
|
||||
|
||||
self.process = Some(child);
|
||||
|
||||
// Wait for socket
|
||||
for _ in 0..50 {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
if let Ok(stream) = UnixStream::connect(&self.socket_path) {
|
||||
stream.set_nonblocking(false).ok();
|
||||
let reader_stream = stream.try_clone().map_err(|e| e.to_string())?;
|
||||
self.writer = Some(stream);
|
||||
self.reader = Some(BufReader::new(reader_stream));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err("Timeout waiting for mpv IPC socket".into())
|
||||
}
|
||||
|
||||
pub fn stop(&mut self) {
|
||||
if let Some(ref mut child) = self.process {
|
||||
child.kill().ok();
|
||||
child.wait().ok();
|
||||
}
|
||||
self.process = None;
|
||||
self.writer = None;
|
||||
self.reader = None;
|
||||
std::fs::remove_file(&self.socket_path).ok();
|
||||
}
|
||||
|
||||
/// Send a command and wait for the matching response (by request_id).
|
||||
/// Skips over asynchronous mpv events while waiting.
|
||||
fn send_and_recv(&mut self, cmd: Value) -> Result<Value, String> {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let writer = self.writer.as_mut().ok_or("mpv not running")?;
|
||||
let reader = self.reader.as_mut().ok_or("mpv not running")?;
|
||||
|
||||
let mut msg_val = cmd;
|
||||
msg_val["request_id"] = json!(id);
|
||||
let mut msg = serde_json::to_string(&msg_val).unwrap();
|
||||
msg.push('\n');
|
||||
writer.write_all(msg.as_bytes()).map_err(|e| e.to_string())?;
|
||||
|
||||
// Read lines until we find the response matching our request_id
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
line.clear();
|
||||
reader.read_line(&mut line).map_err(|e| e.to_string())?;
|
||||
let parsed: Value = serde_json::from_str(&line).map_err(|e| e.to_string())?;
|
||||
// mpv events have "event" key, responses have "request_id"
|
||||
if parsed.get("request_id").and_then(|v| v.as_u64()) == Some(id) {
|
||||
return Ok(parsed);
|
||||
}
|
||||
// Otherwise it's an async event — skip it
|
||||
}
|
||||
}
|
||||
|
||||
pub fn command(&mut self, args: &[&str]) -> Result<(), String> {
|
||||
let resp = self.send_and_recv(json!({ "command": args }))?;
|
||||
if resp.get("error").and_then(|e| e.as_str()) != Some("success") {
|
||||
return Err(format!("mpv error: {}", resp.get("error").unwrap_or(&Value::Null)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_property(&mut self, name: &str, value: Value) -> Result<(), String> {
|
||||
let resp = self.send_and_recv(json!({ "command": ["set_property", name, value] }))?;
|
||||
if resp.get("error").and_then(|e| e.as_str()) != Some("success") {
|
||||
return Err(format!("mpv error: {}", resp.get("error").unwrap_or(&Value::Null)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_property(&mut self, name: &str) -> Result<Value, String> {
|
||||
let resp = self.send_and_recv(json!({ "command": ["get_property", name] }))?;
|
||||
if resp.get("error").and_then(|e| e.as_str()) != Some("success") {
|
||||
return Err(format!("mpv error: {}", resp.get("error").unwrap_or(&Value::Null)));
|
||||
}
|
||||
Ok(resp.get("data").cloned().unwrap_or(Value::Null))
|
||||
}
|
||||
|
||||
pub fn load_file(&mut self, video_url: &str, audio_url: &str) -> Result<(), String> {
|
||||
let options = format!("audio-file={}", audio_url);
|
||||
let resp = self.send_and_recv(json!({
|
||||
"command": ["loadfile", video_url, "replace", -1, options]
|
||||
}))?;
|
||||
if resp.get("error").and_then(|e| e.as_str()) != Some("success") {
|
||||
return Err(format!("mpv error: {}", resp.get("error").unwrap_or(&Value::Null)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn seek(&mut self, time: f64) -> Result<(), String> {
|
||||
self.command(&["seek", &time.to_string(), "absolute"])
|
||||
}
|
||||
|
||||
pub fn pause(&mut self) -> Result<(), String> {
|
||||
self.set_property("pause", json!(true))
|
||||
}
|
||||
|
||||
pub fn resume(&mut self) -> Result<(), String> {
|
||||
self.set_property("pause", json!(false))
|
||||
}
|
||||
|
||||
pub fn set_loop(&mut self, a: f64, b: f64) -> Result<(), String> {
|
||||
self.set_property("ab-loop-a", json!(a))?;
|
||||
self.set_property("ab-loop-b", json!(b))
|
||||
}
|
||||
|
||||
pub fn clear_loop(&mut self) -> Result<(), String> {
|
||||
self.set_property("ab-loop-a", json!("no"))?;
|
||||
self.set_property("ab-loop-b", json!("no"))
|
||||
}
|
||||
|
||||
pub fn time_pos(&mut self) -> Result<f64, String> {
|
||||
let val = self.get_property("time-pos")?;
|
||||
val.as_f64().ok_or("time-pos not a number".into())
|
||||
}
|
||||
|
||||
pub fn get_duration(&mut self) -> Result<f64, String> {
|
||||
let val = self.get_property("duration")?;
|
||||
val.as_f64().ok_or("duration not a number".into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Mpv {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
{
|
||||
"$schema": "https://schema.tauri.app/config/2",
|
||||
"productName": "8cut",
|
||||
"version": "0.1.0",
|
||||
"identifier": "com.ethanfel.8cut",
|
||||
"build": {
|
||||
"beforeDevCommand": "pnpm dev",
|
||||
"devUrl": "http://localhost:1420",
|
||||
"beforeBuildCommand": "pnpm build",
|
||||
"frontendDist": "../build"
|
||||
},
|
||||
"app": {
|
||||
"windows": [
|
||||
{
|
||||
"title": "8-cut",
|
||||
"width": 1200,
|
||||
"height": 800
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": null
|
||||
}
|
||||
},
|
||||
"bundle": {
|
||||
"active": true,
|
||||
"targets": ["deb", "appimage"],
|
||||
"icon": [
|
||||
"icons/32x32.png",
|
||||
"icons/128x128.png",
|
||||
"icons/128x128@2x.png",
|
||||
"icons/icon.icns",
|
||||
"icons/icon.ico"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%sveltekit.assets%/favicon.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Tauri + SvelteKit + Typescript App</title>
|
||||
%sveltekit.head%
|
||||
</head>
|
||||
<body data-sveltekit-preload-data="hover">
|
||||
<div style="display: contents">%sveltekit.body%</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,113 +0,0 @@
|
||||
<script lang="ts" module>
|
||||
// Module-level export so App can call doExport via bind:this
|
||||
</script>
|
||||
|
||||
<script lang="ts">
|
||||
import { startExport } from "$lib/api";
|
||||
import {
|
||||
currentFile, cursor, clips, spread, shortSide, portraitRatio,
|
||||
cropCenter, format, label, category, clipName, profile,
|
||||
hwEncode,
|
||||
exportStatus, exportCompleted, exportTotal, subprofiles
|
||||
} from "$lib/stores";
|
||||
|
||||
const CATEGORIES = ["", "Human", "Animal", "Vehicle", "Tool", "Music", "Nature", "Sport", "Other"];
|
||||
const RATIOS = ["Off", "9:16", "4:5", "1:1"];
|
||||
|
||||
export async function doExport(folderSuffix: string = "") {
|
||||
if (!$currentFile) return;
|
||||
$exportStatus = "running";
|
||||
$exportCompleted = 0;
|
||||
$exportTotal = $clips;
|
||||
|
||||
const req = {
|
||||
input_path: `${$currentFile.root}/${$currentFile.path}`,
|
||||
cursor: $cursor,
|
||||
name: $clipName || $currentFile.name.replace(/\.[^.]+$/, ""),
|
||||
clips: $clips,
|
||||
spread: $spread,
|
||||
short_side: $shortSide,
|
||||
portrait_ratio: $portraitRatio,
|
||||
crop_center: $cropCenter,
|
||||
format: $format,
|
||||
label: $label,
|
||||
category: $category,
|
||||
profile: $profile,
|
||||
folder_suffix: folderSuffix,
|
||||
encoder: $hwEncode ? "h264_nvenc" : "libx264",
|
||||
};
|
||||
|
||||
try {
|
||||
await startExport(req);
|
||||
} catch (e) {
|
||||
$exportStatus = "error";
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="export-panel">
|
||||
<div class="row">
|
||||
<button onclick={() => doExport()} disabled={$exportStatus === "running"}>
|
||||
Export{#if $exportStatus === "running"} ({$exportCompleted}/{$exportTotal}){/if}
|
||||
</button>
|
||||
{#each $subprofiles as sub}
|
||||
<button onclick={() => doExport(sub)} title="Export {sub}">
|
||||
{sub}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<label>Clips <input type="number" bind:value={$clips} min="1" max="99" /></label>
|
||||
<label>Spread <input type="number" bind:value={$spread} min="2" max="8" step="0.5" /></label>
|
||||
<label>Size <input type="number" bind:value={$shortSide} min="0" max="4320" step="64" /></label>
|
||||
<label>Ratio
|
||||
<select bind:value={$portraitRatio}>
|
||||
{#each RATIOS as r}
|
||||
<option value={r === "Off" ? null : r}>{r}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<label>Label <input type="text" bind:value={$label} /></label>
|
||||
<label>Category
|
||||
<select bind:value={$category}>
|
||||
{#each CATEGORIES as c}
|
||||
<option value={c}>{c || "---"}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</label>
|
||||
<label>Format
|
||||
<select bind:value={$format}>
|
||||
<option>MP4</option>
|
||||
<option>WebP sequence</option>
|
||||
</select>
|
||||
</label>
|
||||
<label><input type="checkbox" bind:checked={$hwEncode} /> GPU</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.export-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
padding: 4px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.row {
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
label { display: flex; align-items: center; gap: 2px; }
|
||||
input[type="number"] { width: 50px; background: #2d2d2d; color: #e0e0e0; border: 1px solid #444; }
|
||||
input[type="text"] { width: 120px; background: #2d2d2d; color: #e0e0e0; border: 1px solid #444; }
|
||||
select { background: #2d2d2d; color: #e0e0e0; border: 1px solid #444; }
|
||||
button { background: #0066cc; color: white; border: none; padding: 4px 12px; cursor: pointer; }
|
||||
button:disabled { background: #444; }
|
||||
</style>
|
||||
@@ -1,173 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import { getFiles, getRoots, getHidden, getMarkers, hideFile, unhideFile } from "$lib/api";
|
||||
import {
|
||||
files, roots, hiddenFiles, currentFile, showHidden,
|
||||
profile, markers, visibleFiles
|
||||
} from "$lib/stores";
|
||||
|
||||
let selectedRoot = $state("");
|
||||
let currentFolder = $state("");
|
||||
|
||||
onMount(async () => {
|
||||
$roots = await getRoots();
|
||||
if ($roots.length) {
|
||||
selectedRoot = $roots[0];
|
||||
await loadFiles();
|
||||
}
|
||||
});
|
||||
|
||||
// Reload hidden files when profile changes
|
||||
$effect(() => {
|
||||
void $profile;
|
||||
if (selectedRoot) {
|
||||
loadFiles();
|
||||
}
|
||||
});
|
||||
|
||||
async function loadFiles() {
|
||||
$files = await getFiles(selectedRoot);
|
||||
const hidden = await getHidden($profile);
|
||||
$hiddenFiles = new Set(hidden);
|
||||
}
|
||||
|
||||
// Derive subfolders and files at current folder level
|
||||
let subfolders = $derived.by(() => {
|
||||
const prefix = currentFolder ? currentFolder + "/" : "";
|
||||
const folderSet = new Set<string>();
|
||||
for (const f of $visibleFiles) {
|
||||
if (!f.path.startsWith(prefix)) continue;
|
||||
const rest = f.path.slice(prefix.length);
|
||||
const slashIdx = rest.indexOf("/");
|
||||
if (slashIdx !== -1) {
|
||||
folderSet.add(rest.slice(0, slashIdx));
|
||||
}
|
||||
}
|
||||
return [...folderSet].sort();
|
||||
});
|
||||
|
||||
let currentFiles = $derived.by(() => {
|
||||
const prefix = currentFolder ? currentFolder + "/" : "";
|
||||
return $visibleFiles.filter(f => {
|
||||
if (!f.path.startsWith(prefix)) return false;
|
||||
const rest = f.path.slice(prefix.length);
|
||||
return !rest.includes("/"); // only direct children
|
||||
});
|
||||
});
|
||||
|
||||
async function selectFile(file: typeof $files[0]) {
|
||||
$currentFile = file;
|
||||
$markers = await getMarkers(file.name, $profile);
|
||||
}
|
||||
|
||||
function navigateToFolder(name: string) {
|
||||
currentFolder = currentFolder ? currentFolder + "/" + name : name;
|
||||
}
|
||||
|
||||
function navigateUp() {
|
||||
const idx = currentFolder.lastIndexOf("/");
|
||||
currentFolder = idx === -1 ? "" : currentFolder.slice(0, idx);
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
if (bytes > 1e9) return (bytes / 1e9).toFixed(1) + " GB";
|
||||
if (bytes > 1e6) return (bytes / 1e6).toFixed(0) + " MB";
|
||||
return (bytes / 1e3).toFixed(0) + " KB";
|
||||
}
|
||||
|
||||
async function toggleHidden(file: typeof $files[0]) {
|
||||
if ($hiddenFiles.has(file.name)) {
|
||||
await unhideFile(file.name, $profile);
|
||||
} else {
|
||||
await hideFile(file.name, $profile);
|
||||
}
|
||||
await loadFiles();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="file-browser">
|
||||
<div class="controls">
|
||||
<select bind:value={selectedRoot} onchange={() => { currentFolder = ""; loadFiles(); }}>
|
||||
{#each $roots as root}
|
||||
<option value={root}>{root}</option>
|
||||
{/each}
|
||||
</select>
|
||||
<label><input type="checkbox" bind:checked={$showHidden} /> Hidden</label>
|
||||
</div>
|
||||
{#if currentFolder}
|
||||
<div class="breadcrumb" onclick={navigateUp}>.. / {currentFolder}</div>
|
||||
{/if}
|
||||
<ul class="file-list">
|
||||
{#each subfolders as folder}
|
||||
<li class="folder" onclick={() => navigateToFolder(folder)}>
|
||||
<span class="name">{folder}/</span>
|
||||
<span class="badge">dir</span>
|
||||
</li>
|
||||
{/each}
|
||||
{#each currentFiles as file}
|
||||
<li
|
||||
class:selected={$currentFile?.path === file.path}
|
||||
onclick={() => selectFile(file)}
|
||||
oncontextmenu={(e) => { e.preventDefault(); toggleHidden(file); }}
|
||||
>
|
||||
<span class="name">{file.name}</span>
|
||||
<span class="size">{formatSize(file.size)}</span>
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.file-browser {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
min-width: 200px;
|
||||
}
|
||||
.controls {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
padding: 4px;
|
||||
align-items: center;
|
||||
}
|
||||
.controls select {
|
||||
flex: 1;
|
||||
background: #2d2d2d;
|
||||
color: #e0e0e0;
|
||||
border: 1px solid #444;
|
||||
padding: 2px;
|
||||
}
|
||||
.breadcrumb {
|
||||
padding: 3px 8px;
|
||||
font-size: 11px;
|
||||
color: #88aaff;
|
||||
cursor: pointer;
|
||||
background: #252525;
|
||||
border-bottom: 1px solid #333;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.breadcrumb:hover { background: #2a2a2a; }
|
||||
.file-list {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
overflow-y: auto;
|
||||
flex: 1;
|
||||
}
|
||||
.file-list li {
|
||||
padding: 4px 8px;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
font-size: 12px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.file-list li:hover { background: #333; }
|
||||
.file-list li.selected { background: #0066cc; }
|
||||
.file-list li.folder { color: #88aaff; }
|
||||
.name { flex: 1; overflow: hidden; text-overflow: ellipsis; }
|
||||
.size { flex-shrink: 0; margin-left: 8px; color: #888; font-size: 11px; }
|
||||
.badge { flex-shrink: 0; margin-left: 8px; color: #666; font-size: 10px; }
|
||||
</style>
|
||||
@@ -1,93 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import { getProfiles, setServer, getServer } from "$lib/api";
|
||||
import { profile, subprofiles, serverUrl } from "$lib/stores";
|
||||
import { saveSettings } from "$lib/settings";
|
||||
|
||||
let profiles = $state<string[]>([]);
|
||||
let serverInput = $state(getServer());
|
||||
|
||||
onMount(async () => {
|
||||
serverInput = getServer();
|
||||
try {
|
||||
profiles = await getProfiles();
|
||||
if (profiles.length && !profiles.includes($profile)) {
|
||||
$profile = profiles[0];
|
||||
}
|
||||
} catch { /* server not reachable yet */ }
|
||||
});
|
||||
|
||||
function applyServer() {
|
||||
const url = serverInput.replace(/\/+$/, "");
|
||||
setServer(url);
|
||||
$serverUrl = url;
|
||||
saveSettings();
|
||||
// Reload profiles from new server
|
||||
getProfiles().then(p => { profiles = p; }).catch(() => {});
|
||||
}
|
||||
|
||||
function addSubprofile() {
|
||||
const name = prompt("Subprofile suffix:");
|
||||
if (name && !$subprofiles.includes(name)) {
|
||||
$subprofiles = [...$subprofiles, name];
|
||||
}
|
||||
}
|
||||
|
||||
function removeSubprofile(name: string) {
|
||||
$subprofiles = $subprofiles.filter(s => s !== name);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="profile-bar">
|
||||
<input
|
||||
class="server-input"
|
||||
type="text"
|
||||
bind:value={serverInput}
|
||||
onkeydown={(e) => { if (e.key === "Enter") applyServer(); }}
|
||||
placeholder="http://host:8000"
|
||||
/>
|
||||
<button onclick={applyServer}>Set</button>
|
||||
|
||||
<select bind:value={$profile}>
|
||||
{#each profiles as p}
|
||||
<option value={p}>{p}</option>
|
||||
{/each}
|
||||
</select>
|
||||
|
||||
<span class="subs">
|
||||
{#each $subprofiles as sub}
|
||||
<span class="sub-tag" oncontextmenu={(e) => { e.preventDefault(); removeSubprofile(sub); }}>
|
||||
{sub}
|
||||
</span>
|
||||
{/each}
|
||||
<button onclick={addSubprofile}>+</button>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.profile-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 4px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.server-input {
|
||||
width: 180px;
|
||||
background: #2d2d2d;
|
||||
color: #e0e0e0;
|
||||
border: 1px solid #444;
|
||||
padding: 2px 4px;
|
||||
font-size: 11px;
|
||||
}
|
||||
select { background: #2d2d2d; color: #e0e0e0; border: 1px solid #444; }
|
||||
.subs { display: flex; gap: 4px; align-items: center; }
|
||||
.sub-tag {
|
||||
background: #444;
|
||||
padding: 2px 6px;
|
||||
border-radius: 3px;
|
||||
cursor: context-menu;
|
||||
font-size: 11px;
|
||||
}
|
||||
button { background: #333; color: #e0e0e0; border: 1px solid #555; padding: 1px 6px; cursor: pointer; }
|
||||
</style>
|
||||
@@ -1,170 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
duration, cursor, playPos, markers, clips, spread, locked, clipSpan
|
||||
} from "$lib/stores";
|
||||
|
||||
let {
|
||||
onCursorChange = (_time: number) => {},
|
||||
onSeek = (_time: number) => {},
|
||||
onMarkerClick = (_marker: { start_time: number; output_path: string }) => {},
|
||||
onMarkerDelete = (_outputPath: string) => {},
|
||||
} = $props<{
|
||||
onCursorChange?: (time: number) => void;
|
||||
onSeek?: (time: number) => void;
|
||||
onMarkerClick?: (marker: { start_time: number; output_path: string }) => void;
|
||||
onMarkerDelete?: (outputPath: string) => void;
|
||||
}>();
|
||||
|
||||
let canvas: HTMLCanvasElement;
|
||||
let ctx: CanvasRenderingContext2D;
|
||||
let dragging = $state(false);
|
||||
|
||||
const HEIGHT = 160;
|
||||
|
||||
function timeToX(t: number): number {
|
||||
if ($duration <= 0) return 0;
|
||||
return (t / $duration) * canvas.width;
|
||||
}
|
||||
|
||||
function xToTime(x: number): number {
|
||||
if ($duration <= 0) return 0;
|
||||
return Math.max(0, Math.min($duration, (x / canvas.width) * $duration));
|
||||
}
|
||||
|
||||
function draw() {
|
||||
if (!ctx) return;
|
||||
const w = canvas.width;
|
||||
const h = canvas.height;
|
||||
ctx.clearRect(0, 0, w, h);
|
||||
|
||||
// Background
|
||||
ctx.fillStyle = "#1a1a1a";
|
||||
ctx.fillRect(0, 0, w, h);
|
||||
|
||||
// Clip span region
|
||||
if ($duration > 0) {
|
||||
const x0 = timeToX($cursor);
|
||||
const x1 = timeToX($cursor + $clipSpan);
|
||||
ctx.fillStyle = "rgba(0, 100, 200, 0.15)";
|
||||
ctx.fillRect(x0, 0, x1 - x0, h);
|
||||
}
|
||||
|
||||
// Markers
|
||||
for (const m of $markers) {
|
||||
const x = timeToX(m.start_time);
|
||||
ctx.fillStyle = "#22aa44";
|
||||
ctx.fillRect(x - 1, 0, 3, h);
|
||||
}
|
||||
|
||||
// Cursor
|
||||
if ($duration > 0) {
|
||||
const cx = timeToX($cursor);
|
||||
ctx.fillStyle = "#ff4444";
|
||||
ctx.fillRect(cx - 1, 0, 3, h);
|
||||
}
|
||||
|
||||
// Play position
|
||||
if ($playPos !== null && $duration > 0) {
|
||||
const px = timeToX($playPos);
|
||||
ctx.fillStyle = "#ffaa00";
|
||||
ctx.fillRect(px - 1, 0, 2, h);
|
||||
}
|
||||
|
||||
// Time labels
|
||||
if ($duration > 0) {
|
||||
ctx.fillStyle = "#888";
|
||||
ctx.font = "11px monospace";
|
||||
const step = Math.max(10, Math.pow(10, Math.floor(Math.log10($duration / 5))));
|
||||
for (let t = 0; t <= $duration; t += step) {
|
||||
const x = timeToX(t);
|
||||
ctx.fillText(formatTime(t), x + 2, h - 4);
|
||||
ctx.fillRect(x, h - 16, 1, 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function formatTime(s: number): string {
|
||||
const m = Math.floor(s / 60);
|
||||
const sec = (Math.floor(s % 60 * 10) / 10).toFixed(1);
|
||||
return `${m}:${sec.padStart(4, "0")}`;
|
||||
}
|
||||
|
||||
function handleMouseDown(e: MouseEvent) {
|
||||
if ($locked) return;
|
||||
dragging = true;
|
||||
const time = xToTime(e.offsetX);
|
||||
$cursor = time;
|
||||
onCursorChange(time);
|
||||
}
|
||||
|
||||
function handleMouseMove(e: MouseEvent) {
|
||||
if (!dragging || $locked) return;
|
||||
const time = xToTime(e.offsetX);
|
||||
$cursor = time;
|
||||
onCursorChange(time);
|
||||
}
|
||||
|
||||
function handleMouseUp() {
|
||||
dragging = false;
|
||||
}
|
||||
|
||||
function handleDblClick(e: MouseEvent) {
|
||||
const time = xToTime(e.offsetX);
|
||||
for (const m of $markers) {
|
||||
const mx = timeToX(m.start_time);
|
||||
if (Math.abs(e.offsetX - mx) < 8) {
|
||||
onMarkerClick(m);
|
||||
return;
|
||||
}
|
||||
}
|
||||
onSeek(time);
|
||||
}
|
||||
|
||||
function handleContextMenu(e: MouseEvent) {
|
||||
e.preventDefault();
|
||||
for (const m of $markers) {
|
||||
const mx = timeToX(m.start_time);
|
||||
if (Math.abs(e.offsetX - mx) < 8) {
|
||||
onMarkerDelete(m.output_path);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redraw on any state change
|
||||
$effect(() => {
|
||||
void $duration; void $cursor; void $playPos; void $markers; void $clips; void $spread; void $clipSpan;
|
||||
draw();
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
ctx = canvas.getContext("2d")!;
|
||||
const obs = new ResizeObserver(() => {
|
||||
canvas.width = canvas.clientWidth;
|
||||
canvas.height = HEIGHT;
|
||||
draw();
|
||||
});
|
||||
obs.observe(canvas);
|
||||
return () => obs.disconnect();
|
||||
});
|
||||
</script>
|
||||
|
||||
<canvas
|
||||
bind:this={canvas}
|
||||
style="width:100%;height:{HEIGHT}px"
|
||||
onmousedown={handleMouseDown}
|
||||
onmousemove={handleMouseMove}
|
||||
onmouseup={handleMouseUp}
|
||||
onmouseleave={handleMouseUp}
|
||||
ondblclick={handleDblClick}
|
||||
oncontextmenu={handleContextMenu}
|
||||
></canvas>
|
||||
|
||||
<style>
|
||||
canvas {
|
||||
display: block;
|
||||
background: #1a1a1a;
|
||||
cursor: crosshair;
|
||||
}
|
||||
</style>
|
||||
@@ -1,158 +0,0 @@
|
||||
const DEFAULT_SERVER = "http://192.168.1.51:8000";
|
||||
|
||||
let serverUrl = DEFAULT_SERVER;
|
||||
|
||||
export function setServer(url: string) {
|
||||
serverUrl = url.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
export function getServer(): string {
|
||||
return serverUrl;
|
||||
}
|
||||
|
||||
async function get<T>(path: string): Promise<T> {
|
||||
const res = await fetch(`${serverUrl}${path}`);
|
||||
if (!res.ok) throw new Error(`${res.status} ${res.statusText}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function post<T>(path: string, body?: unknown): Promise<T> {
|
||||
const res = await fetch(`${serverUrl}${path}`, {
|
||||
method: "POST",
|
||||
headers: body ? { "Content-Type": "application/json" } : {},
|
||||
body: body ? JSON.stringify(body) : undefined,
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status} ${res.statusText}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function del<T>(path: string): Promise<T> {
|
||||
const res = await fetch(`${serverUrl}${path}`, { method: "DELETE" });
|
||||
if (!res.ok) throw new Error(`${res.status} ${res.statusText}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
// --- Files ---
|
||||
|
||||
export interface VideoFile {
|
||||
name: string;
|
||||
path: string;
|
||||
root: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
export function getRoots(): Promise<string[]> {
|
||||
return get("/api/roots");
|
||||
}
|
||||
|
||||
export function getFiles(root?: string): Promise<VideoFile[]> {
|
||||
const q = root ? `?root=${encodeURIComponent(root)}` : "";
|
||||
return get(`/api/files${q}`);
|
||||
}
|
||||
|
||||
// For {path:path} routes, encode each segment individually to preserve slashes
|
||||
function encodePath(p: string): string {
|
||||
return p.split("/").map(encodeURIComponent).join("/");
|
||||
}
|
||||
|
||||
export function streamUrl(path: string, root: string, quality: string): string {
|
||||
return `${serverUrl}/api/stream/${encodePath(path)}?root=${encodeURIComponent(root)}&quality=${quality}`;
|
||||
}
|
||||
|
||||
export function audioUrl(path: string, root: string): string {
|
||||
return `${serverUrl}/api/audio/${encodePath(path)}?root=${encodeURIComponent(root)}`;
|
||||
}
|
||||
|
||||
/** Poll cache status until both video and audio are ready. */
|
||||
export async function waitForCache(
|
||||
path: string, root: string, quality: string,
|
||||
signal: AbortSignal, interval = 2000
|
||||
): Promise<void> {
|
||||
const url = `${serverUrl}/api/cache/status/${encodePath(path)}?root=${encodeURIComponent(root)}`;
|
||||
// Trigger transcode/audio extraction by hitting stream+audio once
|
||||
await fetch(streamUrl(path, root, quality), { signal }).catch(() => {});
|
||||
await fetch(audioUrl(path, root), { signal }).catch(() => {});
|
||||
|
||||
while (!signal.aborted) {
|
||||
const res = await fetch(url, { signal });
|
||||
const status = await res.json();
|
||||
if (status[quality] === "ready" && status.audio === "ready") return;
|
||||
await new Promise(r => setTimeout(r, interval));
|
||||
}
|
||||
throw new Error("Aborted");
|
||||
}
|
||||
|
||||
export function cacheStatus(path: string, root: string): Promise<Record<string, string>> {
|
||||
return get(`/api/cache/status/${encodePath(path)}?root=${encodeURIComponent(root)}`);
|
||||
}
|
||||
|
||||
// --- Markers & Profiles ---
|
||||
|
||||
export interface Marker {
|
||||
start_time: number;
|
||||
marker_number: number;
|
||||
output_path: string;
|
||||
}
|
||||
|
||||
export function getMarkers(filename: string, profile: string = "default"): Promise<Marker[]> {
|
||||
return get(`/api/markers/${encodeURIComponent(filename)}?profile=${encodeURIComponent(profile)}`);
|
||||
}
|
||||
|
||||
export function getProfiles(): Promise<string[]> {
|
||||
return get("/api/profiles");
|
||||
}
|
||||
|
||||
export function getLabels(): Promise<string[]> {
|
||||
return get("/api/labels");
|
||||
}
|
||||
|
||||
// --- Export ---
|
||||
|
||||
export interface ExportRequest {
|
||||
input_path: string;
|
||||
cursor: number;
|
||||
name: string;
|
||||
clips?: number;
|
||||
spread?: number;
|
||||
short_side?: number | null;
|
||||
portrait_ratio?: string | null;
|
||||
crop_center?: number;
|
||||
format?: string;
|
||||
label?: string;
|
||||
category?: string;
|
||||
profile?: string;
|
||||
folder_suffix?: string;
|
||||
encoder?: string;
|
||||
}
|
||||
|
||||
export function startExport(req: ExportRequest): Promise<{ job_id: string }> {
|
||||
return post("/api/export", req);
|
||||
}
|
||||
|
||||
export function getExportStatus(jobId: string): Promise<{
|
||||
status: string;
|
||||
total: number;
|
||||
completed: number;
|
||||
outputs: string[];
|
||||
error?: string;
|
||||
}> {
|
||||
return get(`/api/export/${jobId}`);
|
||||
}
|
||||
|
||||
export function deleteExport(outputPath: string): Promise<{ deleted: string }> {
|
||||
return del(`/api/export?output_path=${encodeURIComponent(outputPath)}`);
|
||||
}
|
||||
|
||||
// --- Hidden ---
|
||||
|
||||
export function hideFile(filename: string, profile: string = "default"): Promise<unknown> {
|
||||
return post(`/api/hidden/${encodeURIComponent(filename)}?profile=${encodeURIComponent(profile)}`);
|
||||
}
|
||||
|
||||
export function unhideFile(filename: string, profile: string = "default"): Promise<unknown> {
|
||||
return del(`/api/hidden/${encodeURIComponent(filename)}?profile=${encodeURIComponent(profile)}`);
|
||||
}
|
||||
|
||||
export function getHidden(profile: string = "default"): Promise<string[]> {
|
||||
return get(`/api/hidden?profile=${encodeURIComponent(profile)}`);
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
import { invoke } from "@tauri-apps/api/core";
|
||||
|
||||
export async function mpvStart(): Promise<void> {
|
||||
return invoke("mpv_start");
|
||||
}
|
||||
|
||||
export async function mpvStop(): Promise<void> {
|
||||
return invoke("mpv_stop");
|
||||
}
|
||||
|
||||
export async function mpvLoad(videoUrl: string, audioUrl: string): Promise<void> {
|
||||
return invoke("mpv_load", { videoUrl, audioUrl });
|
||||
}
|
||||
|
||||
export async function mpvSeek(time: number): Promise<void> {
|
||||
return invoke("mpv_seek", { time });
|
||||
}
|
||||
|
||||
export async function mpvPause(): Promise<void> {
|
||||
return invoke("mpv_pause");
|
||||
}
|
||||
|
||||
export async function mpvResume(): Promise<void> {
|
||||
return invoke("mpv_resume");
|
||||
}
|
||||
|
||||
export async function mpvSetLoop(a: number, b: number): Promise<void> {
|
||||
return invoke("mpv_set_loop", { a, b });
|
||||
}
|
||||
|
||||
export async function mpvClearLoop(): Promise<void> {
|
||||
return invoke("mpv_clear_loop");
|
||||
}
|
||||
|
||||
export async function mpvTimePos(): Promise<number> {
|
||||
return invoke("mpv_time_pos");
|
||||
}
|
||||
|
||||
export async function mpvDuration(): Promise<number> {
|
||||
return invoke("mpv_duration");
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
import {
|
||||
serverUrl, quality, clips, spread, shortSide, portraitRatio,
|
||||
format, hwEncode, profile, subprofiles
|
||||
} from "./stores";
|
||||
import { setServer } from "./api";
|
||||
import { get } from "svelte/store";
|
||||
|
||||
const KEY = "8cut-settings";
|
||||
|
||||
interface Settings {
|
||||
serverUrl: string;
|
||||
quality: string;
|
||||
clips: number;
|
||||
spread: number;
|
||||
shortSide: number | null;
|
||||
portraitRatio: string | null;
|
||||
format: string;
|
||||
hwEncode: boolean;
|
||||
profile: string;
|
||||
subprofiles: string[];
|
||||
}
|
||||
|
||||
export function saveSettings() {
|
||||
const data: Settings = {
|
||||
serverUrl: get(serverUrl),
|
||||
quality: get(quality),
|
||||
clips: get(clips),
|
||||
spread: get(spread),
|
||||
shortSide: get(shortSide),
|
||||
portraitRatio: get(portraitRatio),
|
||||
format: get(format),
|
||||
hwEncode: get(hwEncode),
|
||||
profile: get(profile),
|
||||
subprofiles: get(subprofiles),
|
||||
};
|
||||
localStorage.setItem(KEY, JSON.stringify(data));
|
||||
}
|
||||
|
||||
export function loadSettings() {
|
||||
const raw = localStorage.getItem(KEY);
|
||||
if (!raw) return;
|
||||
try {
|
||||
const data: Settings = JSON.parse(raw);
|
||||
if (data.serverUrl) {
|
||||
serverUrl.set(data.serverUrl);
|
||||
setServer(data.serverUrl);
|
||||
}
|
||||
if (data.quality) quality.set(data.quality);
|
||||
if (data.clips) clips.set(data.clips);
|
||||
if (data.spread) spread.set(data.spread);
|
||||
if (data.shortSide !== undefined) shortSide.set(data.shortSide);
|
||||
if (data.portraitRatio !== undefined) portraitRatio.set(data.portraitRatio);
|
||||
if (data.format) format.set(data.format);
|
||||
if (data.hwEncode !== undefined) hwEncode.set(data.hwEncode);
|
||||
if (data.profile) profile.set(data.profile);
|
||||
if (data.subprofiles) subprofiles.set(data.subprofiles);
|
||||
} catch { /* ignore corrupt settings */ }
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
import { writable, derived } from "svelte/store";
|
||||
import type { VideoFile, Marker } from "./api";
|
||||
|
||||
// --- Connection ---
|
||||
export const serverUrl = writable("http://192.168.1.51:8000");
|
||||
|
||||
// --- Files ---
|
||||
export const roots = writable<string[]>([]);
|
||||
export const files = writable<VideoFile[]>([]);
|
||||
export const hiddenFiles = writable<Set<string>>(new Set());
|
||||
export const currentFile = writable<VideoFile | null>(null);
|
||||
export const hideExported = writable(false);
|
||||
export const showHidden = writable(false);
|
||||
|
||||
// --- Playback ---
|
||||
export const duration = writable(0);
|
||||
export const cursor = writable(0);
|
||||
export const playPos = writable<number | null>(null);
|
||||
export const playing = writable(false);
|
||||
export const quality = writable("low");
|
||||
|
||||
// --- Timeline ---
|
||||
export const markers = writable<Marker[]>([]);
|
||||
export const locked = writable(false);
|
||||
|
||||
// --- Export settings ---
|
||||
export const clips = writable(3);
|
||||
export const spread = writable(3.0);
|
||||
export const shortSide = writable<number | null>(512);
|
||||
export const portraitRatio = writable<string | null>(null);
|
||||
export const cropCenter = writable(0.5);
|
||||
export const format = writable("MP4");
|
||||
export const hwEncode = writable(false);
|
||||
export const label = writable("");
|
||||
export const category = writable("");
|
||||
export const clipName = writable("");
|
||||
export const exportFolder = writable("");
|
||||
export const encoder = writable("libx264");
|
||||
export const trackSubject = writable(false);
|
||||
export const randPortrait = writable(false);
|
||||
export const randSquare = writable(false);
|
||||
|
||||
// --- Profiles ---
|
||||
export const profile = writable("default");
|
||||
export const subprofiles = writable<string[]>([]);
|
||||
|
||||
// --- Export progress ---
|
||||
export const exportStatus = writable<string>("idle"); // idle | running | done | error
|
||||
export const exportCompleted = writable(0);
|
||||
export const exportTotal = writable(0);
|
||||
|
||||
// --- Derived ---
|
||||
export const clipSpan = derived(
|
||||
[clips, spread],
|
||||
([$clips, $spread]) => 8.0 + ($clips - 1) * $spread
|
||||
);
|
||||
|
||||
export const visibleFiles = derived(
|
||||
[files, hiddenFiles, showHidden],
|
||||
([$files, $hidden, $showHidden]) => {
|
||||
return $files.filter(f => {
|
||||
if (!$showHidden && $hidden.has(f.name)) return false;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
);
|
||||
@@ -1,48 +0,0 @@
|
||||
import { getServer } from "./api";
|
||||
import { exportStatus, exportCompleted } from "./stores";
|
||||
|
||||
let socket: WebSocket | null = null;
|
||||
let reconnectDelay = 2000;
|
||||
|
||||
export function connectExportWs() {
|
||||
const wsUrl = getServer().replace(/^http/, "ws") + "/ws/export";
|
||||
socket = new WebSocket(wsUrl);
|
||||
|
||||
socket.onopen = () => {
|
||||
reconnectDelay = 2000; // reset backoff on successful connect
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
switch (msg.type) {
|
||||
case "clip_done":
|
||||
exportCompleted.update(n => n + 1);
|
||||
break;
|
||||
case "all_done":
|
||||
exportStatus.set("done");
|
||||
break;
|
||||
case "error":
|
||||
exportStatus.set("error");
|
||||
console.error("Export error:", msg.msg);
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to parse WebSocket message:", e);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onclose = () => {
|
||||
// Reconnect with exponential backoff, max 30s
|
||||
setTimeout(connectExportWs, reconnectDelay);
|
||||
reconnectDelay = Math.min(reconnectDelay * 2, 30000);
|
||||
};
|
||||
}
|
||||
|
||||
export function disconnectExportWs() {
|
||||
if (socket) {
|
||||
socket.onclose = null; // prevent reconnect
|
||||
socket.close();
|
||||
socket = null;
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Tauri doesn't have a Node.js server to do proper SSR
|
||||
// so we use adapter-static with a fallback to index.html to put the site in SPA mode
|
||||
// See: https://svelte.dev/docs/kit/single-page-apps
|
||||
// See: https://v2.tauri.app/start/frontend/sveltekit/ for more info
|
||||
export const ssr = false;
|
||||
@@ -1,251 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from "svelte";
|
||||
import FileBrowser from "../components/FileBrowser.svelte";
|
||||
import Timeline from "../components/Timeline.svelte";
|
||||
import ExportPanel from "../components/ExportPanel.svelte";
|
||||
import ProfileBar from "../components/ProfileBar.svelte";
|
||||
import { mpvStart, mpvLoad, mpvSeek, mpvPause, mpvResume, mpvSetLoop, mpvClearLoop, mpvTimePos, mpvDuration } from "$lib/mpv";
|
||||
import { streamUrl, audioUrl, waitForCache, deleteExport, getMarkers } from "$lib/api";
|
||||
import { connectExportWs, disconnectExportWs } from "$lib/ws";
|
||||
import { loadSettings, saveSettings } from "$lib/settings";
|
||||
import {
|
||||
currentFile, cursor, duration, playPos, playing, quality,
|
||||
clips, spread, locked, markers, profile, clipSpan, subprofiles
|
||||
} from "$lib/stores";
|
||||
|
||||
let pollInterval: ReturnType<typeof setInterval>;
|
||||
let exportPanelRef: ExportPanel;
|
||||
|
||||
onMount(async () => {
|
||||
loadSettings();
|
||||
|
||||
await mpvStart();
|
||||
connectExportWs();
|
||||
|
||||
// Poll mpv for time position
|
||||
pollInterval = setInterval(async () => {
|
||||
if ($playing) {
|
||||
try {
|
||||
$playPos = await mpvTimePos();
|
||||
} catch { /* mpv not ready */ }
|
||||
}
|
||||
}, 50);
|
||||
|
||||
// Auto-save settings on changes
|
||||
const unsubs = [
|
||||
quality.subscribe(() => saveSettings()),
|
||||
clips.subscribe(() => saveSettings()),
|
||||
spread.subscribe(() => saveSettings()),
|
||||
profile.subscribe(() => saveSettings()),
|
||||
subprofiles.subscribe(() => saveSettings()),
|
||||
];
|
||||
return () => unsubs.forEach(u => u());
|
||||
});
|
||||
|
||||
onDestroy(() => {
|
||||
clearInterval(pollInterval);
|
||||
disconnectExportWs();
|
||||
});
|
||||
|
||||
// Load file into mpv when currentFile OR quality changes
|
||||
let loadAbort: AbortController | null = null;
|
||||
$effect(() => {
|
||||
const file = $currentFile;
|
||||
const q = $quality;
|
||||
if (file) {
|
||||
// Cancel any previous polling
|
||||
loadAbort?.abort();
|
||||
const ac = new AbortController();
|
||||
loadAbort = ac;
|
||||
|
||||
const vUrl = streamUrl(file.path, file.root, q);
|
||||
const aUrl = audioUrl(file.path, file.root);
|
||||
waitForCache(file.path, file.root, q, ac.signal).then(() =>
|
||||
mpvLoad(vUrl, aUrl)
|
||||
).then(async () => {
|
||||
await new Promise(r => setTimeout(r, 500));
|
||||
try { $duration = await mpvDuration(); } catch {}
|
||||
}).catch(() => {}); // aborted or error
|
||||
}
|
||||
});
|
||||
|
||||
async function handleCursorChange(time: number) {
|
||||
await mpvSeek(time);
|
||||
}
|
||||
|
||||
async function handlePlay() {
|
||||
const a = $cursor;
|
||||
const b = $cursor + $clipSpan;
|
||||
await mpvSeek(a);
|
||||
await mpvSetLoop(a, b);
|
||||
await mpvResume();
|
||||
$playing = true;
|
||||
}
|
||||
|
||||
async function handlePause() {
|
||||
await mpvPause();
|
||||
await mpvClearLoop();
|
||||
$playing = false;
|
||||
}
|
||||
|
||||
async function handleMarkerClick(m: { start_time: number; output_path: string }) {
|
||||
if ($locked) {
|
||||
const span = 8.0 + ($clips - 1) * $spread;
|
||||
$cursor = m.start_time + span;
|
||||
await mpvSeek($cursor);
|
||||
} else {
|
||||
$cursor = m.start_time;
|
||||
await mpvSeek(m.start_time);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMarkerDelete(outputPath: string) {
|
||||
await deleteExport(outputPath);
|
||||
if ($currentFile) {
|
||||
$markers = await getMarkers($currentFile.name, $profile);
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeydown(e: KeyboardEvent) {
|
||||
const tag = (e.target as HTMLElement).tagName;
|
||||
if (tag === "INPUT" || tag === "SELECT" || tag === "TEXTAREA") return;
|
||||
|
||||
switch (e.key) {
|
||||
case " ":
|
||||
e.preventDefault();
|
||||
$playing ? handlePause() : handlePlay();
|
||||
break;
|
||||
case "e":
|
||||
case "E":
|
||||
exportPanelRef?.doExport();
|
||||
break;
|
||||
case "ArrowLeft":
|
||||
$cursor = Math.max(0, $cursor - 1);
|
||||
handleCursorChange($cursor);
|
||||
break;
|
||||
case "ArrowRight":
|
||||
$cursor = Math.min($duration, $cursor + 1);
|
||||
handleCursorChange($cursor);
|
||||
break;
|
||||
}
|
||||
|
||||
const num = parseInt(e.key);
|
||||
if (num >= 1 && num <= 9) {
|
||||
const idx = num - 1;
|
||||
if (idx < $subprofiles.length) {
|
||||
exportPanelRef?.doExport($subprofiles[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function fmtTime(s: number): string {
|
||||
const m = Math.floor(s / 60);
|
||||
const sec = (Math.floor(s % 60 * 10) / 10).toFixed(1);
|
||||
return `${m}:${sec.padStart(4, "0")}`;
|
||||
}
|
||||
</script>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
|
||||
<main>
|
||||
<div class="layout">
|
||||
<div class="sidebar">
|
||||
<FileBrowser />
|
||||
</div>
|
||||
<div class="content">
|
||||
<ProfileBar />
|
||||
<div class="player-area">
|
||||
<div class="video-placeholder">
|
||||
{#if $currentFile}
|
||||
<p>{$currentFile.name}</p>
|
||||
{:else}
|
||||
<p>Select a file</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<Timeline
|
||||
onCursorChange={handleCursorChange}
|
||||
onSeek={handleCursorChange}
|
||||
onMarkerClick={handleMarkerClick}
|
||||
onMarkerDelete={handleMarkerDelete}
|
||||
/>
|
||||
<div class="transport">
|
||||
<button onclick={handlePlay} disabled={!$currentFile}>Play</button>
|
||||
<button onclick={handlePause}>Pause</button>
|
||||
<button onclick={() => $locked = !$locked}>
|
||||
{$locked ? "Locked" : "Unlocked"}
|
||||
</button>
|
||||
<span class="time">
|
||||
{#if $duration > 0}
|
||||
{fmtTime($cursor)} / {fmtTime($duration)}
|
||||
{/if}
|
||||
</span>
|
||||
<select bind:value={$quality} style="margin-left:auto">
|
||||
<option value="potato">480p</option>
|
||||
<option value="low">720p</option>
|
||||
<option value="medium">1080p</option>
|
||||
<option value="high">Original</option>
|
||||
</select>
|
||||
</div>
|
||||
<ExportPanel bind:this={exportPanelRef} />
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
<style>
|
||||
:global(body) {
|
||||
margin: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
background: #1e1e1e;
|
||||
color: #e0e0e0;
|
||||
}
|
||||
main { height: 100vh; overflow: hidden; }
|
||||
.layout {
|
||||
display: flex;
|
||||
height: 100%;
|
||||
}
|
||||
.sidebar {
|
||||
width: 220px;
|
||||
min-width: 220px;
|
||||
flex-shrink: 0;
|
||||
border-right: 1px solid #333;
|
||||
overflow: hidden;
|
||||
}
|
||||
.content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
.player-area {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: #000;
|
||||
min-height: 200px;
|
||||
}
|
||||
.video-placeholder {
|
||||
color: #666;
|
||||
text-align: center;
|
||||
}
|
||||
.transport {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 4px 8px;
|
||||
background: #222;
|
||||
}
|
||||
.transport button {
|
||||
background: #333;
|
||||
color: #e0e0e0;
|
||||
border: 1px solid #555;
|
||||
padding: 4px 10px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.time {
|
||||
font-family: monospace;
|
||||
font-size: 13px;
|
||||
}
|
||||
select { background: #2d2d2d; color: #e0e0e0; border: 1px solid #444; }
|
||||
</style>
|
||||
|
Before Width: | Height: | Size: 1.5 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="26.6" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 308"><path fill="#FF3E00" d="M239.682 40.707C211.113-.182 154.69-12.301 113.895 13.69L42.247 59.356a82.198 82.198 0 0 0-37.135 55.056a86.566 86.566 0 0 0 8.536 55.576a82.425 82.425 0 0 0-12.296 30.719a87.596 87.596 0 0 0 14.964 66.244c28.574 40.893 84.997 53.007 125.787 27.016l71.648-45.664a82.182 82.182 0 0 0 37.135-55.057a86.601 86.601 0 0 0-8.53-55.577a82.409 82.409 0 0 0 12.29-30.718a87.573 87.573 0 0 0-14.963-66.244"></path><path fill="#FFF" d="M106.889 270.841c-23.102 6.007-47.497-3.036-61.103-22.648a52.685 52.685 0 0 1-9.003-39.85a49.978 49.978 0 0 1 1.713-6.693l1.35-4.115l3.671 2.697a92.447 92.447 0 0 0 28.036 14.007l2.663.808l-.245 2.659a16.067 16.067 0 0 0 2.89 10.656a17.143 17.143 0 0 0 18.397 6.828a15.786 15.786 0 0 0 4.403-1.935l71.67-45.672a14.922 14.922 0 0 0 6.734-9.977a15.923 15.923 0 0 0-2.713-12.011a17.156 17.156 0 0 0-18.404-6.832a15.78 15.78 0 0 0-4.396 1.933l-27.35 17.434a52.298 52.298 0 0 1-14.553 6.391c-23.101 6.007-47.497-3.036-61.101-22.649a52.681 52.681 0 0 1-9.004-39.849a49.428 49.428 0 0 1 22.34-33.114l71.664-45.677a52.218 52.218 0 0 1 14.563-6.398c23.101-6.007 47.497 3.036 61.101 22.648a52.685 52.685 0 0 1 9.004 39.85a50.559 50.559 0 0 1-1.713 6.692l-1.35 4.116l-3.67-2.693a92.373 92.373 0 0 0-28.037-14.013l-2.664-.809l.246-2.658a16.099 16.099 0 0 0-2.89-10.656a17.143 17.143 0 0 0-18.398-6.828a15.786 15.786 0 0 0-4.402 1.935l-71.67 45.674a14.898 14.898 0 0 0-6.73 9.975a15.9 15.9 0 0 0 2.709 12.012a17.156 17.156 0 0 0 18.404 6.832a15.841 15.841 0 0 0 4.402-1.935l27.345-17.427a52.147 52.147 0 0 1 14.552-6.397c23.101-6.006 47.497 3.037 61.102 22.65a52.681 52.681 0 0 1 9.003 39.848a49.453 49.453 0 0 1-22.34 33.12l-71.664 45.673a52.218 52.218 0 0 1-14.563 6.398"></path></svg>
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
@@ -1,6 +0,0 @@
|
||||
<svg width="206" height="231" viewBox="0 0 206 231" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M143.143 84C143.143 96.1503 133.293 106 121.143 106C108.992 106 99.1426 96.1503 99.1426 84C99.1426 71.8497 108.992 62 121.143 62C133.293 62 143.143 71.8497 143.143 84Z" fill="#FFC131"/>
|
||||
<ellipse cx="84.1426" cy="147" rx="22" ry="22" transform="rotate(180 84.1426 147)" fill="#24C8DB"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M166.738 154.548C157.86 160.286 148.023 164.269 137.757 166.341C139.858 160.282 141 153.774 141 147C141 144.543 140.85 142.121 140.558 139.743C144.975 138.204 149.215 136.139 153.183 133.575C162.73 127.404 170.292 118.608 174.961 108.244C179.63 97.8797 181.207 86.3876 179.502 75.1487C177.798 63.9098 172.884 53.4021 165.352 44.8883C157.82 36.3744 147.99 30.2165 137.042 27.1546C126.095 24.0926 114.496 24.2568 103.64 27.6274C92.7839 30.998 83.1319 37.4317 75.8437 46.1553C74.9102 47.2727 74.0206 48.4216 73.176 49.5993C61.9292 50.8488 51.0363 54.0318 40.9629 58.9556C44.2417 48.4586 49.5653 38.6591 56.679 30.1442C67.0505 17.7298 80.7861 8.57426 96.2354 3.77762C111.685 -1.01901 128.19 -1.25267 143.769 3.10474C159.348 7.46215 173.337 16.2252 184.056 28.3411C194.775 40.457 201.767 55.4101 204.193 71.404C206.619 87.3978 204.374 103.752 197.73 118.501C191.086 133.25 180.324 145.767 166.738 154.548ZM41.9631 74.275L62.5557 76.8042C63.0459 72.813 63.9401 68.9018 65.2138 65.1274C57.0465 67.0016 49.2088 70.087 41.9631 74.275Z" fill="#FFC131"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M38.4045 76.4519C47.3493 70.6709 57.2677 66.6712 67.6171 64.6132C65.2774 70.9669 64 77.8343 64 85.0001C64 87.1434 64.1143 89.26 64.3371 91.3442C60.0093 92.8732 55.8533 94.9092 51.9599 97.4256C42.4128 103.596 34.8505 112.392 30.1816 122.756C25.5126 133.12 23.9357 144.612 25.6403 155.851C27.3449 167.09 32.2584 177.598 39.7906 186.112C47.3227 194.626 57.153 200.784 68.1003 203.846C79.0476 206.907 90.6462 206.743 101.502 203.373C112.359 200.002 122.011 193.568 129.299 184.845C130.237 183.722 131.131 182.567 131.979 181.383C143.235 180.114 154.132 176.91 164.205 171.962C160.929 182.49 155.596 192.319 148.464 200.856C138.092 213.27 124.357 222.426 108.907 227.222C93.458 232.019 76.9524 232.253 61.3736 227.895C45.7948 223.538 31.8055 214.775 21.0867 202.659C10.3679 190.543 3.37557 175.59 0.949823 159.596C-1.47592 143.602 0.768139 127.248 7.41237 112.499C14.0566 97.7497 24.8183 85.2327 38.4045 76.4519ZM163.062 156.711L163.062 156.711C162.954 156.773 162.846 156.835 162.738 156.897C162.846 156.835 162.954 156.773 163.062 156.711Z" fill="#24C8DB"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.5 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>
|
||||
|
Before Width: | Height: | Size: 1.5 KiB |
@@ -1,18 +0,0 @@
|
||||
// Tauri doesn't have a Node.js server to do proper SSR
|
||||
// so we use adapter-static with a fallback to index.html to put the site in SPA mode
|
||||
// See: https://svelte.dev/docs/kit/single-page-apps
|
||||
// See: https://v2.tauri.app/start/frontend/sveltekit/ for more info
|
||||
import adapter from "@sveltejs/adapter-static";
|
||||
import { vitePreprocess } from "@sveltejs/vite-plugin-svelte";
|
||||
|
||||
/** @type {import('@sveltejs/kit').Config} */
|
||||
const config = {
|
||||
preprocess: vitePreprocess(),
|
||||
kit: {
|
||||
adapter: adapter({
|
||||
fallback: "index.html",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export default config;
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"extends": "./.svelte-kit/tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"allowJs": true,
|
||||
"checkJs": true,
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"skipLibCheck": true,
|
||||
"sourceMap": true,
|
||||
"strict": true,
|
||||
"moduleResolution": "bundler"
|
||||
}
|
||||
// Path aliases are handled by https://svelte.dev/docs/kit/configuration#alias
|
||||
// except $lib which is handled by https://svelte.dev/docs/kit/configuration#files
|
||||
//
|
||||
// If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes
|
||||
// from the referenced tsconfig.json - TypeScript does not merge them in
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
import { defineConfig } from "vite";
|
||||
import { sveltekit } from "@sveltejs/kit/vite";
|
||||
|
||||
// @ts-expect-error process is a nodejs global
|
||||
const host = process.env.TAURI_DEV_HOST;
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig(async () => ({
|
||||
plugins: [sveltekit()],
|
||||
|
||||
// Vite options tailored for Tauri development and only applied in `tauri dev` or `tauri build`
|
||||
//
|
||||
// 1. prevent Vite from obscuring rust errors
|
||||
clearScreen: false,
|
||||
// 2. tauri expects a fixed port, fail if that port is not available
|
||||
server: {
|
||||
port: 1420,
|
||||
strictPort: true,
|
||||
host: host || false,
|
||||
hmr: host
|
||||
? {
|
||||
protocol: "ws",
|
||||
host,
|
||||
port: 1421,
|
||||
}
|
||||
: undefined,
|
||||
watch: {
|
||||
// 3. tell Vite to ignore watching `src-tauri`
|
||||
ignored: ["**/src-tauri/**"],
|
||||
},
|
||||
},
|
||||
}));
|
||||
@@ -1,2 +1,6 @@
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "gui: constructs Qt widgets; needs a display")
|
||||
|
||||
@@ -0,0 +1,809 @@
|
||||
"""Audio scanning — embedding-based classifier for audio event detection."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
_SR = 16000 # lower sr = faster
|
||||
|
||||
|
||||
def _load_audio_ffmpeg(path: str, sr: int = _SR) -> np.ndarray:
|
||||
"""Load audio from any file as mono float32 numpy array using ffmpeg directly."""
|
||||
cmd = [
|
||||
_bin("ffmpeg"), "-i", path,
|
||||
"-vn", # skip video
|
||||
"-ac", "1", # mono
|
||||
"-ar", str(sr), # resample
|
||||
"-f", "f32le", # raw 32-bit float little-endian
|
||||
"-loglevel", "error",
|
||||
"pipe:1",
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"ffmpeg timed out (300s) on {os.path.basename(path)}")
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode().strip()}")
|
||||
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
_WINDOW = 8.0 # seconds
|
||||
_PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_MODEL_DIR = os.path.join(_PROJECT_DIR, "models")
|
||||
_W2V_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "w2v")
|
||||
_DL_CACHE_DIR = os.path.join(_PROJECT_DIR, "cache", "downloads")
|
||||
|
||||
# Redirect torch hub and huggingface downloads into the project
|
||||
os.environ.setdefault("TORCH_HOME", _DL_CACHE_DIR)
|
||||
os.environ.setdefault("HF_HOME", os.path.join(_DL_CACHE_DIR, "huggingface"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding extraction (lazy-loaded)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_w2v_model = None
|
||||
_w2v_device = None
|
||||
_w2v_model_name = None
|
||||
_ast_feature_extractor = None
|
||||
|
||||
# Supported embedding models — name → embed_dim
|
||||
_EMBED_MODELS = {
|
||||
"WAV2VEC2_BASE": 768,
|
||||
"WAV2VEC2_LARGE": 1024,
|
||||
"WAV2VEC2_LARGE_LV60K":1024,
|
||||
"HUBERT_BASE": 768,
|
||||
"HUBERT_LARGE": 1024,
|
||||
"HUBERT_XLARGE": 1280,
|
||||
"BEATS": 768,
|
||||
# Multi-layer variants (4 quartile layers concatenated)
|
||||
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
"EAT": 768,
|
||||
"EAT_LARGE": 1024,
|
||||
}
|
||||
_DEFAULT_EMBED_MODEL = "EAT_LARGE"
|
||||
|
||||
_BEATS_CHECKPOINT = os.path.join(
|
||||
_DL_CACHE_DIR, "huggingface", "hub",
|
||||
"models--lpepino--beats_ckpts", "snapshots",
|
||||
"5b53b0404df452a3a607d7e67687227730e5bad1", "BEATs_iter3_plus_AS2M.pt",
|
||||
)
|
||||
|
||||
|
||||
def _get_w2v_model(model_name: str | None = None):
|
||||
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||
global _w2v_model, _w2v_device, _w2v_model_name
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
# Multi-layer variants use the same base model weights
|
||||
ml = _ml_config(model_name)
|
||||
load_name = ml[0] if ml else model_name
|
||||
if _w2v_model is None or _w2v_model_name != load_name:
|
||||
import torch
|
||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if load_name == "BEATS":
|
||||
from .beats_model import BEATs, BEATsConfig
|
||||
checkpoint = torch.load(_BEATS_CHECKPOINT, map_location=_w2v_device,
|
||||
weights_only=False)
|
||||
cfg = BEATsConfig(checkpoint['cfg'])
|
||||
_w2v_model = BEATs(cfg)
|
||||
_w2v_model.load_state_dict(checkpoint['model'])
|
||||
_w2v_model.to(_w2v_device)
|
||||
elif load_name == "AST":
|
||||
from transformers import ASTModel, ASTFeatureExtractor
|
||||
_w2v_model = ASTModel.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
).to(_w2v_device)
|
||||
global _ast_feature_extractor
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
elif load_name in ("EAT", "EAT_LARGE"):
|
||||
from transformers import AutoModel
|
||||
eat_repo = ("worstchan/EAT-large_epoch20_finetune_AS2M"
|
||||
if load_name == "EAT_LARGE"
|
||||
else "worstchan/EAT-base_epoch30_finetune_AS2M")
|
||||
_w2v_model = AutoModel.from_pretrained(
|
||||
eat_repo, trust_remote_code=True,
|
||||
).to(_w2v_device)
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||
|
||||
_w2v_model.eval()
|
||||
_w2v_model_name = load_name
|
||||
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||
return _w2v_model, _w2v_device
|
||||
|
||||
|
||||
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||
|
||||
Returns tensor of shape [B, 1, T, 128].
|
||||
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
TARGET_LEN = 1024
|
||||
MEAN, STD = -4.268, 4.569
|
||||
|
||||
mels = []
|
||||
for chunk in chunks:
|
||||
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||
fbank = kaldi.fbank(
|
||||
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||
)
|
||||
# Pad or truncate to TARGET_LEN
|
||||
if fbank.shape[0] < TARGET_LEN:
|
||||
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||
else:
|
||||
fbank = fbank[:TARGET_LEN]
|
||||
fbank = (fbank - MEAN) / (STD * 2)
|
||||
mels.append(fbank)
|
||||
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||
|
||||
|
||||
def _embed_dim(model_name: str | None = None) -> int:
|
||||
"""Return embedding dimension for a model name."""
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
return _EMBED_MODELS.get(model_name, 768)
|
||||
|
||||
|
||||
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||
|
||||
Returns None for single-layer models.
|
||||
Layer indices are 0-based into the list returned by extract_features().
|
||||
"""
|
||||
if not model_name.endswith("_ML"):
|
||||
return None
|
||||
base = model_name[:-3] # strip "_ML"
|
||||
if base not in _EMBED_MODELS:
|
||||
return None
|
||||
# Layer counts per model family
|
||||
layer_counts = {
|
||||
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||
"AST": 12,
|
||||
}
|
||||
n = layer_counts.get(base)
|
||||
if n is None:
|
||||
return None
|
||||
# Select 4 layers at quartile boundaries (0-indexed)
|
||||
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||
return base, indices
|
||||
|
||||
|
||||
def _w2v_cache_path(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> str:
|
||||
"""Return cache file path for a video's embeddings (includes model name)."""
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
abspath = os.path.abspath(video_path)
|
||||
mtime = os.path.getmtime(abspath)
|
||||
key = f"{abspath}|{mtime}|{hop}|{window}|{model_name}"
|
||||
h = hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||
return os.path.join(_W2V_CACHE_DIR, f"{h}.npz")
|
||||
|
||||
|
||||
def _w2v_cache_exists(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> bool:
|
||||
"""Check if embedding cache exists for a video."""
|
||||
try:
|
||||
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
return os.path.exists(path)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _w2v_cache_load(video_path: str, hop: float, window: float,
|
||||
model_name: str | None = None) -> tuple[np.ndarray, np.ndarray] | None:
|
||||
"""Load embeddings from cache. Returns (timestamps, embeddings) or None."""
|
||||
try:
|
||||
path = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
if os.path.exists(path):
|
||||
data = np.load(path)
|
||||
_log(f"audio_scan: cache hit ({path})")
|
||||
return data["timestamps"], data["embeddings"]
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache read failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
||||
hop: float = 1.0, window: float = _WINDOW,
|
||||
video_path: str | None = None,
|
||||
cancel_flag: object = None,
|
||||
model_name: str | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Extract embeddings for all sliding windows using a torchaudio model.
|
||||
|
||||
If video_path is given, results are cached to disk for fast re-scans.
|
||||
Returns (timestamps, embeddings) where embeddings is (N, D).
|
||||
"""
|
||||
edim = _embed_dim(model_name)
|
||||
|
||||
# Try loading from cache
|
||||
cache_file = None
|
||||
if video_path:
|
||||
try:
|
||||
cache_file = _w2v_cache_path(video_path, hop, window, model_name)
|
||||
if os.path.exists(cache_file):
|
||||
data = np.load(cache_file)
|
||||
_log(f"audio_scan: cache hit ({cache_file})")
|
||||
return data["timestamps"], data["embeddings"]
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache read failed: {e}")
|
||||
|
||||
win_samples = int(window * sr)
|
||||
hop_samples = int(hop * sr)
|
||||
n_windows = max(0, (len(y) - win_samples) // hop_samples + 1)
|
||||
|
||||
if n_windows == 0:
|
||||
return np.array([]), np.empty((0, edim))
|
||||
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
# Auto-size batches based on available GPU memory
|
||||
batch_size = 16
|
||||
if device == "cuda":
|
||||
try:
|
||||
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
||||
if vram_gb >= 16:
|
||||
batch_size = 64
|
||||
elif vram_gb >= 8:
|
||||
batch_size = 32
|
||||
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
|
||||
except Exception:
|
||||
pass
|
||||
timestamps = np.arange(n_windows) * hop
|
||||
embeddings = []
|
||||
|
||||
for batch_start in range(0, n_windows, batch_size):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return np.array([]), np.empty((0, edim))
|
||||
batch_end = min(batch_start + batch_size, n_windows)
|
||||
chunks = []
|
||||
for i in range(batch_start, batch_end):
|
||||
start = i * hop_samples
|
||||
chunks.append(y[start:start + win_samples])
|
||||
with torch.no_grad():
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
else:
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
|
||||
result_ts = timestamps
|
||||
result_emb = np.vstack(embeddings)
|
||||
|
||||
# Save to cache
|
||||
if cache_file:
|
||||
try:
|
||||
os.makedirs(_W2V_CACHE_DIR, exist_ok=True)
|
||||
np.savez(cache_file, timestamps=result_ts, embeddings=result_emb)
|
||||
_log(f"audio_scan: w2v cache saved ({cache_file})")
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: cache write failed: {e}")
|
||||
|
||||
return result_ts, result_emb
|
||||
|
||||
|
||||
def _extract_w2v_targeted(y: np.ndarray, sr: int, gt_intense: list[float],
|
||||
gt_soft: list[float], tolerance: float = 12.0,
|
||||
neg_margin: float = 120.0,
|
||||
model_name: str | None = None,
|
||||
gt_negative: list[float] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Extract embeddings only near positives and distant negatives.
|
||||
|
||||
Returns (timestamps, embeddings, labels) where labels: 1=pos, -1=neg, 0=ambig.
|
||||
"""
|
||||
edim = _embed_dim(model_name)
|
||||
duration = len(y) / sr
|
||||
win_samples = int(_WINDOW * sr)
|
||||
all_gt = list(gt_intense) + list(gt_soft)
|
||||
|
||||
# Positive windows: every second near intense markers
|
||||
pos_times = set()
|
||||
for gt in gt_intense:
|
||||
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||
t = gt + offset
|
||||
if 0 <= t <= duration - _WINDOW:
|
||||
pos_times.add(int(t))
|
||||
|
||||
# Manual negative windows: near explicit negative markers
|
||||
manual_neg_times = set()
|
||||
if gt_negative:
|
||||
for gt in gt_negative:
|
||||
for offset in range(-int(tolerance), int(tolerance) + 1):
|
||||
t = gt + offset
|
||||
if 0 <= t <= duration - _WINDOW:
|
||||
manual_neg_times.add(int(t))
|
||||
# Don't let manual negatives overlap with positives
|
||||
manual_neg_times -= pos_times
|
||||
|
||||
# Auto negative windows: every 4s, far from any marker (skip if margin <= 0 or no markers)
|
||||
neg_times = set()
|
||||
if all_gt and neg_margin > 0:
|
||||
for t in range(0, int(duration - _WINDOW), 4):
|
||||
if min(abs(t - g) for g in all_gt) > neg_margin:
|
||||
neg_times.add(t)
|
||||
|
||||
all_times = sorted(pos_times | neg_times | manual_neg_times)
|
||||
# Filter out windows that go past the end
|
||||
valid_times = [t for t in all_times if int(t * sr) + win_samples <= len(y)]
|
||||
|
||||
if not valid_times:
|
||||
return np.array([]), np.zeros((0, edim)), np.array([], dtype=int)
|
||||
|
||||
import torch
|
||||
model, device = _get_w2v_model(model_name)
|
||||
batch_size = 16
|
||||
timestamps_list: list[float] = []
|
||||
embeddings_list: list[np.ndarray] = []
|
||||
|
||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) in ("EAT", "EAT_LARGE")
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
|
||||
for batch_start in range(0, len(valid_times), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(valid_times))
|
||||
chunks = []
|
||||
for t in valid_times[batch_start:batch_end]:
|
||||
start = int(t * sr)
|
||||
chunks.append(y[start:start + win_samples])
|
||||
timestamps_list.append(float(t))
|
||||
with torch.no_grad():
|
||||
if is_ast:
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
else:
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
|
||||
timestamps = np.array(timestamps_list)
|
||||
embeddings = np.vstack(embeddings_list)
|
||||
|
||||
labels = np.zeros(len(timestamps), dtype=int)
|
||||
for i, t in enumerate(timestamps):
|
||||
di = min((abs(t - g) for g in gt_intense), default=9999)
|
||||
da = min((abs(t - g) for g in all_gt), default=9999)
|
||||
dm = min((abs(t - g) for g in (gt_negative or [])), default=9999)
|
||||
if di < tolerance:
|
||||
labels[i] = 1
|
||||
elif dm < tolerance or (neg_margin > 0 and da > neg_margin):
|
||||
labels[i] = -1
|
||||
return timestamps, embeddings, labels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classifier mode — train / save / load / scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train_classifier(video_infos: list[tuple[str, list[float], list[float]]],
|
||||
model_path: str | None = None,
|
||||
tolerance: float = 12.0,
|
||||
neg_margin: float = 120.0,
|
||||
embed_model: str | None = None,
|
||||
cancel_flag: object = None,
|
||||
n_workers: int = 4,
|
||||
progress_cb: object = None) -> dict:
|
||||
"""Train a classifier from labeled videos.
|
||||
|
||||
Args:
|
||||
video_infos: list of (video_path, intense_times, soft_times)
|
||||
model_path: if given, save model to this path
|
||||
tolerance/neg_margin: labeling parameters
|
||||
embed_model: embedding model name (e.g. "HUBERT_BASE", "BEATS"), defaults to WAV2VEC2_BASE
|
||||
cancel_flag: object with _cancel attribute; if set, training aborts early
|
||||
n_workers: number of threads for parallel audio loading
|
||||
|
||||
Returns:
|
||||
dict with 'classifier', 'embed_model', and metadata, or None on failure.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||
|
||||
def _progress(msg: str) -> None:
|
||||
_log(msg)
|
||||
if progress_cb:
|
||||
progress_cb(msg)
|
||||
|
||||
def _load_audio(path: str) -> np.ndarray:
|
||||
return _load_audio_ffmpeg(path, sr=_SR)
|
||||
|
||||
# Phase 1: load all audio in parallel (cap workers — disk I/O bound)
|
||||
n = len(video_infos)
|
||||
load_workers = min(n_workers, 4)
|
||||
_progress(f"Loading audio: 0/{n} videos ({load_workers} workers)...")
|
||||
audio_data: dict[int, np.ndarray] = {}
|
||||
with ThreadPoolExecutor(max_workers=load_workers) as pool:
|
||||
future_to_idx = {
|
||||
pool.submit(_load_audio, vi[0]): i
|
||||
for i, vi in enumerate(video_infos)
|
||||
}
|
||||
failed = set()
|
||||
for future in as_completed(future_to_idx):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: training cancelled")
|
||||
return None
|
||||
idx = future_to_idx[future]
|
||||
try:
|
||||
audio_data[idx] = future.result()
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: failed to load {os.path.basename(video_infos[idx][0])}: {e}")
|
||||
failed.add(idx)
|
||||
_progress(f"Loading audio: {len(audio_data) + len(failed)}/{n}")
|
||||
|
||||
# Phase 2: extract embeddings sequentially on GPU
|
||||
_progress(f"Extracting embeddings: 0/{n}")
|
||||
all_X, all_y = [], []
|
||||
for vi, vinfo in enumerate(video_infos):
|
||||
if vi in failed:
|
||||
continue
|
||||
vpath, gt_intense, gt_soft = vinfo[0], vinfo[1], vinfo[2]
|
||||
gt_negative = vinfo[3] if len(vinfo) > 3 else []
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: training cancelled")
|
||||
return None
|
||||
_progress(f"Extracting embeddings: {vi+1}/{n}")
|
||||
y = audio_data.pop(vi)
|
||||
|
||||
timestamps, embeddings, labels = _extract_w2v_targeted(
|
||||
y, _SR, gt_intense, gt_soft, tolerance, neg_margin,
|
||||
model_name=embed_model, gt_negative=gt_negative,
|
||||
)
|
||||
if len(timestamps) == 0:
|
||||
continue
|
||||
# Per-video z-score normalize
|
||||
vid_mean = embeddings.mean(axis=0)
|
||||
vid_std = np.maximum(embeddings.std(axis=0), 1e-6)
|
||||
normed = (embeddings - vid_mean) / vid_std
|
||||
for i in range(len(labels)):
|
||||
if labels[i] == 1:
|
||||
all_X.append(normed[i])
|
||||
all_y.append(1)
|
||||
elif labels[i] == -1:
|
||||
all_X.append(normed[i])
|
||||
all_y.append(0)
|
||||
|
||||
if not all_X:
|
||||
_log("audio_scan: no training samples collected")
|
||||
return None
|
||||
|
||||
X = np.stack(all_X)
|
||||
y_arr = np.array(all_y)
|
||||
n_pos = (y_arr == 1).sum()
|
||||
n_neg = (y_arr == 0).sum()
|
||||
_log(f"audio_scan: training set — {n_pos} positive, {n_neg} negative")
|
||||
|
||||
if n_pos == 0 or n_neg == 0:
|
||||
_log(f"audio_scan: need both classes — {n_pos} pos, {n_neg} neg")
|
||||
return None
|
||||
|
||||
# Subsample negatives for balance
|
||||
rng = np.random.RandomState(42)
|
||||
pos_idx = np.where(y_arr == 1)[0]
|
||||
neg_idx = np.where(y_arr == 0)[0]
|
||||
n_neg_sample = min(len(neg_idx), len(pos_idx) * 3)
|
||||
neg_sample = rng.choice(neg_idx, n_neg_sample, replace=False)
|
||||
train_idx = np.concatenate([pos_idx, neg_sample])
|
||||
rng.shuffle(train_idx)
|
||||
|
||||
_progress(f"Fitting classifier on {len(train_idx)} samples...")
|
||||
clf = HistGradientBoostingClassifier(
|
||||
max_iter=200, max_depth=5, learning_rate=0.1, random_state=42,
|
||||
)
|
||||
clf.fit(X[train_idx], y_arr[train_idx])
|
||||
_log("audio_scan: classifier trained")
|
||||
|
||||
# Calibrate probabilities for better threshold behavior
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
min_class = min(int(n_pos), int(n_neg_sample))
|
||||
if min_class >= 6:
|
||||
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||
clf = cal_clf
|
||||
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||
else:
|
||||
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||
|
||||
model = {"classifier": clf, "n_features": X.shape[1],
|
||||
"embed_model": embed_model or _DEFAULT_EMBED_MODEL}
|
||||
|
||||
if model_path:
|
||||
import joblib
|
||||
from datetime import datetime
|
||||
parent = os.path.dirname(model_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
# Save with timestamp in name; keep a symlink/copy as the "latest"
|
||||
stem, ext = os.path.splitext(model_path)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
versioned = f"{stem}_{ts}{ext}"
|
||||
joblib.dump(model, versioned)
|
||||
_log(f"audio_scan: model saved to {versioned}")
|
||||
# Update the base path to point to latest version (copy)
|
||||
import shutil
|
||||
shutil.copy2(versioned, model_path)
|
||||
_log(f"audio_scan: latest model updated: {model_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_classifier(model_path: str) -> dict | None:
|
||||
"""Load a saved classifier model."""
|
||||
if not os.path.exists(model_path):
|
||||
return None
|
||||
import joblib
|
||||
return joblib.load(model_path)
|
||||
|
||||
|
||||
def default_model_path(profile_name: str = "default",
|
||||
embed_model: str | None = None) -> str:
|
||||
"""Return the path for a profile's classifier model.
|
||||
|
||||
When embed_model is given the file is ``{profile}_{model}.joblib``,
|
||||
otherwise ``{profile}.joblib`` (legacy single-model layout).
|
||||
"""
|
||||
if embed_model:
|
||||
return os.path.join(_MODEL_DIR, f"{profile_name}_{embed_model}.joblib")
|
||||
return os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||
|
||||
|
||||
def list_model_versions(profile_name: str = "default",
|
||||
embed_model: str | None = None) -> list[tuple[str, str]]:
|
||||
"""Return available backup versions for a model, newest first.
|
||||
|
||||
Returns list of (timestamp_label, file_path).
|
||||
The current (active) model is listed first as "current".
|
||||
"""
|
||||
import re
|
||||
current = default_model_path(profile_name, embed_model)
|
||||
stem, ext = os.path.splitext(current)
|
||||
versions: list[tuple[str, str]] = []
|
||||
if os.path.exists(current):
|
||||
versions.append(("current", current))
|
||||
if not os.path.isdir(_MODEL_DIR):
|
||||
return versions
|
||||
pattern = re.compile(re.escape(os.path.basename(stem)) + r"_(\d{8}_\d{6})" + re.escape(ext) + "$")
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
versions.append((m.group(1), os.path.join(_MODEL_DIR, fname)))
|
||||
# Sort backups newest first (after "current")
|
||||
current_entry = versions[:1]
|
||||
backups = sorted(versions[1:], key=lambda v: v[0], reverse=True)
|
||||
return current_entry + backups
|
||||
|
||||
|
||||
def restore_model_version(version_path: str, profile_name: str = "default",
|
||||
embed_model: str | None = None) -> None:
|
||||
"""Restore a backup version as the active model."""
|
||||
import filecmp, shutil
|
||||
from datetime import datetime
|
||||
current = default_model_path(profile_name, embed_model)
|
||||
if version_path == current:
|
||||
return
|
||||
# Back up current before replacing — but only if no identical backup exists
|
||||
if os.path.exists(current):
|
||||
stem, ext = os.path.splitext(current)
|
||||
already_saved = False
|
||||
if os.path.isdir(_MODEL_DIR):
|
||||
import re
|
||||
pat = re.compile(re.escape(os.path.basename(stem)) + r"_\d{8}_\d{6}" + re.escape(ext) + "$")
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
if pat.match(fname):
|
||||
candidate = os.path.join(_MODEL_DIR, fname)
|
||||
if filecmp.cmp(current, candidate, shallow=False):
|
||||
already_saved = True
|
||||
break
|
||||
if not already_saved:
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
shutil.move(current, f"{stem}_{ts}{ext}")
|
||||
shutil.copy2(version_path, current)
|
||||
_log(f"audio_scan: restored {os.path.basename(version_path)} as active model")
|
||||
|
||||
|
||||
def list_trained_models(profile_name: str = "default") -> list[str]:
|
||||
"""Return embedding model keys that have a trained .joblib for *profile_name*.
|
||||
|
||||
Looks for files matching ``{profile}_{KEY}.joblib`` in the models dir.
|
||||
KEY is either a bare embed model name (e.g. ``EAT_LARGE``) or
|
||||
``{MODEL}_{name}`` for user-named variants.
|
||||
"""
|
||||
prefix = f"{profile_name}_"
|
||||
suffix = ".joblib"
|
||||
result = []
|
||||
if not os.path.isdir(_MODEL_DIR):
|
||||
return result
|
||||
for fname in os.listdir(_MODEL_DIR):
|
||||
if fname.startswith(prefix) and fname.endswith(suffix):
|
||||
key = fname[len(prefix):-len(suffix)]
|
||||
if key in _EMBED_MODELS:
|
||||
result.append(key)
|
||||
else:
|
||||
for m in _EMBED_MODELS:
|
||||
if key.startswith(m + "_"):
|
||||
result.append(key)
|
||||
break
|
||||
# Also check legacy {profile}.joblib
|
||||
legacy = os.path.join(_MODEL_DIR, f"{profile_name}.joblib")
|
||||
if os.path.exists(legacy) and not result:
|
||||
result.append("")
|
||||
return sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scanning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fuse_regions(regions: list[tuple[float, float, float]]
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Merge overlapping/adjacent regions, keeping max score."""
|
||||
if not regions:
|
||||
return []
|
||||
by_start = sorted(regions, key=lambda r: r[0])
|
||||
fused: list[tuple[float, float, float]] = []
|
||||
s, e, sc = by_start[0]
|
||||
for s2, e2, sc2 in by_start[1:]:
|
||||
if s2 <= e: # overlapping or touching
|
||||
e = max(e, e2)
|
||||
sc = max(sc, sc2)
|
||||
else:
|
||||
fused.append((s, e, sc))
|
||||
s, e, sc = s2, e2, sc2
|
||||
fused.append((s, e, sc))
|
||||
return fused
|
||||
|
||||
|
||||
def prefetch_audio(video_path: str, embed_model: str | None = None,
|
||||
hop: float = 1.0, window: float = _WINDOW) -> np.ndarray | None:
|
||||
"""Pre-load audio for a video if embeddings aren't cached.
|
||||
|
||||
Returns the raw audio array, or None if cache already exists.
|
||||
Call from a background thread while the GPU is busy with another video.
|
||||
"""
|
||||
if _w2v_cache_exists(video_path, hop, window, embed_model):
|
||||
return None
|
||||
_log(f"audio_scan: prefetching {os.path.basename(video_path)}")
|
||||
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||
_log(f"audio_scan: prefetched {len(y)/_SR:.1f}s")
|
||||
return y
|
||||
|
||||
|
||||
def scan_video(
|
||||
video_path: str,
|
||||
model: dict = None,
|
||||
threshold: float = 0.50,
|
||||
hop: float = 1.0,
|
||||
window: float = _WINDOW,
|
||||
cancel_flag: object = None,
|
||||
prefetched_audio: np.ndarray | None = None,
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Scan a video for matching audio regions using a trained classifier.
|
||||
|
||||
Returns list of (start_time, end_time, score) above threshold.
|
||||
If prefetched_audio is provided, skips the ffmpeg decode step.
|
||||
"""
|
||||
if model is None:
|
||||
_log("audio_scan: no model provided")
|
||||
return []
|
||||
|
||||
clf = model["classifier"]
|
||||
embed_model = model.get("embed_model")
|
||||
|
||||
# Try cache first — skip expensive audio loading if embeddings exist
|
||||
cached = _w2v_cache_load(video_path, hop, window, embed_model)
|
||||
if cached is not None:
|
||||
timestamps, window_vectors = cached
|
||||
else:
|
||||
if prefetched_audio is not None:
|
||||
_log(f"audio_scan: using prefetched audio")
|
||||
y = prefetched_audio
|
||||
else:
|
||||
_log(f"audio_scan: loading {video_path}")
|
||||
y = _load_audio_ffmpeg(video_path, sr=_SR)
|
||||
sr = _SR
|
||||
_log(f"audio_scan: {len(y)/sr:.1f}s loaded")
|
||||
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return []
|
||||
|
||||
_log(f"audio_scan: extracting embeddings ({embed_model or 'default'})...")
|
||||
timestamps, window_vectors = _extract_w2v_windows(
|
||||
y, sr, hop=hop, window=window, video_path=video_path,
|
||||
cancel_flag=cancel_flag, model_name=embed_model,
|
||||
)
|
||||
|
||||
if len(timestamps) == 0:
|
||||
_log("audio_scan: video shorter than window")
|
||||
return []
|
||||
|
||||
# Per-video z-score normalize
|
||||
vid_mean = window_vectors.mean(axis=0)
|
||||
vid_std = np.maximum(window_vectors.std(axis=0), 1e-6)
|
||||
normed = (window_vectors - vid_mean) / vid_std
|
||||
|
||||
_log(f"audio_scan: classifying {len(normed)} windows...")
|
||||
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
return []
|
||||
|
||||
probs = clf.predict_proba(normed)[:, 1]
|
||||
mask = probs >= threshold
|
||||
raw = [
|
||||
(timestamps[i], timestamps[i] + window, float(probs[i]))
|
||||
for i in np.nonzero(mask)[0]
|
||||
]
|
||||
results = _fuse_regions(raw)
|
||||
_log(f"audio_scan: {len(results)} regions above threshold {threshold} (from {len(raw)} raw)")
|
||||
return results
|
||||
@@ -0,0 +1,783 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm, Parameter
|
||||
from .beats_modules import (
|
||||
GradMultiply,
|
||||
SamePad,
|
||||
get_activation_fn,
|
||||
GLU_Linear,
|
||||
quant_noise,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
|
||||
self.pos_conv = nn.Conv1d(
|
||||
self.embedding_dim,
|
||||
self.embedding_dim,
|
||||
kernel_size=args.conv_pos,
|
||||
padding=args.conv_pos // 2,
|
||||
groups=args.conv_pos_groups,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
self.relative_position_embedding = args.relative_position_embedding
|
||||
self.num_buckets = args.num_buckets
|
||||
self.max_distance = args.max_distance
|
||||
else:
|
||||
self.relative_position_embedding = False
|
||||
self.num_buckets = 0
|
||||
self.max_distance = 0
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
deep_norm=args.deep_norm,
|
||||
has_relative_attention_bias=self.relative_position_embedding,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
gru_rel_pos=args.gru_rel_pos,
|
||||
encoder_layers=args.encoder_layers,
|
||||
)
|
||||
for i in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
if self.relative_position_embedding:
|
||||
for i in range(1, args.encoder_layers):
|
||||
del self.layers[i].self_attn.relative_attention_bias
|
||||
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
if args.deep_norm:
|
||||
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||
for i in range(args.encoder_layers):
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||
|
||||
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
z = None
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
r = None
|
||||
pos_bias = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||
dropout_probability = np.random.random()
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: float = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
deep_norm: bool = False,
|
||||
has_relative_attention_bias: bool = False,
|
||||
num_buckets: int = 0,
|
||||
max_distance: int = 0,
|
||||
rescale_init: bool = False,
|
||||
gru_rel_pos: bool = False,
|
||||
encoder_layers: int = 0,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
self.activation_name = activation_fn
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
num_buckets=num_buckets,
|
||||
max_distance=max_distance,
|
||||
rescale_init=rescale_init,
|
||||
gru_rel_pos=gru_rel_pos,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
if self.activation_name == "glu":
|
||||
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||
else:
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
self.deep_norm = deep_norm
|
||||
if self.deep_norm:
|
||||
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||
else:
|
||||
self.deep_norm_alpha = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
pos_bias=None
|
||||
):
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, pos_bias
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
has_relative_attention_bias=False,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
gru_rel_pos=False,
|
||||
rescale_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_head_dim = self.head_dim
|
||||
self.k_head_dim = self.head_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
k_bias = True
|
||||
if rescale_init:
|
||||
k_bias = False
|
||||
|
||||
k_embed_dim = embed_dim
|
||||
q_embed_dim = embed_dim
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.gru_rel_pos = gru_rel_pos
|
||||
if self.gru_rel_pos:
|
||||
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
if self.has_relative_attention_bias:
|
||||
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||
|
||||
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||
num_buckets = self.num_buckets
|
||||
max_distance = self.max_distance
|
||||
relative_buckets = 0
|
||||
|
||||
if bidirectional:
|
||||
num_buckets = num_buckets // 2
|
||||
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||
relative_positions = torch.abs(relative_positions)
|
||||
else:
|
||||
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_positions < max_exact
|
||||
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_positions.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_positions_bucket(
|
||||
relative_position,
|
||||
bidirectional=True
|
||||
)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
values = values.permute([2, 0, 1])
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
position_bias: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if self.has_relative_attention_bias and position_bias is None:
|
||||
position_bias = self.compute_bias(tgt_len, src_len)
|
||||
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
alpha = 32
|
||||
q *= 1 / alpha
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||
key_padding_mask
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v, position_bias
|
||||
|
||||
if position_bias is not None:
|
||||
attn_mask_rel_pos = position_bias
|
||||
if self.gru_rel_pos == 1:
|
||||
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||
|
||||
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||
|
||||
attn_weights = attn_weights + attn_mask_rel_pos
|
||||
|
||||
attn_weights_float = F.softmax(
|
||||
attn_weights, dim=-1
|
||||
)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights, position_bias
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(
|
||||
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
@@ -0,0 +1,179 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from .beats_backbone import (
|
||||
TransformerEncoder,
|
||||
)
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEATsConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.input_patch_size: int = -1 # path size of patch embedding
|
||||
self.embed_dim: int = 512 # patch embedding dimension
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
# label predictor
|
||||
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||
self.predictor_class: int = 527 # target class number for the predictor
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class BEATs(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: BEATsConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
self.embed = cfg.embed_dim
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.input_patch_size = cfg.input_patch_size
|
||||
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||
bias=cfg.conv_bias)
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
|
||||
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
if cfg.finetuned_model:
|
||||
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||
else:
|
||||
self.predictor = None
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(
|
||||
padding_mask.size(0), features.size(1), -1
|
||||
)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
) -> torch.Tensor:
|
||||
fbanks = []
|
||||
for waveform in source:
|
||||
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||
fbanks.append(fbank)
|
||||
fbank = torch.stack(fbanks, dim=0)
|
||||
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||
return fbank
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
):
|
||||
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||
|
||||
fbank = fbank.unsqueeze(1)
|
||||
features = self.patch_embedding(fbank)
|
||||
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
x = self.dropout_input(features)
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if self.predictor is not None:
|
||||
x = self.predictor_dropout(x)
|
||||
logits = self.predictor(x)
|
||||
|
||||
if padding_mask is not None and padding_mask.any():
|
||||
logits[padding_mask] = 0
|
||||
logits = logits.sum(dim=1)
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||
else:
|
||||
logits = logits.mean(dim=1)
|
||||
|
||||
lprobs = torch.sigmoid(logits)
|
||||
|
||||
return lprobs, padding_mask
|
||||
else:
|
||||
return x, padding_mask
|
||||
@@ -0,0 +1,219 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
self.act = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.act(x)
|
||||
|
||||
|
||||
class GLU_Linear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||
super(GLU_Linear, self).__init__()
|
||||
|
||||
self.glu_type = glu_type
|
||||
self.output_dim = output_dim
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = torch.nn.Sigmoid()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = torch.nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = torch.nn.GELU()
|
||||
|
||||
if bias_in_glu:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||
else:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||
|
||||
def forward(self, x):
|
||||
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||
x = self.linear(x)
|
||||
|
||||
if self.glu_type == "bilinear":
|
||||
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||
else:
|
||||
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
warnings.warn(
|
||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||
)
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "glu":
|
||||
return lambda x: x
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features, device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from .ffmpeg import build_ffmpeg_command, build_audio_extract_command
|
||||
from .paths import _log
|
||||
|
||||
|
||||
class ExportRunner:
|
||||
"""Run ffmpeg export jobs in a background thread pool.
|
||||
|
||||
Callbacks:
|
||||
on_clip_done(path: str)
|
||||
on_all_done()
|
||||
on_error(msg: str)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
jobs: list[tuple[float, str, str | None, float]],
|
||||
short_side: int | None = None,
|
||||
image_sequence: bool = False,
|
||||
max_workers: int | None = None,
|
||||
encoder: str = "libx264",
|
||||
on_clip_done: Callable[[str], None] | None = None,
|
||||
on_all_done: Callable[[], None] | None = None,
|
||||
on_error: Callable[[str], None] | None = None,
|
||||
):
|
||||
self._input = input_path
|
||||
self._jobs = jobs
|
||||
self._short_side = short_side
|
||||
self._image_sequence = image_sequence
|
||||
self._max_workers = max_workers
|
||||
self._encoder = encoder
|
||||
self._on_clip_done = on_clip_done
|
||||
self._on_all_done = on_all_done
|
||||
self._on_error = on_error
|
||||
self._cancel = False
|
||||
self._procs: list[subprocess.Popen] = []
|
||||
self._procs_lock = threading.Lock()
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def cancel(self):
|
||||
self._cancel = True
|
||||
with self._procs_lock:
|
||||
for proc in self._procs:
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return self._thread is not None and self._thread.is_alive()
|
||||
|
||||
def _run_one(self, start: float, output: str,
|
||||
portrait_ratio: str | None, crop_center: float) -> str:
|
||||
if self._cancel:
|
||||
raise RuntimeError("cancelled")
|
||||
if self._image_sequence:
|
||||
os.makedirs(output, exist_ok=True)
|
||||
cmd = build_ffmpeg_command(
|
||||
self._input, start, output,
|
||||
short_side=self._short_side,
|
||||
portrait_ratio=portrait_ratio,
|
||||
crop_center=crop_center,
|
||||
image_sequence=self._image_sequence,
|
||||
encoder=self._encoder,
|
||||
)
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
with self._procs_lock:
|
||||
self._procs.append(proc)
|
||||
try:
|
||||
_, stderr = proc.communicate(timeout=120)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
raise RuntimeError("ffmpeg timed out")
|
||||
finally:
|
||||
with self._procs_lock:
|
||||
self._procs.remove(proc)
|
||||
if self._cancel:
|
||||
raise RuntimeError("cancelled")
|
||||
if proc.returncode != 0:
|
||||
msg = stderr.decode(errors='replace')[-500:] if stderr else "ffmpeg failed"
|
||||
raise RuntimeError(msg)
|
||||
if self._image_sequence:
|
||||
audio_cmd = build_audio_extract_command(self._input, start, output)
|
||||
audio_result = subprocess.run(audio_cmd, capture_output=True, text=True, timeout=60)
|
||||
if audio_result.returncode != 0:
|
||||
msg = (audio_result.stderr or "audio extraction failed")[-500:]
|
||||
raise RuntimeError(msg)
|
||||
return output
|
||||
|
||||
def _run(self):
|
||||
cap = self._max_workers or (os.cpu_count() or 2)
|
||||
workers = min(len(self._jobs), cap)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(self._run_one, s, o, pr, cc): o
|
||||
for s, o, pr, cc in self._jobs
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
if self._cancel:
|
||||
break
|
||||
try:
|
||||
path = fut.result()
|
||||
if self._on_clip_done:
|
||||
self._on_clip_done(path)
|
||||
except Exception as e:
|
||||
if "cancelled" not in str(e) and self._on_error:
|
||||
self._on_error(str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
if self._on_error:
|
||||
self._on_error(str(e))
|
||||
return
|
||||
if self._cancel:
|
||||
return
|
||||
if self._on_all_done:
|
||||
self._on_all_done()
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from .paths import _bin, _log
|
||||
|
||||
@@ -63,6 +64,13 @@ def apply_keyframes_to_jobs(
|
||||
return result
|
||||
|
||||
|
||||
def _find_vaapi_device() -> str:
|
||||
"""Return the first available VAAPI render device path (Linux)."""
|
||||
import glob
|
||||
devices = sorted(glob.glob("/dev/dri/renderD*"))
|
||||
return devices[0] if devices else "/dev/dri/renderD128"
|
||||
|
||||
|
||||
def build_ffmpeg_command(
|
||||
input_path: str, start: float, output_path: str,
|
||||
short_side: int | None = None,
|
||||
@@ -70,23 +78,29 @@ def build_ffmpeg_command(
|
||||
crop_center: float = 0.5,
|
||||
image_sequence: bool = False,
|
||||
encoder: str = "libx264",
|
||||
duration: float = 8.0,
|
||||
target_fps: float | None = None,
|
||||
snap32: bool = False,
|
||||
frames: int | None = None,
|
||||
) -> list[str]:
|
||||
# -ss before -i: fast input-seeking. Safe here because we always re-encode,
|
||||
# so there is no keyframe-alignment issue from pre-input seek.
|
||||
# Image sequences always use libwebp, so skip HW encoder setup.
|
||||
use_hw_vaapi = encoder == "h264_vaapi" and not image_sequence
|
||||
use_hw_vaapi = (encoder == "h264_vaapi" and not image_sequence
|
||||
and sys.platform == "linux")
|
||||
cmd = [_bin("ffmpeg"), "-y"]
|
||||
|
||||
# VAAPI needs a device for hardware context.
|
||||
# VAAPI needs a render device for hardware context (Linux only).
|
||||
if use_hw_vaapi:
|
||||
vaapi_dev = _find_vaapi_device()
|
||||
cmd += ["-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi",
|
||||
"-vaapi_device", "/dev/dri/renderD128"]
|
||||
"-vaapi_device", vaapi_dev]
|
||||
|
||||
cmd += [
|
||||
"-threads", "0",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", "8",
|
||||
"-t", str(duration),
|
||||
]
|
||||
|
||||
filters: list[str] = []
|
||||
@@ -98,6 +112,13 @@ def build_ffmpeg_command(
|
||||
f"scale='if(lt(iw,ih),{short_side},-2)':'if(lt(iw,ih),-2,{short_side})':flags=lanczos"
|
||||
)
|
||||
|
||||
# LTX-2: centered crop to ÷32 (no rescale → no aspect distortion) then fps.
|
||||
# Placed among CPU filters, after scale and before the VAAPI hwupload block.
|
||||
if snap32:
|
||||
filters.append("crop=trunc(iw/32)*32:trunc(ih/32)*32")
|
||||
if target_fps is not None:
|
||||
filters.append(f"fps={target_fps:g}")
|
||||
|
||||
# VAAPI: decoded frames are GPU surfaces. CPU filters need hwdownload first.
|
||||
if use_hw_vaapi:
|
||||
if filters:
|
||||
@@ -109,6 +130,12 @@ def build_ffmpeg_command(
|
||||
if filters:
|
||||
cmd += ["-vf", ",".join(filters)]
|
||||
|
||||
# LTX-2 output rate + exact frame cap (apply to both clip and webp-seq paths).
|
||||
if target_fps is not None:
|
||||
cmd += ["-r", f"{target_fps:g}"]
|
||||
if frames is not None:
|
||||
cmd += ["-frames:v", str(frames)]
|
||||
|
||||
if image_sequence:
|
||||
cmd += [
|
||||
"-an",
|
||||
@@ -118,27 +145,93 @@ def build_ffmpeg_command(
|
||||
os.path.join(output_path, "frame_%04d.webp"),
|
||||
]
|
||||
else:
|
||||
cmd += ["-c:v", encoder, "-c:a", "pcm_s16le", output_path]
|
||||
cmd += ["-c:v", encoder]
|
||||
if "nvenc" in encoder:
|
||||
cmd += ["-preset", "p4", "-cq", "28"]
|
||||
elif "vaapi" in encoder:
|
||||
cmd += ["-qp", "28"]
|
||||
elif "qsv" in encoder:
|
||||
cmd += ["-global_quality", "28"]
|
||||
elif "amf" in encoder:
|
||||
cmd += ["-qp_i", "28", "-qp_p", "28"]
|
||||
cmd += ["-c:a", "pcm_s16le", output_path]
|
||||
return cmd
|
||||
|
||||
|
||||
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str) -> list[str]:
|
||||
def build_audio_extract_command(input_path: str, start: float, sequence_dir: str,
|
||||
duration: float = 8.0) -> list[str]:
|
||||
"""Return an ffmpeg command that extracts audio to <sequence_dir>.wav."""
|
||||
audio_path = sequence_dir + ".wav"
|
||||
return [
|
||||
_bin("ffmpeg"), "-y",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", "8",
|
||||
"-t", str(duration),
|
||||
"-vn",
|
||||
"-c:a", "pcm_s16le",
|
||||
audio_path,
|
||||
]
|
||||
|
||||
|
||||
# Audio codec chosen per output extension for the manual "Extract audio area"
|
||||
# tool. Empty list -> let ffmpeg pick a default encoder from the extension.
|
||||
_AUDIO_CODEC_BY_EXT: dict[str, list[str]] = {
|
||||
".wav": ["-c:a", "pcm_s16le"],
|
||||
".flac": ["-c:a", "flac"],
|
||||
".mp3": ["-c:a", "libmp3lame", "-q:a", "2"],
|
||||
".m4a": ["-c:a", "aac", "-b:a", "256k"],
|
||||
".aac": ["-c:a", "aac", "-b:a", "256k"],
|
||||
".ogg": ["-c:a", "libvorbis", "-q:a", "5"],
|
||||
".opus": ["-c:a", "libopus", "-b:a", "192k"],
|
||||
}
|
||||
|
||||
|
||||
def probe_duration(path: str) -> float | None:
|
||||
"""Return the media duration in seconds via ffprobe, or None on failure."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[_bin("ffprobe"), "-v", "error", "-show_entries", "format=duration",
|
||||
"-of", "default=nw=1:nk=1", path],
|
||||
capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if r.returncode == 0 and r.stdout.strip():
|
||||
return float(r.stdout.strip())
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def build_audio_clip_command(input_path: str, start: float, duration: float,
|
||||
out_path: str) -> list[str]:
|
||||
"""ffmpeg command to extract exactly *duration* seconds of audio starting
|
||||
at *start*, re-encoded per *out_path*'s extension (wav/mp3/flac/…)."""
|
||||
ext = os.path.splitext(out_path)[1].lower()
|
||||
codec = _AUDIO_CODEC_BY_EXT.get(ext, [])
|
||||
return [
|
||||
_bin("ffmpeg"), "-y",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", str(duration),
|
||||
"-vn",
|
||||
*codec,
|
||||
out_path,
|
||||
]
|
||||
|
||||
|
||||
def detect_hw_encoders() -> list[str]:
|
||||
"""Probe ffmpeg for available H.264 hardware encoders."""
|
||||
_HW_ENCODERS = ["h264_nvenc", "h264_vaapi", "h264_qsv", "h264_amf", "h264_videotoolbox"]
|
||||
"""Probe ffmpeg for available H.264 hardware encoders.
|
||||
|
||||
Returns only encoders relevant to the current platform:
|
||||
- Windows: h264_nvenc, h264_qsv, h264_amf
|
||||
- Linux: h264_nvenc, h264_vaapi, h264_qsv
|
||||
- macOS: h264_videotoolbox
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
candidates = ["h264_nvenc", "h264_qsv", "h264_amf"]
|
||||
elif sys.platform == "darwin":
|
||||
candidates = ["h264_videotoolbox"]
|
||||
else:
|
||||
candidates = ["h264_nvenc", "h264_vaapi", "h264_qsv"]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_bin("ffmpeg"), "-hide_banner", "-encoders"],
|
||||
@@ -149,10 +242,7 @@ def detect_hw_encoders() -> list[str]:
|
||||
output = result.stdout
|
||||
except Exception:
|
||||
return []
|
||||
available = []
|
||||
for enc in _HW_ENCODERS:
|
||||
if re.search(rf'\b{enc}\b', output):
|
||||
available.append(enc)
|
||||
available = [enc for enc in candidates if re.search(rf'\b{enc}\b', output)]
|
||||
if available:
|
||||
_log(f"HW encoders detected: {', '.join(available)}")
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""LTX-2 frame-count math. Legal F satisfy F % 8 == 1 (8x temporal + 1)."""
|
||||
|
||||
|
||||
def is_legal_frames(f: int) -> bool:
|
||||
return f >= 9 and f % 8 == 1
|
||||
|
||||
|
||||
def legal_frames(min_f: int = 9, max_f: int = 1000) -> list[int]:
|
||||
start = max(9, min_f + ((1 - min_f) % 8)) # first 8k+1 >= min_f
|
||||
return list(range(start, max_f + 1, 8))
|
||||
|
||||
|
||||
def nearest_legal_frames(f: int) -> int:
|
||||
if f <= 9:
|
||||
return 9
|
||||
low = ((f - 1) // 8) * 8 + 1
|
||||
high = low + 8
|
||||
return low if (f - low) <= (high - f) else high
|
||||
|
||||
|
||||
def duration_for_frames(frames: int, fps: float) -> float:
|
||||
return frames / fps
|
||||
|
||||
|
||||
def frames_for_duration(duration: float, fps: float) -> int:
|
||||
return nearest_legal_frames(round(duration * fps))
|
||||
@@ -24,16 +24,26 @@ def _log(*args) -> None:
|
||||
print(f"[8-cut {ts}]", *args, file=sys.stderr)
|
||||
|
||||
|
||||
def build_export_path(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||
group = f"{basename}_{counter:03d}"
|
||||
name = f"{group}_{sub}" if sub is not None else group
|
||||
return os.path.join(folder, group, name + ".mp4")
|
||||
def build_export_path(folder: str, basename: str, counter: int,
|
||||
sub: int | None = None, tag: str | None = None) -> str:
|
||||
"""Build clip output path. *folder* should be the vid folder (e.g. .../mp4/vid_001)."""
|
||||
name = f"{basename}_{counter:03d}"
|
||||
if tag is not None:
|
||||
name = f"{name}_{tag}"
|
||||
if sub is not None:
|
||||
name = f"{name}_{sub}"
|
||||
return os.path.join(folder, name + ".mp4")
|
||||
|
||||
|
||||
def build_sequence_dir(folder: str, basename: str, counter: int, sub: int | None = None) -> str:
|
||||
group = f"{basename}_{counter:03d}"
|
||||
name = f"{group}_{sub}" if sub is not None else group
|
||||
return os.path.join(folder, group, name)
|
||||
def build_sequence_dir(folder: str, basename: str, counter: int,
|
||||
sub: int | None = None, tag: str | None = None) -> str:
|
||||
"""Build WebP sequence output dir. *folder* should be the vid folder."""
|
||||
name = f"{basename}_{counter:03d}"
|
||||
if tag is not None:
|
||||
name = f"{name}_{tag}"
|
||||
if sub is not None:
|
||||
name = f"{name}_{sub}"
|
||||
return os.path.join(folder, name)
|
||||
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
services:
|
||||
8cut:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- /path/to/videos:/videos:ro
|
||||
- /path/to/exports:/exports
|
||||
- 8cut-data:/data
|
||||
environment:
|
||||
MEDIA_DIRS: /videos
|
||||
EXPORT_DIR: /exports
|
||||
DB_PATH: /data/8cut.db
|
||||
CACHE_DIR: /data/cache
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
|
||||
volumes:
|
||||
8cut-data:
|
||||
@@ -0,0 +1,97 @@
|
||||
# Audio Similarity Scanning — Design
|
||||
|
||||
**Goal:** Scan a video's audio track and highlight segments that match the sound profile of existing reference clips, so the user can quickly find similar moments without scrubbing manually.
|
||||
|
||||
**Runs in:** Python/Qt client (`main.py`), not the server.
|
||||
|
||||
---
|
||||
|
||||
## Core Module: `core/audio_scan.py`
|
||||
|
||||
New module alongside `core/tracking.py`. Two main functions:
|
||||
|
||||
- `build_profile(clip_paths: list[str]) -> dict` — extracts MFCCs (20 coefficients) from each clip using `librosa`, returns a profile containing both the averaged vector and individual clip vectors.
|
||||
- `scan_video(video_path: str, profile: dict, mode: str, threshold: float, hop: float) -> list[tuple[float, float, float]]` — slides an 8s window across the video's audio, returns `(start_time, end_time, score)` tuples for segments above threshold.
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
- Audio loaded via `librosa.load()` (handles video files directly, mono, 22050Hz).
|
||||
- MFCCs: `librosa.feature.mfcc(n_mfcc=20)`, averaged over time axis to produce a single vector per window/clip.
|
||||
- Similarity: cosine similarity (`numpy` dot product on L2-normalized vectors).
|
||||
|
||||
### Matching Modes
|
||||
|
||||
- **Average mode:** Compare each window to the mean of all reference MFCC vectors. Fast, good when references are homogeneous.
|
||||
- **Nearest mode:** Compare each window to every reference vector, take the max score. Better when references have variety within the style.
|
||||
|
||||
### Parameters
|
||||
|
||||
- `threshold` (float, 0.0–1.0): minimum cosine similarity to include a segment. Default 0.7.
|
||||
- `hop` (float, seconds): step size for the sliding window. Default 1.0s.
|
||||
- Window size fixed at 8s to match reference clip length.
|
||||
|
||||
---
|
||||
|
||||
## UI Integration in `main.py`
|
||||
|
||||
### Controls
|
||||
|
||||
Added near the existing tracking checkbox area:
|
||||
|
||||
- **"Scan" button** — triggers audio scan on current video.
|
||||
- **Threshold slider** (0.0–1.0, step 0.05) — controls match strictness.
|
||||
- **Mode combobox** — "Average" / "Nearest".
|
||||
- **Reference source combobox** — "Current Profile" / "Custom Folder" (shows folder picker when "Custom Folder" selected).
|
||||
|
||||
### Scan Workflow
|
||||
|
||||
1. User clicks Scan.
|
||||
2. Reference clips collected: either all export `output_path` values from the current profile (via DB) or all audio/video files in a custom folder.
|
||||
3. Scan runs in a `QThread` so UI stays responsive.
|
||||
4. On completion, results sent to Timeline widget via signal.
|
||||
|
||||
### Timeline Display
|
||||
|
||||
- New `set_scan_regions(regions: list[tuple[float, float, float]])` method on Timeline.
|
||||
- Drawn as semi-transparent colored rectangles behind existing markers.
|
||||
- Color intensity proportional to score (brighter = higher match).
|
||||
- Cleared on file change or re-scan.
|
||||
|
||||
### Keyboard Shortcut
|
||||
|
||||
- `S` — jump cursor to the next scan region (similar to `M` for next marker).
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
Reference clips (DB export paths or folder)
|
||||
|
|
||||
librosa.load() each -> MFCC vectors (20-dim)
|
||||
|
|
||||
Profile: { mean_vector, clip_vectors[] }
|
||||
|
|
||||
Current video -> librosa.load() full audio (mono 22050Hz)
|
||||
|
|
||||
Sliding 8s window (hop=1s) -> MFCC per window
|
||||
|
|
||||
Cosine similarity vs profile -> score per position
|
||||
|
|
||||
Threshold filter -> [(start, end, score), ...]
|
||||
|
|
||||
Timeline: semi-transparent highlight regions
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
- 2-hour video at 22050Hz mono ~ 380MB memory.
|
||||
- MFCC extraction + sliding window: ~10-30s.
|
||||
- QThread keeps UI responsive.
|
||||
|
||||
## What This Does NOT Do
|
||||
|
||||
- No DB schema changes — scan results are ephemeral (visual only).
|
||||
- No auto-export — user decides what to cut.
|
||||
- No server integration — runs entirely in the Python client.
|
||||
- No GPU/ML model dependency — just librosa + numpy.
|
||||
@@ -0,0 +1,739 @@
|
||||
# Audio Similarity Scanning — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Scan a video's audio track to find segments matching a reference sound profile, displayed as highlighted regions on the timeline.
|
||||
|
||||
**Architecture:** New `core/audio_scan.py` module extracts MFCC features from reference clips and slides an 8s window across the target video's audio, scoring each position via cosine similarity. A `ScanWorker` QThread runs the scan in the background, and results are drawn as semi-transparent rectangles on the existing Timeline widget.
|
||||
|
||||
**Tech Stack:** Python 3, librosa 0.11, numpy, PyQt6
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Core audio_scan module — build_profile
|
||||
|
||||
**Files:**
|
||||
- Create: `core/audio_scan.py`
|
||||
- Create: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the tests**
|
||||
|
||||
```python
|
||||
# tests/test_audio_scan.py
|
||||
import tempfile, os
|
||||
import numpy as np
|
||||
from core.audio_scan import build_profile, _extract_mfcc
|
||||
|
||||
|
||||
def _make_wav(path: str, duration: float = 8.0, sr: int = 22050):
|
||||
"""Create a short sine-wave WAV file for testing."""
|
||||
import soundfile as sf
|
||||
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
|
||||
audio = 0.5 * np.sin(2 * np.pi * 440 * t)
|
||||
sf.write(path, audio, sr)
|
||||
|
||||
|
||||
def test_extract_mfcc_returns_1d_vector():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
vec = _extract_mfcc(f.name)
|
||||
assert vec.shape == (20,)
|
||||
assert not np.isnan(vec).any()
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_single_clip():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
profile = build_profile([f.name])
|
||||
assert "mean_vector" in profile
|
||||
assert "clip_vectors" in profile
|
||||
assert profile["mean_vector"].shape == (20,)
|
||||
assert len(profile["clip_vectors"]) == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_multiple_clips():
|
||||
paths = []
|
||||
try:
|
||||
for i in range(3):
|
||||
f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
freq = 440 + i * 200
|
||||
import soundfile as sf
|
||||
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||
sf.write(f.name, 0.5 * np.sin(2 * np.pi * freq * t), 22050)
|
||||
paths.append(f.name)
|
||||
f.close()
|
||||
|
||||
profile = build_profile(paths)
|
||||
assert len(profile["clip_vectors"]) == 3
|
||||
assert profile["mean_vector"].shape == (20,)
|
||||
finally:
|
||||
for p in paths:
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
def test_build_profile_skips_missing_files():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
_make_wav(f.name)
|
||||
try:
|
||||
profile = build_profile([f.name, "/no/such/file.wav"])
|
||||
assert len(profile["clip_vectors"]) == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_build_profile_empty_returns_none():
|
||||
result = build_profile([])
|
||||
assert result is None
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: FAIL with `ModuleNotFoundError: No module named 'core.audio_scan'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
```python
|
||||
# core/audio_scan.py
|
||||
"""Audio similarity scanning — MFCC-based profile matching."""
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
from .paths import _log
|
||||
|
||||
_N_MFCC = 20
|
||||
_SR = 22050
|
||||
|
||||
|
||||
def _extract_mfcc(path: str, sr: int = _SR) -> np.ndarray:
|
||||
"""Load audio from a file and return a mean MFCC vector (20-dim)."""
|
||||
y, _ = librosa.load(path, sr=sr, mono=True)
|
||||
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=_N_MFCC)
|
||||
return mfcc.mean(axis=1) # average over time → (20,)
|
||||
|
||||
|
||||
def build_profile(clip_paths: list[str]) -> dict | None:
|
||||
"""Extract MFCCs from reference clips.
|
||||
|
||||
Returns dict with:
|
||||
- mean_vector: averaged MFCC across all clips (20,)
|
||||
- clip_vectors: list of individual MFCC vectors
|
||||
Returns None if no clips could be loaded.
|
||||
"""
|
||||
vectors = []
|
||||
for p in clip_paths:
|
||||
try:
|
||||
vec = _extract_mfcc(p)
|
||||
vectors.append(vec)
|
||||
except Exception as e:
|
||||
_log(f"audio_scan: skip {p}: {e}")
|
||||
if not vectors:
|
||||
return None
|
||||
arr = np.stack(vectors)
|
||||
return {
|
||||
"mean_vector": arr.mean(axis=0),
|
||||
"clip_vectors": vectors,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: all 5 PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add audio_scan module with build_profile"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Core audio_scan module — scan_video
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py`
|
||||
- Modify: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the tests**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
from core.audio_scan import scan_video
|
||||
|
||||
|
||||
def test_scan_video_finds_matching_region():
|
||||
"""A video made of the same sine wave as the reference should match."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
_make_wav(ref.name, duration=8.0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
_make_wav(vid.name, duration=20.0)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="average", threshold=0.5, hop=1.0)
|
||||
assert len(regions) > 0
|
||||
for start, end, score in regions:
|
||||
assert abs((end - start) - 8.0) < 1e-9
|
||||
assert score >= 0.5
|
||||
assert score >= 0.5
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
|
||||
|
||||
def test_scan_video_nearest_mode():
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
_make_wav(ref.name, duration=8.0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
_make_wav(vid.name, duration=20.0)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="nearest", threshold=0.5, hop=1.0)
|
||||
assert len(regions) > 0
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
|
||||
|
||||
def test_scan_video_high_threshold_no_match():
|
||||
"""Different frequencies with very high threshold should not match."""
|
||||
import soundfile as sf
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as ref:
|
||||
t = np.linspace(0, 8.0, 22050 * 8, endpoint=False)
|
||||
sf.write(ref.name, 0.5 * np.sin(2 * np.pi * 440 * t), 22050)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vid:
|
||||
# White noise — very different from sine wave
|
||||
sf.write(vid.name, np.random.randn(22050 * 20).astype(np.float32) * 0.1, 22050)
|
||||
try:
|
||||
profile = build_profile([ref.name])
|
||||
regions = scan_video(vid.name, profile, mode="average", threshold=0.99, hop=1.0)
|
||||
assert len(regions) == 0
|
||||
finally:
|
||||
os.unlink(ref.name)
|
||||
os.unlink(vid.name)
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_scan_video_finds_matching_region -v`
|
||||
Expected: FAIL with `ImportError: cannot import name 'scan_video'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
Add to `core/audio_scan.py`:
|
||||
|
||||
```python
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Cosine similarity between two vectors.
|
||||
|
||||
Returns value in [-1, 1]. Negative means anti-correlated (very
|
||||
dissimilar). For threshold filtering this is fine — negative scores
|
||||
never exceed the threshold. Scores near 0 may be uncorrelated or
|
||||
weakly anti-correlated.
|
||||
"""
|
||||
na = np.linalg.norm(a)
|
||||
nb = np.linalg.norm(b)
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return float(np.dot(a, b) / (na * nb))
|
||||
|
||||
|
||||
def scan_video(
|
||||
video_path: str,
|
||||
profile: dict,
|
||||
mode: str = "average",
|
||||
threshold: float = 0.7,
|
||||
hop: float = 1.0,
|
||||
window: float = 8.0,
|
||||
cancel_flag: object = None,
|
||||
) -> list[tuple[float, float, float]]:
|
||||
"""Slide a window across the video audio and score against the profile.
|
||||
|
||||
Args:
|
||||
video_path: path to video/audio file
|
||||
profile: dict from build_profile()
|
||||
mode: "average" (compare to mean) or "nearest" (max over all clips)
|
||||
threshold: minimum cosine similarity to include
|
||||
hop: step size in seconds
|
||||
window: window size in seconds (default 8s)
|
||||
cancel_flag: object with _cancel bool attribute; checked each iteration
|
||||
|
||||
Returns:
|
||||
list of (start_time, end_time, score) for regions above threshold
|
||||
"""
|
||||
_log(f"audio_scan: loading {video_path}")
|
||||
y, sr = librosa.load(video_path, sr=_SR, mono=True)
|
||||
duration = len(y) / sr
|
||||
_log(f"audio_scan: {duration:.1f}s loaded, scanning with hop={hop}s")
|
||||
|
||||
win_samples = int(window * sr)
|
||||
hop_samples = int(hop * sr)
|
||||
|
||||
results = []
|
||||
pos = 0
|
||||
while pos + win_samples <= len(y):
|
||||
if cancel_flag and getattr(cancel_flag, '_cancel', False):
|
||||
_log("audio_scan: cancelled")
|
||||
return results
|
||||
|
||||
chunk = y[pos : pos + win_samples]
|
||||
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=_N_MFCC)
|
||||
vec = mfcc.mean(axis=1)
|
||||
|
||||
if mode == "nearest":
|
||||
score = max(
|
||||
_cosine_similarity(vec, cv) for cv in profile["clip_vectors"]
|
||||
)
|
||||
else: # average
|
||||
score = _cosine_similarity(vec, profile["mean_vector"])
|
||||
|
||||
if score >= threshold:
|
||||
start_t = pos / sr
|
||||
results.append((start_t, start_t + window, score))
|
||||
|
||||
pos += hop_samples
|
||||
|
||||
_log(f"audio_scan: {len(results)} regions above threshold {threshold}")
|
||||
return results
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py -v`
|
||||
Expected: all 8 PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add scan_video with average and nearest modes"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Timeline — draw scan regions
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (Timeline class, around lines 209-260 and 300-375)
|
||||
|
||||
**Step 1: Add scan region storage to Timeline.__init__**
|
||||
|
||||
In `main.py`, find the Timeline class `__init__` method (around line 198). After `self._markers` initialization (line 209), add:
|
||||
|
||||
```python
|
||||
self._scan_regions: list[tuple[float, float, float]] = [] # (start, end, score)
|
||||
```
|
||||
|
||||
**Step 2: Add set_scan_regions method**
|
||||
|
||||
After the `set_markers` method (line 249-252), add:
|
||||
|
||||
```python
|
||||
def set_scan_regions(self, regions: list[tuple[float, float, float]]) -> None:
|
||||
"""regions: list of (start_time, end_time, score)"""
|
||||
self._scan_regions = regions
|
||||
self.update()
|
||||
|
||||
def clear_scan_regions(self) -> None:
|
||||
self._scan_regions = []
|
||||
self.update()
|
||||
```
|
||||
|
||||
**Step 3: Draw scan regions in paintEvent**
|
||||
|
||||
In `paintEvent` (starts around line 282), find the marker drawing section (line 363, comment `# ── export markers`). BEFORE that section, add:
|
||||
|
||||
```python
|
||||
# ── scan regions ──────────────────────────────────────────────
|
||||
if self._scan_regions and self._duration > 0:
|
||||
for (start, end, score) in self._scan_regions:
|
||||
x1 = int(start / self._duration * w)
|
||||
x2 = int(end / self._duration * w)
|
||||
alpha = int(40 + score * 80) # 40–120 opacity
|
||||
p.fillRect(x1, rh, x2 - x1, h - rh, QColor(100, 200, 255, alpha))
|
||||
```
|
||||
|
||||
**Step 4: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: app starts without errors. No scan regions visible yet (none set).
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: timeline scan region rendering"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: ScanWorker QThread
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (add ScanWorker class, after ExportWorker around line 165)
|
||||
|
||||
**Step 1: Add the ScanWorker class**
|
||||
|
||||
After the `ExportWorker` class (ends around line 165), add:
|
||||
|
||||
```python
|
||||
class ScanWorker(QThread):
|
||||
"""Runs audio similarity scan off the main thread."""
|
||||
finished = pyqtSignal(list) # emits list of (start, end, score)
|
||||
error = pyqtSignal(str)
|
||||
progress = pyqtSignal(str) # status message
|
||||
|
||||
def __init__(self, video_path: str, clip_paths: list[str],
|
||||
mode: str = "average", threshold: float = 0.7):
|
||||
super().__init__()
|
||||
self._video_path = video_path
|
||||
self._clip_paths = clip_paths
|
||||
self._mode = mode
|
||||
self._threshold = threshold
|
||||
self._cancel = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._cancel = True
|
||||
|
||||
def run(self):
|
||||
from core.audio_scan import build_profile, scan_video
|
||||
try:
|
||||
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
|
||||
profile = build_profile(self._clip_paths)
|
||||
if self._cancel:
|
||||
return
|
||||
if profile is None:
|
||||
self.error.emit("No valid reference clips found")
|
||||
return
|
||||
self.progress.emit("Scanning audio...")
|
||||
regions = scan_video(
|
||||
self._video_path, profile,
|
||||
mode=self._mode, threshold=self._threshold,
|
||||
cancel_flag=self,
|
||||
)
|
||||
if not self._cancel:
|
||||
self.finished.emit(regions)
|
||||
except Exception as e:
|
||||
if not self._cancel:
|
||||
self.error.emit(str(e))
|
||||
```
|
||||
|
||||
**Step 2: Verify import works**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -c "from main import ScanWorker; print('ok')"`
|
||||
Expected: `ok`
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add ScanWorker QThread for background scanning"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: DB helper — get_all_export_paths
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/db.py`
|
||||
- Modify: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write the test**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
def test_db_get_all_export_paths():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
from core.db import ProcessedDB
|
||||
db = ProcessedDB(path)
|
||||
db.add("a.mp4", 10.0, "/out/a_001.mp4", profile="test")
|
||||
db.add("b.mp4", 20.0, "/out/b_001.mp4", profile="test")
|
||||
db.add("c.mp4", 30.0, "/out/c_001.mp4", profile="other")
|
||||
paths = db.get_all_export_paths("test")
|
||||
assert set(paths) == {"/out/a_001.mp4", "/out/b_001.mp4"}
|
||||
finally:
|
||||
os.unlink(path)
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||
Expected: FAIL with `AttributeError: 'ProcessedDB' object has no attribute 'get_all_export_paths'`
|
||||
|
||||
**Step 3: Write the implementation**
|
||||
|
||||
Add to `core/db.py`, after the `get_markers` method. Note: no lock needed — follows
|
||||
the codebase convention where read-only methods don't acquire the lock.
|
||||
|
||||
```python
|
||||
def get_all_export_paths(self, profile: str = "default") -> list[str]:
|
||||
"""Return all unique output_path values for a given profile."""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/test_audio_scan.py::test_db_get_all_export_paths -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add core/db.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add get_all_export_paths to ProcessedDB"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: UI controls for audio scanning
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` (MainWindow class — control creation ~1490-1575, layout ~1620-1640)
|
||||
|
||||
**Step 1: Add scan control widgets**
|
||||
|
||||
In the MainWindow `__init__`, find the control creation section. After `self._chk_track` (around line 1501), add:
|
||||
|
||||
```python
|
||||
# ── audio scan controls ──────────────────────────────────────
|
||||
self._btn_scan = QPushButton("Scan")
|
||||
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
|
||||
self._btn_scan.clicked.connect(self._start_scan)
|
||||
|
||||
self._sld_threshold = QDoubleSpinBox()
|
||||
self._sld_threshold.setRange(0.0, 1.0)
|
||||
self._sld_threshold.setSingleStep(0.05)
|
||||
self._sld_threshold.setValue(0.7)
|
||||
self._sld_threshold.setPrefix("Thr: ")
|
||||
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
|
||||
|
||||
self._cmb_scan_mode = QComboBox()
|
||||
self._cmb_scan_mode.addItems(["Average", "Nearest"])
|
||||
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
|
||||
|
||||
self._cmb_scan_ref = QComboBox()
|
||||
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
|
||||
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
|
||||
self._scan_folder: str = ""
|
||||
|
||||
self._scan_worker: ScanWorker | None = None
|
||||
```
|
||||
|
||||
**Step 2: Add controls to settings_row layout**
|
||||
|
||||
Find the `settings_row` assembly (around line 1620). Before `settings_row.addStretch()` (around line 1635), add:
|
||||
|
||||
```python
|
||||
settings_row.addWidget(self._btn_scan)
|
||||
settings_row.addWidget(self._sld_threshold)
|
||||
settings_row.addWidget(self._cmb_scan_mode)
|
||||
settings_row.addWidget(self._cmb_scan_ref)
|
||||
```
|
||||
|
||||
**Step 3: Add handler methods**
|
||||
|
||||
Add these methods to MainWindow (after `_jump_to_next_marker` around line 2410):
|
||||
|
||||
```python
|
||||
def _on_scan_ref_changed(self, index: int) -> None:
|
||||
if index == 1: # Custom Folder
|
||||
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
|
||||
if folder:
|
||||
self._scan_folder = folder
|
||||
else:
|
||||
self._cmb_scan_ref.setCurrentIndex(0)
|
||||
|
||||
def _cleanup_scan_worker(self) -> None:
|
||||
"""Disconnect signals and schedule deletion of old scan worker."""
|
||||
if self._scan_worker is not None:
|
||||
try:
|
||||
self._scan_worker.finished.disconnect()
|
||||
self._scan_worker.error.disconnect()
|
||||
self._scan_worker.progress.disconnect()
|
||||
except TypeError:
|
||||
pass # already disconnected
|
||||
self._scan_worker.deleteLater()
|
||||
self._scan_worker = None
|
||||
|
||||
def _start_scan(self) -> None:
|
||||
if not self._file_path:
|
||||
self._show_status("No video loaded")
|
||||
return
|
||||
if self._scan_worker and self._scan_worker.isRunning():
|
||||
self._show_status("Scan already running")
|
||||
return
|
||||
|
||||
# Clean up previous worker
|
||||
self._cleanup_scan_worker()
|
||||
|
||||
# Collect reference clip paths
|
||||
if self._cmb_scan_ref.currentIndex() == 0:
|
||||
# Current profile — all exports across all files in this profile
|
||||
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
|
||||
if os.path.exists(p)]
|
||||
else:
|
||||
# Custom folder
|
||||
if not self._scan_folder:
|
||||
self._show_status("No reference folder selected")
|
||||
return
|
||||
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
|
||||
clip_paths = [
|
||||
os.path.join(self._scan_folder, f)
|
||||
for f in sorted(os.listdir(self._scan_folder))
|
||||
if f.lower().endswith(exts)
|
||||
]
|
||||
|
||||
if not clip_paths:
|
||||
self._show_status("No reference clips found")
|
||||
return
|
||||
|
||||
mode = self._cmb_scan_mode.currentText().lower()
|
||||
threshold = self._sld_threshold.value()
|
||||
|
||||
self._btn_scan.setEnabled(False)
|
||||
self._scan_file_path = self._file_path # remember which file we're scanning
|
||||
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
|
||||
|
||||
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
|
||||
self._scan_worker.finished.connect(self._on_scan_done)
|
||||
self._scan_worker.error.connect(self._on_scan_error)
|
||||
self._scan_worker.progress.connect(self._show_status)
|
||||
self._scan_worker.start()
|
||||
|
||||
def _on_scan_done(self, regions: list) -> None:
|
||||
self._btn_scan.setEnabled(True)
|
||||
# Ignore stale results if the user switched files during scan
|
||||
if self._file_path != getattr(self, '_scan_file_path', None):
|
||||
return
|
||||
self._timeline.set_scan_regions(regions)
|
||||
self._show_status(f"Scan complete: {len(regions)} matching regions")
|
||||
|
||||
def _on_scan_error(self, msg: str) -> None:
|
||||
self._btn_scan.setEnabled(True)
|
||||
self._show_status(f"Scan error: {msg}")
|
||||
```
|
||||
|
||||
**Step 4: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: Scan button, threshold spinner, mode dropdown, and reference source dropdown visible in the settings row. Clicking Scan with no file loaded shows "No video loaded" in status.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add scan UI controls and start_scan handler"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Keyboard shortcut — jump to next scan region
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py`
|
||||
|
||||
**Step 1: Add the keyboard shortcut**
|
||||
|
||||
Find the shortcut definitions (around line 1728, where `QShortcut(QKeySequence("M"), ...)` is defined). Add after it:
|
||||
|
||||
```python
|
||||
QShortcut(QKeySequence("S"), self, context=ctx).activated.connect(self._jump_to_next_scan_region)
|
||||
```
|
||||
|
||||
**Step 2: Add the jump method**
|
||||
|
||||
After `_on_scan_error` (or after `_jump_to_next_marker`), add:
|
||||
|
||||
```python
|
||||
def _jump_to_next_scan_region(self) -> None:
|
||||
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
|
||||
if not regions:
|
||||
return
|
||||
for (start, _end, _score) in regions:
|
||||
if start > self._cursor + 0.1:
|
||||
self._step_cursor(start - self._cursor)
|
||||
return
|
||||
# Wrap to first region
|
||||
self._step_cursor(regions[0][0] - self._cursor)
|
||||
```
|
||||
|
||||
**Step 3: Update help text**
|
||||
|
||||
Find the help/shortcuts tooltip (around line 1757). Add a row:
|
||||
|
||||
```python
|
||||
"<tr><td><b>S</b></td><td>Jump to next scan region</td></tr>"
|
||||
```
|
||||
|
||||
**Step 4: Clear scan regions and cancel running scan on file change**
|
||||
|
||||
Find `_load_file` method (around line 1931). After the existing marker/state resets, add:
|
||||
|
||||
```python
|
||||
self._timeline.clear_scan_regions()
|
||||
if self._scan_worker and self._scan_worker.isRunning():
|
||||
self._scan_worker.cancel()
|
||||
self._cleanup_scan_worker()
|
||||
self._btn_scan.setEnabled(True)
|
||||
```
|
||||
|
||||
**Step 5: Verify manually**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python main.py`
|
||||
Expected: S key does nothing when no scan regions exist. After a scan, S jumps through matched regions.
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add main.py
|
||||
git commit -m "feat: add S shortcut and clear scan on file change"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Final integration test
|
||||
|
||||
**Step 1: End-to-end manual test**
|
||||
|
||||
1. Open the app: `cd /media/p5/8-cut && python main.py`
|
||||
2. Load a video file
|
||||
3. Export a few clips (these become the reference)
|
||||
4. Set reference source to "Current Profile"
|
||||
5. Click "Scan"
|
||||
6. Verify: status shows progress messages, then "Scan complete: N matching regions"
|
||||
7. Verify: cyan-tinted regions appear on the timeline
|
||||
8. Press S to jump through scan regions
|
||||
9. Change threshold and re-scan — verify different number of regions
|
||||
10. Switch mode to "Nearest" and re-scan
|
||||
11. Switch reference to "Custom Folder", pick a folder with clips
|
||||
12. Re-scan and verify results
|
||||
|
||||
**Step 2: Run all tests**
|
||||
|
||||
Run: `cd /media/p5/8-cut && python -m pytest tests/ -v`
|
||||
Expected: all tests PASS
|
||||
|
||||
**Step 3: Final commit**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "feat: audio similarity scanning complete"
|
||||
```
|
||||
@@ -0,0 +1,98 @@
|
||||
# Audio Pipeline Improvements Design
|
||||
|
||||
Date: 2026-04-19
|
||||
|
||||
## Goal
|
||||
|
||||
Improve audio scan classification accuracy, especially for non-speech sounds (suction, gagging, impacts), through three changes:
|
||||
|
||||
1. Multi-layer feature extraction from existing HuBERT/Wav2Vec2 models
|
||||
2. Two new embedding models: AST (AudioSet-supervised) and EAT (self-supervised + AudioSet finetuned)
|
||||
3. Calibrated classifier for better threshold behavior
|
||||
|
||||
## 1. Multi-Layer Feature Extraction
|
||||
|
||||
### Current behavior
|
||||
|
||||
`model(waveforms)` extracts embeddings from the **last transformer layer only**.
|
||||
|
||||
### Change
|
||||
|
||||
Use `model.extract_features(waveforms)` (torchaudio API) to get all layer outputs. Select layers at quartile boundaries, mean-pool each over time, concatenate.
|
||||
|
||||
| Model | Layers | Single-layer dim | Multi-layer dim (4 quartiles) |
|
||||
|-------|--------|-------------------|-------------------------------|
|
||||
| HUBERT_XLARGE | 48 | 1280 | 5120 |
|
||||
| HUBERT_LARGE | 24 | 1024 | 4096 |
|
||||
| HUBERT_BASE | 12 | 768 | 3072 |
|
||||
| WAV2VEC2_BASE | 12 | 768 | 3072 |
|
||||
|
||||
### Implementation
|
||||
|
||||
- New entries in `_EMBED_MODELS`: `"HUBERT_XLARGE_ML"` -> 5120, etc.
|
||||
- `_extract_w2v_windows`: when model name ends with `_ML`, call `extract_features()` instead of `model()`, select quartile layers, concat
|
||||
- Cache key: model name includes `_ML` suffix -> separate cache files
|
||||
- No change to classifier or training pipeline (HistGBT handles high-dim fine)
|
||||
|
||||
## 2. AST (Audio Spectrogram Transformer)
|
||||
|
||||
### What
|
||||
|
||||
`MIT/ast-finetuned-audioset-10-10-0.4593` via HuggingFace `transformers`. 86M params, 768-dim, supervised on AudioSet 527 sound classes.
|
||||
|
||||
### Integration
|
||||
|
||||
- Load: `ASTModel.from_pretrained()` + `ASTFeatureExtractor`
|
||||
- Preprocessing: `ASTFeatureExtractor` handles mel spectrogram from 16kHz raw audio
|
||||
- Batching: prepare `input_values` per window, stack into batch, forward through model
|
||||
- Multi-layer: `output_hidden_states=True` returns 13 layers; `AST_ML` variant concats quartile layers -> 3072-dim
|
||||
- Model cached via `_get_w2v_model()` same lazy-load pattern
|
||||
|
||||
### Entries
|
||||
|
||||
- `"AST"` -> 768
|
||||
- `"AST_ML"` -> 3072
|
||||
|
||||
## 3. EAT (Efficient Audio Transformer)
|
||||
|
||||
### What
|
||||
|
||||
`worstchan/EAT-base_epoch30_finetune_AS2M` via HuggingFace with `trust_remote_code=True`. 88M params, 768-dim, self-supervised + AudioSet finetuned.
|
||||
|
||||
### Integration
|
||||
|
||||
- Load: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
||||
- Preprocessing: manual 128-bin Kaldi fbank mel spectrogram via torchaudio, normalize with EAT constants `(mel - (-4.268)) / (4.569 * 2)`, reshape to `[B, 1, T, 128]`
|
||||
- Feature extraction: `model.extract_features(mel)` returns `[B, seq, 768]`; CLS token `[:, 0, :]` for utterance-level, or mean-pool `[:, 1:, :]` for frame-level. Use mean-pool for consistency with other models.
|
||||
- Multi-layer: not natively supported, skip for now
|
||||
|
||||
### Entry
|
||||
|
||||
- `"EAT"` -> 768
|
||||
|
||||
## 4. Calibrated Classifier
|
||||
|
||||
Wrap `HistGradientBoostingClassifier` in `CalibratedClassifierCV(clf, cv=3, method='isotonic')` after fitting. Gives well-calibrated probabilities -> threshold slider maps more linearly to precision/recall.
|
||||
|
||||
One change in `train_classifier()`, no UI changes needed.
|
||||
|
||||
## 5. Requirements
|
||||
|
||||
Add to `requirements.txt`:
|
||||
```
|
||||
transformers>=4.30
|
||||
timm>=0.9
|
||||
```
|
||||
|
||||
Both AST and EAT need `transformers`. EAT additionally needs `timm` (used internally by its custom model code). Both setup scripts (`setup_env.sh`, `setup-windows.ps1`) install from `requirements.txt` so no changes needed there.
|
||||
|
||||
## Cache Compatibility
|
||||
|
||||
- All new model variants get distinct cache keys via model name in the hash
|
||||
- Existing caches for HUBERT_XLARGE, BEATs, etc. remain valid and untouched
|
||||
- New models create new `.npz` files in the same `cache/w2v/` directory
|
||||
|
||||
## UI Changes
|
||||
|
||||
- `_EMBED_MODELS` dict additions appear automatically in Train dialog model dropdown and scan model dropdown
|
||||
- No other UI changes needed
|
||||
@@ -0,0 +1,588 @@
|
||||
# Audio Pipeline Improvements Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Improve audio scan accuracy with multi-layer extraction, AST/EAT models, and calibrated classifier.
|
||||
|
||||
**Architecture:** All changes are in `core/audio_scan.py`. The embedding extraction functions gain new model-type branches (AST, EAT, multi-layer). The classifier gets a calibration wrapper. `_EMBED_MODELS` dict and `_get_w2v_model()` are extended. No UI changes needed — new models appear automatically in dropdowns.
|
||||
|
||||
**Tech Stack:** torchaudio (existing), transformers (new dep), timm (new dep), sklearn.calibration (existing dep)
|
||||
|
||||
**Key design notes:**
|
||||
- `_get_w2v_model()` resolves `_ML` suffixed names to their base model for loading (e.g. `HUBERT_XLARGE_ML` loads `HUBERT_XLARGE`). Both share the same GPU model — only the extraction path differs (last-layer vs multi-layer). The global `_w2v_model_name` stores the **base** name so switching between `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` does NOT trigger a reload.
|
||||
- Cache keys use the **full** model name (including `_ML`), so single-layer and multi-layer caches coexist as separate `.npz` files.
|
||||
- AST and EAT are separate model types that do NOT share the torchaudio loading path — they get their own `elif` branches in `_get_w2v_model()`.
|
||||
- Both `_extract_w2v_windows` and `_extract_w2v_targeted` need identical changes to their batch inference blocks. Keep them in sync.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add transformers and timm to requirements
|
||||
|
||||
**Files:**
|
||||
- Modify: `requirements.txt`
|
||||
|
||||
**Step 1: Add dependencies**
|
||||
|
||||
Add after the `torchaudio` line in `requirements.txt`:
|
||||
|
||||
```
|
||||
transformers>=4.30
|
||||
timm>=0.9
|
||||
```
|
||||
|
||||
**Step 2: Verify install**
|
||||
|
||||
Run: `pip install transformers timm`
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add requirements.txt
|
||||
git commit -m "deps: add transformers and timm for AST/EAT models"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Multi-layer extraction for torchaudio models
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-58` (_EMBED_MODELS dict)
|
||||
- Modify: `core/audio_scan.py:96-100` (_embed_dim)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model)
|
||||
- Modify: `core/audio_scan.py:189-205` (_extract_w2v_windows batch loop)
|
||||
- Modify: `core/audio_scan.py:278-293` (_extract_w2v_targeted batch loop)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
Add to `tests/test_audio_scan.py`:
|
||||
|
||||
```python
|
||||
def test_embed_dim_multi_layer():
|
||||
from core.audio_scan import _embed_dim
|
||||
# Multi-layer models should report concatenated dimension
|
||||
assert _embed_dim("HUBERT_XLARGE_ML") == 5120
|
||||
assert _embed_dim("HUBERT_LARGE_ML") == 4096
|
||||
assert _embed_dim("HUBERT_BASE_ML") == 3072
|
||||
# Single-layer unchanged
|
||||
assert _embed_dim("HUBERT_XLARGE") == 1280
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||
Expected: FAIL — `_embed_dim("HUBERT_XLARGE_ML")` returns 768 (default fallback)
|
||||
|
||||
**Step 3: Add multi-layer entries to _EMBED_MODELS**
|
||||
|
||||
In `core/audio_scan.py:50-58`, add after existing entries:
|
||||
|
||||
```python
|
||||
_EMBED_MODELS = {
|
||||
"WAV2VEC2_BASE": 768,
|
||||
"WAV2VEC2_LARGE": 1024,
|
||||
"WAV2VEC2_LARGE_LV60K": 1024,
|
||||
"HUBERT_BASE": 768,
|
||||
"HUBERT_LARGE": 1024,
|
||||
"HUBERT_XLARGE": 1280,
|
||||
"BEATS": 768,
|
||||
# Multi-layer variants (4 quartile layers concatenated)
|
||||
"WAV2VEC2_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_BASE_ML": 3072, # 768 * 4
|
||||
"HUBERT_LARGE_ML": 4096, # 1024 * 4
|
||||
"HUBERT_XLARGE_ML": 5120, # 1280 * 4
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_multi_layer -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Add helper to resolve base model and layer indices**
|
||||
|
||||
Add after `_embed_dim()` (around line 101):
|
||||
|
||||
```python
|
||||
def _ml_config(model_name: str) -> tuple[str, list[int]] | None:
|
||||
"""If model_name is a multi-layer variant, return (base_model, layer_indices).
|
||||
|
||||
Returns None for single-layer models.
|
||||
Layer indices are 0-based into the list returned by extract_features().
|
||||
"""
|
||||
if not model_name.endswith("_ML"):
|
||||
return None
|
||||
base = model_name[:-3] # strip "_ML"
|
||||
if base not in _EMBED_MODELS:
|
||||
return None
|
||||
# Layer counts per model family
|
||||
layer_counts = {
|
||||
"WAV2VEC2_BASE": 12, "WAV2VEC2_LARGE": 24, "WAV2VEC2_LARGE_LV60K": 24,
|
||||
"HUBERT_BASE": 12, "HUBERT_LARGE": 24, "HUBERT_XLARGE": 48,
|
||||
"AST": 12,
|
||||
}
|
||||
n = layer_counts.get(base)
|
||||
if n is None:
|
||||
return None
|
||||
# Select 4 layers at quartile boundaries (0-indexed)
|
||||
indices = [n // 4 - 1, n // 2 - 1, 3 * n // 4 - 1, n - 1]
|
||||
return base, indices
|
||||
```
|
||||
|
||||
Note: AST is included in the layer_counts dict here already so Task 3 doesn't need to modify it again.
|
||||
|
||||
**Step 6: Write test for _ml_config**
|
||||
|
||||
```python
|
||||
def test_ml_config():
|
||||
from core.audio_scan import _ml_config
|
||||
assert _ml_config("HUBERT_XLARGE") is None
|
||||
assert _ml_config("BEATS_ML") is None # BEATS has no ML variant
|
||||
base, layers = _ml_config("HUBERT_XLARGE_ML")
|
||||
assert base == "HUBERT_XLARGE"
|
||||
assert layers == [11, 23, 35, 47]
|
||||
base, layers = _ml_config("HUBERT_BASE_ML")
|
||||
assert base == "HUBERT_BASE"
|
||||
assert layers == [2, 5, 8, 11]
|
||||
```
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_ml_config -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 7: Modify _get_w2v_model to resolve ML base names**
|
||||
|
||||
In `_get_w2v_model()` (line 68), the comparison key must use the resolved base name so that `HUBERT_XLARGE` and `HUBERT_XLARGE_ML` share the same loaded model without reloading:
|
||||
|
||||
```python
|
||||
def _get_w2v_model(model_name: str | None = None):
|
||||
"""Lazy-load an embedding model. Reloads if model_name differs from cached."""
|
||||
global _w2v_model, _w2v_device, _w2v_model_name
|
||||
if model_name is None:
|
||||
model_name = _DEFAULT_EMBED_MODEL
|
||||
# Multi-layer variants use the same base model weights
|
||||
ml = _ml_config(model_name)
|
||||
load_name = ml[0] if ml else model_name
|
||||
if _w2v_model is None or _w2v_model_name != load_name:
|
||||
import torch
|
||||
_w2v_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if load_name == "BEATS":
|
||||
... # existing BEATs code unchanged
|
||||
else:
|
||||
import torchaudio
|
||||
bundle = getattr(torchaudio.pipelines, load_name)
|
||||
_w2v_model = bundle.get_model().to(_w2v_device)
|
||||
_w2v_model.eval()
|
||||
_w2v_model_name = load_name
|
||||
_log(f"audio_scan: {load_name} loaded on {_w2v_device}")
|
||||
return _w2v_model, _w2v_device
|
||||
```
|
||||
|
||||
**Step 8: Modify _extract_w2v_windows batch inference**
|
||||
|
||||
In `_extract_w2v_windows`, compute `ml_cfg` **once** before the batch loop (after line 173 `is_beats = ...`):
|
||||
|
||||
```python
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
```
|
||||
|
||||
Then replace the batch inference block (lines 197-204):
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings.append(batch_emb)
|
||||
```
|
||||
|
||||
**Step 9: Modify _extract_w2v_targeted batch inference (keep in sync)**
|
||||
|
||||
In `_extract_w2v_targeted`, add `ml_cfg` computation after line 276 `is_beats = ...`:
|
||||
|
||||
```python
|
||||
ml_cfg = _ml_config(model_name or _DEFAULT_EMBED_MODEL)
|
||||
```
|
||||
|
||||
Then replace the batch inference block (lines 285-292) with the same branching logic as Step 8:
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
waveforms = torch.from_numpy(np.stack(chunks)).float().to(device)
|
||||
if is_beats:
|
||||
padding_mask = torch.zeros_like(waveforms, dtype=torch.bool)
|
||||
features, _ = model.extract_features(waveforms, padding_mask=padding_mask)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
elif ml_cfg is not None:
|
||||
all_layers, _ = model.extract_features(waveforms)
|
||||
selected = [all_layers[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
features, _ = model(waveforms)
|
||||
batch_emb = features.mean(dim=1).cpu().numpy()
|
||||
embeddings_list.append(batch_emb)
|
||||
```
|
||||
|
||||
Note: `_extract_w2v_targeted` appends to `embeddings_list` (not `embeddings`).
|
||||
|
||||
**Step 10: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 11: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: multi-layer extraction for HuBERT/Wav2Vec2 models"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: AST model integration
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add AST entries)
|
||||
- Modify: `core/audio_scan.py:45-47` (add _ast_feature_extractor global)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add AST loading branch)
|
||||
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add AST inference branch)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
```python
|
||||
def test_embed_dim_ast():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("AST") == 768
|
||||
assert _embed_dim("AST_ML") == 3072
|
||||
```
|
||||
|
||||
Run: `pytest tests/test_audio_scan.py::test_embed_dim_ast -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 2: Add AST entries to _EMBED_MODELS**
|
||||
|
||||
Add to the dict (after the ML entries):
|
||||
|
||||
```python
|
||||
# Transformers-based models
|
||||
"AST": 768,
|
||||
"AST_ML": 3072, # 768 * 4
|
||||
```
|
||||
|
||||
Run test again — should PASS now.
|
||||
|
||||
**Step 3: Add module-level global for AST feature extractor**
|
||||
|
||||
Near line 47 (after `_w2v_model_name = None`):
|
||||
|
||||
```python
|
||||
_ast_feature_extractor = None
|
||||
```
|
||||
|
||||
**Step 4: Add AST loading branch in _get_w2v_model**
|
||||
|
||||
In `_get_w2v_model()`, add an `elif` branch **before** the torchaudio fallback `else`:
|
||||
|
||||
```python
|
||||
elif load_name == "AST":
|
||||
from transformers import ASTModel, ASTFeatureExtractor
|
||||
_w2v_model = ASTModel.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
).to(_w2v_device)
|
||||
global _ast_feature_extractor
|
||||
_ast_feature_extractor = ASTFeatureExtractor.from_pretrained(
|
||||
"MIT/ast-finetuned-audioset-10-10-0.4593"
|
||||
)
|
||||
```
|
||||
|
||||
Note: `_ast_feature_extractor` is recreated on every model load (not cached separately) — simple and correct since the feature extractor is lightweight and model reloads are rare.
|
||||
|
||||
**Step 5: Add AST inference branch in both extraction functions**
|
||||
|
||||
In both `_extract_w2v_windows` and `_extract_w2v_targeted`, compute `is_ast` once before the loop:
|
||||
|
||||
```python
|
||||
is_ast = (model_name or _DEFAULT_EMBED_MODEL) in ("AST", "AST_ML")
|
||||
```
|
||||
|
||||
Then in the batch inference block, add after the `elif ml_cfg` branch and before `else`:
|
||||
|
||||
```python
|
||||
elif is_ast:
|
||||
# AST uses its own feature extractor for mel spectrogram
|
||||
inputs = _ast_feature_extractor(
|
||||
list(chunks), sampling_rate=sr, return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_values = inputs.input_values.to(device)
|
||||
if ml_cfg is not None:
|
||||
out = model(input_values, output_hidden_states=True)
|
||||
selected = [out.hidden_states[i].mean(dim=1) for i in ml_cfg[1]]
|
||||
batch_emb = torch.cat(selected, dim=1).cpu().numpy()
|
||||
else:
|
||||
out = model(input_values)
|
||||
batch_emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||
```
|
||||
|
||||
Important: `chunks` is already a list of numpy arrays (built in the loop at lines 194-196). Pass it directly as `list(chunks)` — the `ASTFeatureExtractor` accepts a list of numpy arrays and handles batching/padding internally. Verified: `ASTFeatureExtractor([np.array, np.array, ...], sampling_rate=16000, return_tensors="pt", padding=True)` returns `input_values` of shape `[B, 1024, 128]`.
|
||||
|
||||
**Step 6: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 7: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add AST (Audio Spectrogram Transformer) embedding model"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: EAT model integration
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:50-65` (_EMBED_MODELS, add EAT entry)
|
||||
- Modify: `core/audio_scan.py:68-93` (_get_w2v_model, add EAT loading branch)
|
||||
- Add: `core/audio_scan.py` (_eat_preprocess helper function)
|
||||
- Modify: `core/audio_scan.py` (_extract_w2v_windows and _extract_w2v_targeted, add EAT inference branch)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
```python
|
||||
def test_embed_dim_eat():
|
||||
from core.audio_scan import _embed_dim
|
||||
assert _embed_dim("EAT") == 768
|
||||
```
|
||||
|
||||
**Step 2: Add EAT entry to _EMBED_MODELS**
|
||||
|
||||
```python
|
||||
"EAT": 768,
|
||||
```
|
||||
|
||||
Note: No `EAT_ML` variant — EAT's `extract_features()` does not natively support multi-layer output. Can be added later if needed by monkey-patching.
|
||||
|
||||
**Step 3: Add EAT loading branch in _get_w2v_model**
|
||||
|
||||
Add after the AST branch, before the torchaudio `else`:
|
||||
|
||||
```python
|
||||
elif load_name == "EAT":
|
||||
from transformers import AutoModel
|
||||
_w2v_model = AutoModel.from_pretrained(
|
||||
"worstchan/EAT-base_epoch30_finetune_AS2M",
|
||||
trust_remote_code=True,
|
||||
).to(_w2v_device)
|
||||
```
|
||||
|
||||
**Step 4: Add EAT preprocessing helper**
|
||||
|
||||
Add as a module-level function near `_get_w2v_model`:
|
||||
|
||||
```python
|
||||
def _eat_preprocess(chunks: list[np.ndarray], sr: int, device: str):
|
||||
"""Convert raw audio chunks to EAT mel spectrogram input.
|
||||
|
||||
Returns tensor of shape [B, 1, T, 128].
|
||||
8s audio at 10ms frame shift produces ~798 frames, zero-padded to 1024.
|
||||
"""
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
TARGET_LEN = 1024
|
||||
MEAN, STD = -4.268, 4.569
|
||||
|
||||
mels = []
|
||||
for chunk in chunks:
|
||||
wav = torch.from_numpy(chunk).unsqueeze(0).float()
|
||||
fbank = kaldi.fbank(
|
||||
wav, htk_compat=True, sample_frequency=sr, use_energy=False,
|
||||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10,
|
||||
)
|
||||
# Pad or truncate to TARGET_LEN
|
||||
if fbank.shape[0] < TARGET_LEN:
|
||||
fbank = torch.nn.functional.pad(fbank, (0, 0, 0, TARGET_LEN - fbank.shape[0]))
|
||||
else:
|
||||
fbank = fbank[:TARGET_LEN]
|
||||
fbank = (fbank - MEAN) / (STD * 2)
|
||||
mels.append(fbank)
|
||||
return torch.stack(mels).unsqueeze(1).to(device) # [B, 1, T, 128]
|
||||
```
|
||||
|
||||
**Step 5: Add EAT inference branch in both extraction functions**
|
||||
|
||||
Compute `is_eat` once before the loop:
|
||||
|
||||
```python
|
||||
is_eat = (model_name or _DEFAULT_EMBED_MODEL) == "EAT"
|
||||
```
|
||||
|
||||
Then in the batch inference block, add after the `elif is_ast` branch and before `else`:
|
||||
|
||||
```python
|
||||
elif is_eat:
|
||||
mel_input = _eat_preprocess(chunks, sr, device)
|
||||
features = model.extract_features(mel_input)
|
||||
# Mean-pool frame-level tokens (skip CLS at index 0)
|
||||
batch_emb = features[:, 1:, :].mean(dim=1).cpu().numpy()
|
||||
```
|
||||
|
||||
Important: `model.extract_features()` returns a plain `torch.Tensor` of shape `[B, 513, 768]` (not a tuple). Index 0 is the CLS token, indices 1-512 are frame-level patch embeddings. We mean-pool the frame tokens for consistency with how other models are pooled.
|
||||
|
||||
**Step 6: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 7: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py tests/test_audio_scan.py
|
||||
git commit -m "feat: add EAT (Efficient Audio Transformer) embedding model"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Calibrated classifier
|
||||
|
||||
**Files:**
|
||||
- Modify: `core/audio_scan.py:424-429` (train_classifier, wrap clf)
|
||||
- Test: `tests/test_audio_scan.py`
|
||||
|
||||
**Step 1: Modify train_classifier**
|
||||
|
||||
After the existing `clf.fit()` call (line 428), add calibration with a safe guard:
|
||||
|
||||
```python
|
||||
clf.fit(X[train_idx], y_arr[train_idx])
|
||||
_log("audio_scan: classifier trained")
|
||||
|
||||
# Calibrate probabilities for better threshold behavior
|
||||
# Requires at least 6 samples per class for stable 3-fold isotonic calibration
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
min_class = min(int(n_pos), int(n_neg_sample))
|
||||
if min_class >= 6:
|
||||
cal_clf = CalibratedClassifierCV(clf, cv=3, method='isotonic')
|
||||
cal_clf.fit(X[train_idx], y_arr[train_idx])
|
||||
clf = cal_clf
|
||||
_log("audio_scan: classifier calibrated (isotonic, 3-fold)")
|
||||
else:
|
||||
_log(f"audio_scan: skipping calibration (min class size {min_class} < 6)")
|
||||
```
|
||||
|
||||
Why `min_class >= 6`: `CalibratedClassifierCV` uses stratified k-fold internally. With `cv=3`, each fold needs at least 2 samples per class. `min_class >= 6` guarantees this. With fewer samples, the uncalibrated HistGBT probabilities are still reasonable — calibration is an enhancement, not a requirement.
|
||||
|
||||
Previous plan bug: `cv=min(3, n_pos, n_neg_sample)` could produce `cv=1` when `n_pos=1`, which raises `ValueError` (minimum is 2). Even `cv=2` with 2 positives causes one fold to have only 1 positive, making isotonic regression unstable. The `>= 6` guard avoids all these edge cases.
|
||||
|
||||
**Step 2: Run all tests**
|
||||
|
||||
Run: `pytest tests/ -v`
|
||||
Expected: All pass
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add core/audio_scan.py
|
||||
git commit -m "feat: calibrate classifier probabilities with isotonic regression"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Integration test with real model (manual)
|
||||
|
||||
This task is manual — it requires GPU and a real video file.
|
||||
|
||||
**Step 1: Test multi-layer extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='HUBERT_XLARGE_ML')
|
||||
print(f'HUBERT_XLARGE_ML: {emb.shape}') # expect (13, 5120)
|
||||
assert emb.shape[1] == _embed_dim('HUBERT_XLARGE_ML')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 2: Test AST extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='AST')
|
||||
print(f'AST: {emb.shape}') # expect (13, 768)
|
||||
assert emb.shape[1] == _embed_dim('AST')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 3: Test AST multi-layer**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='AST_ML')
|
||||
print(f'AST_ML: {emb.shape}') # expect (13, 3072)
|
||||
assert emb.shape[1] == _embed_dim('AST_ML')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 4: Test EAT extraction**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _extract_w2v_windows, _embed_dim
|
||||
import numpy as np
|
||||
y = np.random.randn(16000 * 20).astype(np.float32) * 0.01
|
||||
ts, emb = _extract_w2v_windows(y, model_name='EAT')
|
||||
print(f'EAT: {emb.shape}') # expect (13, 768)
|
||||
assert emb.shape[1] == _embed_dim('EAT')
|
||||
print('PASS')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 5: Test model switching doesn't reload unnecessarily**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
from core.audio_scan import _get_w2v_model
|
||||
import core.audio_scan as m
|
||||
# Load HUBERT_XLARGE
|
||||
_get_w2v_model('HUBERT_XLARGE')
|
||||
name1 = m._w2v_model_name
|
||||
# Switch to ML variant — should NOT reload
|
||||
_get_w2v_model('HUBERT_XLARGE_ML')
|
||||
name2 = m._w2v_model_name
|
||||
assert name1 == name2 == 'HUBERT_XLARGE', f'Expected no reload, got {name1} -> {name2}'
|
||||
print('PASS: no reload on ML switch')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 6: Test full train+scan cycle in app**
|
||||
|
||||
Load app, select each new model from scan model dropdown, scan a video, train, verify results display correctly.
|
||||
|
||||
**Step 7: Final commit and push**
|
||||
|
||||
```bash
|
||||
git push
|
||||
```
|
||||
@@ -0,0 +1,226 @@
|
||||
# ComfyUI-8cut Node Pack Design
|
||||
|
||||
Date: 2026-04-19
|
||||
|
||||
## Goal
|
||||
|
||||
Port 8-cut's video scanning, training, review, and export workflow to a ComfyUI node pack. The primary motivation is **remote access** — ComfyUI's web UI allows browser-based operation over the network, and HTML5 `<video>` handles streaming compression natively. No tensor-based image pipeline; videos stay as file paths throughout.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Approach
|
||||
|
||||
Monolithic Review Node + simple pipeline nodes. One central **VideoReview** node embeds the full interactive player/timeline/region table as a large DOM widget. Other nodes (Scan, Train, Export) are headless pipeline nodes that pass lightweight metadata.
|
||||
|
||||
### Core reuse
|
||||
|
||||
The entire `8-cut/core/` package is Qt-free and reusable as-is:
|
||||
- `core/audio_scan.py` — `scan_video()`, `train_classifier()`, `load_classifier()`
|
||||
- `core/db.py` — `ProcessedDB` (SQLite, all scan/training/export persistence)
|
||||
- `core/ffmpeg.py` — `build_ffmpeg_command()` (clip export)
|
||||
- `core/tracking.py` — YOLO-based subject tracking
|
||||
- `core/paths.py` — path helpers, `format_time()`
|
||||
|
||||
No porting required — these are imported directly.
|
||||
|
||||
---
|
||||
|
||||
## Node Pack Structure
|
||||
|
||||
```
|
||||
ComfyUI-8cut/
|
||||
__init__.py # NODE_CLASS_MAPPINGS, WEB_DIRECTORY
|
||||
core/ # symlink or copy of 8-cut/core/
|
||||
data/
|
||||
8cut.db # separate SQLite DB (can copy from ~/.8cut.db)
|
||||
models/ # trained classifiers (.joblib)
|
||||
nodes/
|
||||
load_video.py
|
||||
audio_scan.py
|
||||
video_review.py
|
||||
train_model.py
|
||||
export_clips.py
|
||||
server_routes.py # custom API routes
|
||||
web/
|
||||
js/
|
||||
video_review.js # timeline + player + scan panel widget
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Custom Types
|
||||
|
||||
No tensors anywhere in the pipeline. All data flows as lightweight metadata:
|
||||
|
||||
| Type | Python value | Purpose |
|
||||
|------|-------------|---------|
|
||||
| `VIDEO_PATH` | `str` (absolute path) | Video file reference |
|
||||
| `SCAN_REGIONS` | `list[dict]` with start/end/score/model/disabled | Scan output / review edits |
|
||||
| `SCAN_MODEL` | `str` (path to .joblib) | Trained classifier |
|
||||
|
||||
---
|
||||
|
||||
## Nodes
|
||||
|
||||
### LoadVideo
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `video_path` (STRING, file browser), `profile` (STRING combo from DB profiles) |
|
||||
| **Output** | `VIDEO_PATH`, `filename` (STRING) |
|
||||
| **Logic** | Validates path exists, returns it. Populates profile combo via API route. |
|
||||
|
||||
### AudioScan
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_MODEL`, `threshold` (FLOAT 0-1), `hop` (FLOAT) |
|
||||
| **Output** | `SCAN_REGIONS` |
|
||||
| **Logic** | Calls `core.audio_scan.scan_video()` directly. Progress via `PromptServer.send_sync("progress", ...)`. |
|
||||
|
||||
### VideoReview (interactive, blocking)
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS` (optional) |
|
||||
| **Output** | `SCAN_REGIONS` (edited) |
|
||||
| **OUTPUT_NODE** | `True` |
|
||||
| **Logic** | Execution pauses here. User interacts via the widget. Clicks "Continue" to pass edited regions downstream. |
|
||||
|
||||
The widget layout:
|
||||
|
||||
```
|
||||
+-------------------------------------+
|
||||
| [video player (HTML5 <video>)] |
|
||||
| +- timeline with scan regions ----+|
|
||||
| | cursor + region drag/resize ||
|
||||
| +---------------------------------+|
|
||||
| +- model tabs [EAT_LARGE][HuBERT]+|
|
||||
| | Time | End | Score ||
|
||||
| | 1:23 | 1:31 | 0.92 ||
|
||||
| | 3:45 | 3:53 | 0.87 ||
|
||||
| | [Add Negative] [Export] [Continue]|
|
||||
| +---------------------------------+|
|
||||
+-------------------------------------+
|
||||
```
|
||||
|
||||
Widget size: ~640x500px minimum, resizable via LiteGraph.
|
||||
|
||||
**Blocking mechanism**: The node's `run()` method blocks on a server-side event/queue. The frontend signals completion via `POST /8cut/review_done/{node_id}`, which unblocks `run()` and returns the edited `SCAN_REGIONS`.
|
||||
|
||||
### TrainModel
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `profile` (STRING combo), `positive_folder` (STRING combo), `negative_folder` (STRING combo, optional), `embed_model` (STRING combo from `_EMBED_MODELS`), `use_hard_negatives` (BOOL) |
|
||||
| **Output** | `SCAN_MODEL` |
|
||||
| **Logic** | Queries `db.get_training_data()` to assemble `video_infos`, calls `core.audio_scan.train_classifier()`. Saves to `models/{profile}_{embed_model}.joblib` with version rotation. Progress via ComfyUI progress bar. |
|
||||
|
||||
### ExportClips
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | `VIDEO_PATH`, `SCAN_REGIONS`, `output_folder` (STRING), `short_side` (INT), `format` (combo MP4/WEBM), `spread` (FLOAT), `clip_count` (INT), `fuse_gap` (FLOAT) |
|
||||
| **Output** | exported file paths (list) |
|
||||
| **Logic** | Region fusion via `_build_export_spans()`, then `core.ffmpeg.build_ffmpeg_command()` per clip. Records each clip in DB via `db.add()`. |
|
||||
|
||||
### Typical workflow
|
||||
|
||||
```
|
||||
[LoadVideo] --> [AudioScan] --> [VideoReview] --> [ExportClips]
|
||||
^
|
||||
[TrainModel]
|
||||
```
|
||||
|
||||
### Training loop (hard negatives round-trip)
|
||||
|
||||
1. Scan with existing model -> regions in VideoReview
|
||||
2. Review -> mark false positives as negatives (DB)
|
||||
3. Train -> new model uses hard negatives
|
||||
4. Rescan -> better results
|
||||
5. Repeat
|
||||
|
||||
---
|
||||
|
||||
## API Routes
|
||||
|
||||
### Video serving
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/video` | GET | Serve raw video file via `web.FileResponse`. Query param: `path`. Browser decodes mp4/h264 natively — key for remote streaming. |
|
||||
| `/8cut/video_transcode` | GET | Fallback: transcode to webm on-the-fly via ffmpeg `StreamResponse` for browser-incompatible formats (some MKV, odd codecs). |
|
||||
|
||||
### Region editing (from VideoReview widget)
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/toggle_region` | POST | `toggle_scan_result_disabled()` |
|
||||
| `/8cut/resize_region` | POST | `update_scan_result()` |
|
||||
| `/8cut/delete_region` | POST | `delete_scan_result()` |
|
||||
| `/8cut/add_negatives` | POST | `add_hard_negatives()` |
|
||||
| `/8cut/scan_versions` | GET | `get_scan_versions()` |
|
||||
| `/8cut/review_done/{node_id}` | POST | Unblock the VideoReview node's `run()`, pass final regions |
|
||||
|
||||
### Data queries (for combo widget population)
|
||||
|
||||
| Route | Method | Purpose |
|
||||
|-------|--------|---------|
|
||||
| `/8cut/profiles` | GET | `db.get_profiles()` |
|
||||
| `/8cut/export_folders` | GET | `db.get_export_folders()` |
|
||||
| `/8cut/models` | GET | List available `.joblib` models |
|
||||
|
||||
---
|
||||
|
||||
## Frontend JS Widget (`web/js/video_review.js`)
|
||||
|
||||
Registered via `app.registerExtension()`. Hooks into the VideoReview node's `onNodeCreated` and `onExecuted` callbacks.
|
||||
|
||||
### Components
|
||||
|
||||
1. **Video player** — HTML5 `<video>` element, src pointed at `/8cut/video?path=...`
|
||||
2. **Timeline** — `<canvas>` overlay below the video. Renders:
|
||||
- Scan region rectangles (color-coded by score, red for negatives, gray for disabled)
|
||||
- Cursor line (click to seek)
|
||||
- Drag handles on region edges (resize)
|
||||
- Waveform (optional, fetched via separate route)
|
||||
3. **Region table** — HTML table with model tabs. Click row to seek. Columns: Time, End, Score.
|
||||
4. **Action buttons** — Add Negative, Export, Continue
|
||||
5. **Version combo** — dropdown to switch scan history versions
|
||||
|
||||
### Interaction flow
|
||||
|
||||
- Widget activates when `onExecuted` fires with scan regions
|
||||
- User clicks/drags timeline, edits regions, marks negatives
|
||||
- Each edit hits an API route (immediate DB persistence)
|
||||
- "Continue" sends `POST /8cut/review_done/{node_id}` with final region state
|
||||
- Node's `run()` unblocks, passes `SCAN_REGIONS` downstream
|
||||
|
||||
---
|
||||
|
||||
## DB
|
||||
|
||||
Separate SQLite DB at `ComfyUI-8cut/data/8cut.db`. Uses the existing `ProcessedDB` class unchanged — same schema, same migration code. Users can copy their existing `~/.8cut.db` to carry over scan history, training data, and hard negatives.
|
||||
|
||||
---
|
||||
|
||||
## Dependencies
|
||||
|
||||
Same as 8-cut's `requirements.txt` minus PyQt6/python-mpv:
|
||||
- `torch`, `torchaudio`, `torchvision` (from CUDA index)
|
||||
- `transformers>=4.30,<5.0`, `timm>=0.9`
|
||||
- `librosa`, `scikit-learn`, `joblib`, `soundfile`, `numpy`
|
||||
- `ultralytics` (YOLO tracking)
|
||||
|
||||
ComfyUI already provides torch. The node pack's install script just needs the audio/ML extras.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
1. **Node pack skeleton** — structure, `__init__.py`, custom types, API routes for video serving
|
||||
2. **LoadVideo + AudioScan** — headless nodes, no widget needed yet
|
||||
3. **VideoReview widget (minimal)** — video player + static region display + Continue button
|
||||
4. **VideoReview interactivity** — timeline click/drag, region editing, negative marking
|
||||
5. **TrainModel + ExportClips** — complete the pipeline
|
||||
6. **Polish** — version history, waveform overlay, transcode fallback
|
||||
@@ -0,0 +1,205 @@
|
||||
# Scan History & Hard Negative Management — Final Design
|
||||
|
||||
Date: 2026-04-19 (implemented on `feat/training-ui`)
|
||||
|
||||
## Goal
|
||||
|
||||
1. Keep scan result history per `(file, model)` so users can track classifier improvement across training iterations
|
||||
2. Make hard negatives manageable — viewable, removable, and optionally disabled per training run
|
||||
3. Fix latent bug: `get_export_folders()` doesn't filter by `scan_export`
|
||||
|
||||
---
|
||||
|
||||
## 1. Ghost Folder Fix
|
||||
|
||||
### Bug
|
||||
|
||||
`get_export_folders()` queried all `output_path` rows without filtering `scan_export`. Folders that only contained scan-exported clips appeared in training dropdowns with 0 clips.
|
||||
|
||||
### Implementation (`core/db.py`)
|
||||
|
||||
**`get_export_folders(profile, include_scan_exports=False)`** — new parameter. When `False` (default), the SQL query adds `AND scan_export = 0` to exclude scan-only folders. The `get_training_stats()` method passes this through and also filters its return dict to remove folders with 0 clips:
|
||||
|
||||
```python
|
||||
return {k: v for k, v in stats.items() if v["clips"] > 0}
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
`tests/test_db.py::test_export_folders_excludes_scan_exports` — verifies scan-only folders are excluded by default and included when `include_scan_exports=True`.
|
||||
|
||||
---
|
||||
|
||||
## 2. Scan Result History
|
||||
|
||||
### Schema
|
||||
|
||||
Added column to `scan_results`:
|
||||
|
||||
```sql
|
||||
scan_timestamp TEXT NOT NULL DEFAULT ''
|
||||
```
|
||||
|
||||
All rows from the same scan share one timestamp string with **microsecond precision** (`%Y%m%d_%H%M%S_%f`, e.g. `"20260419_143022_123456"`). Microsecond precision prevents version collisions on fast successive scans.
|
||||
|
||||
Migration adds the column via `ALTER TABLE` for existing databases. Legacy rows keep `scan_timestamp = ''`.
|
||||
|
||||
### DB methods (`core/db.py`)
|
||||
|
||||
**`save_scan_results(filename, profile, model, regions, max_versions=5)`**
|
||||
1. Inserts new rows with current microsecond-precision timestamp
|
||||
2. Counts distinct timestamps for this `(filename, profile, model)`
|
||||
3. Prunes oldest timestamps beyond `max_versions`
|
||||
|
||||
No more DELETE-then-INSERT — all versions coexist in the table.
|
||||
|
||||
**`get_scan_versions(filename, profile, model)`**
|
||||
Returns `[{timestamp, count, max_score}, ...]` ordered newest first. Filters `scan_timestamp != ''` so legacy rows don't appear as named versions.
|
||||
|
||||
**`get_scan_results(filename, profile, scan_timestamp=None)`**
|
||||
- With `scan_timestamp`: returns rows matching that exact version
|
||||
- Without (default): uses `INNER JOIN` subquery with `MAX(scan_timestamp)` per model to return only the latest version. Legacy rows (empty timestamp) sort before any real timestamp, so they're returned when no versioned scans exist.
|
||||
|
||||
### UI (`main.py` — `ScanResultsPanel`)
|
||||
|
||||
Each model tab wraps its `QTableWidget` in a container `QWidget` with a `QComboBox` for version selection:
|
||||
|
||||
```
|
||||
container (QWidget)
|
||||
├── cmb_version (QComboBox) — hidden when ≤ 1 version
|
||||
└── table (QTableWidget)
|
||||
```
|
||||
|
||||
**Helper methods** unwrap this container:
|
||||
- `_current_table()` — returns `QTableWidget` from active tab (handles both raw table and container)
|
||||
- `_tab_table(index)` — same by tab index
|
||||
|
||||
**Version combo** is populated by `_populate_version_combos()` after every `load_for_file()` and `add_scan_results()` call. Labels use `datetime.strptime` parsing with try/except fallback for robustness:
|
||||
|
||||
```
|
||||
2026-04-19 14:30 (12 regions, best: 0.95)
|
||||
```
|
||||
|
||||
**Version switching** via `_on_version_changed(model, idx)`:
|
||||
1. Reads `scan_timestamp` from combo's `userData`
|
||||
2. Calls `get_scan_results(filename, profile, scan_timestamp=ts)`
|
||||
3. Repopulates the table in-place
|
||||
4. **Clears the undo stack** — stale undo entries from a different version would corrupt data
|
||||
5. Emits `regions_edited` to refresh the timeline
|
||||
|
||||
**Tab switch** connects `tab_changed` signal to `_on_scan_regions_edited` (not just `_update_scan_export_count`), so the timeline updates scan regions when switching model tabs.
|
||||
|
||||
### Cache interaction
|
||||
|
||||
Embedding cache is per `(file, model)` and doesn't change across scans. History stores classified regions (start, end, score), not embeddings.
|
||||
|
||||
### Test
|
||||
|
||||
`tests/test_db.py::test_scan_result_history` — saves 3 versions, verifies counts, ordering, and latest-by-default behavior.
|
||||
|
||||
---
|
||||
|
||||
## 3. Hard Negative Management
|
||||
|
||||
### Schema
|
||||
|
||||
Added column to `hard_negatives`:
|
||||
|
||||
```sql
|
||||
source_model TEXT NOT NULL DEFAULT ''
|
||||
```
|
||||
|
||||
Migration adds the column via `ALTER TABLE` for existing databases.
|
||||
|
||||
### DB methods (`core/db.py`)
|
||||
|
||||
**`add_hard_negatives(filename, profile, times, source_path="", source_model="")`** — now stores which embedding model produced the scan that led to the negative marking.
|
||||
|
||||
**`get_hard_negatives(profile)`** — returns all rows as `[{id, filename, start_time, source_path, source_model}, ...]` for the management dialog.
|
||||
|
||||
**`delete_hard_negatives_by_ids(ids)`** — bulk delete by row IDs.
|
||||
|
||||
**`get_training_data(..., use_hard_negatives=True)`** — new parameter. When `False`, the hard negatives query is skipped entirely. Non-destructive — negatives remain in DB.
|
||||
|
||||
### Source model tracking (`main.py`)
|
||||
|
||||
`_on_scan_negatives()` now passes `source_model=self._scan_panel.current_model_name()` when marking negatives from scan results. `current_model_name()` extracts the model name from the active tab text (stripping the count suffix).
|
||||
|
||||
### Training toggle (`main.py` — `TrainDialog`)
|
||||
|
||||
Checkbox **"Use hard negatives in training"** (default checked) with "Manage..." button in an HBox layout. The toggle:
|
||||
- Updates live training stats preview via debounced `_update_stats()`
|
||||
- Passes `use_hard_negatives` through `_open_train_dialog()` to `get_training_data()`
|
||||
|
||||
### Management dialog (`main.py` — `HardNegativesDialog`)
|
||||
|
||||
Accessible from TrainDialog's "Manage..." button. Features:
|
||||
|
||||
| Component | Details |
|
||||
|-----------|---------|
|
||||
| **Filter combo** | `(all)` + each distinct `source_model` found in data |
|
||||
| **Summary label** | `<b>N</b> hard negatives` |
|
||||
| **Table** | File, Time (`{:.1f}s`), Source Model, hidden ID column |
|
||||
| **Delete Selected** | Multi-select aware, skips hidden (filtered) rows |
|
||||
| **Clear All** | **Filter-aware**: if a model filter is active, only deletes negatives for that model with an appropriate confirmation message. If `(all)`, deletes everything. |
|
||||
| **Close** | Closes dialog, triggers stats refresh in parent TrainDialog |
|
||||
|
||||
`blockSignals(True)` guards prevent spurious filter callbacks during `_load()` repopulation.
|
||||
|
||||
### Tests
|
||||
|
||||
- `test_hard_negatives_source_model` — verifies source_model stored and retrieved
|
||||
- `test_training_data_skips_hard_negatives` — verifies `use_hard_negatives=False` excludes them
|
||||
- `test_delete_hard_negatives_by_ids` — verifies bulk deletion by ID
|
||||
|
||||
---
|
||||
|
||||
## 4. Runtime Fixes (discovered during testing)
|
||||
|
||||
### EAT/torchvision ABI mismatch
|
||||
|
||||
**Problem:** `torchvision` installed from PyPI (CPU build) was incompatible with `torch` from CUDA wheel index, causing `operator torchvision::nms does not exist`.
|
||||
|
||||
**Fix:** Added `torchvision` to the explicit torch install line in both setup scripts:
|
||||
```bash
|
||||
pip install torch torchaudio torchvision --index-url "$TORCH_INDEX"
|
||||
```
|
||||
|
||||
Also added `--extra-index-url "$TORCH_INDEX"` to the `pip install -r requirements.txt` line to prevent transitive dependencies (timm, ultralytics) from pulling CPU-only torch packages.
|
||||
|
||||
Applied to: `setup_env.sh` (both conda and venv paths), `setup-windows.ps1`.
|
||||
|
||||
### EAT / transformers 5.x incompatibility
|
||||
|
||||
**Problem:** transformers 5.x broke EAT's remote model code (`'EATModel' object has no attribute 'all_tied_weights_keys'`).
|
||||
|
||||
**Fix:** Pinned `transformers>=4.30,<5.0` in `requirements.txt`.
|
||||
|
||||
### NumPy non-writable array warning
|
||||
|
||||
**Problem:** Cached HuBERT/EAT embeddings loaded from disk are read-only numpy arrays. `torch.from_numpy()` on a non-writable array triggers a deprecation warning.
|
||||
|
||||
**Fix:** In `core/audio_scan.py`, changed EAT preprocessing to copy the array:
|
||||
```python
|
||||
wav = torch.from_numpy(np.array(chunk)).unsqueeze(0).float()
|
||||
```
|
||||
|
||||
### Timeline not updating on tab switch
|
||||
|
||||
**Problem:** Switching model tabs in the scan results panel didn't refresh the timeline's highlighted regions because `tab_changed` was only connected to `_update_scan_export_count`.
|
||||
|
||||
**Fix:** Connected `tab_changed` to `_on_scan_regions_edited` instead, which handles both timeline refresh and export count update.
|
||||
|
||||
---
|
||||
|
||||
## File Summary
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `core/db.py` | Schema migrations, `get_export_folders` filter, versioned `save_scan_results`, `get_scan_versions`, version-aware `get_scan_results`, `add_hard_negatives` with `source_model`, `get_hard_negatives`, `delete_hard_negatives_by_ids`, `get_training_data` with `use_hard_negatives` |
|
||||
| `main.py` | `HardNegativesDialog` class, `TrainDialog` hard neg toggle + manage button, `ScanResultsPanel` container/combo architecture, version combo population and switching, `current_model_name()`, tab-switch timeline fix |
|
||||
| `core/audio_scan.py` | `np.array(chunk)` copy for read-only numpy arrays in EAT preprocessing |
|
||||
| `requirements.txt` | `transformers>=4.30,<5.0` pin |
|
||||
| `setup_env.sh` | `torchvision` in torch install, `--extra-index-url` on requirements install |
|
||||
| `setup-windows.ps1` | `torchvision` in torch install, `--extra-index-url` on requirements install, removed skip-if-exists guard |
|
||||
| `tests/test_db.py` | 5 tests covering all DB-layer changes |
|
||||
@@ -0,0 +1,94 @@
|
||||
# Scan History & Hard Negative Management — Implementation Log
|
||||
|
||||
> All tasks complete. See the design doc for the final specification.
|
||||
|
||||
**Branch:** `feat/training-ui`
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Fix ghost folder bug in get_export_folders -- DONE
|
||||
|
||||
**Commit:** `2614a76 fix: get_export_folders respects scan_export filter`
|
||||
|
||||
- `core/db.py` — `get_export_folders(profile, include_scan_exports=False)`: filters `scan_export = 0` by default
|
||||
- `core/db.py` — `get_training_stats()`: passes `include_scan_exports` through, filters out 0-clip folders
|
||||
- `tests/test_db.py` — `test_export_folders_excludes_scan_exports`
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Scan result history — schema and DB methods -- DONE
|
||||
|
||||
**Commit:** `4fb2ae1 feat: scan result history — keep N versions per (file, model)`
|
||||
|
||||
- `core/db.py` — added `scan_timestamp TEXT NOT NULL DEFAULT ''` column with migration
|
||||
- `core/db.py` — `save_scan_results()`: versioned insert with microsecond-precision timestamp (`%Y%m%d_%H%M%S_%f`), auto-prunes beyond `max_versions=5`
|
||||
- `core/db.py` — `get_scan_versions()`: returns `[{timestamp, count, max_score}, ...]` newest first
|
||||
- `core/db.py` — `get_scan_results(scan_timestamp=None)`: `INNER JOIN` subquery with `MAX(scan_timestamp)` for latest-by-default
|
||||
- `tests/test_db.py` — `test_scan_result_history`
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Scan history UI — version selector in ScanResultsPanel -- DONE
|
||||
|
||||
**Commit:** `8ed9fbf feat: scan version selector in results panel`
|
||||
|
||||
- `main.py` — `_add_tab()`: wraps table in container `QWidget` with version `QComboBox` (hidden when ≤ 1 version)
|
||||
- `main.py` — `_current_table()` / `_tab_table(idx)`: unwrap container to get `QTableWidget`
|
||||
- `main.py` — `_populate_version_combos()`: queries `get_scan_versions()`, formats labels with `datetime.strptime` + try/except fallback
|
||||
- `main.py` — `_on_version_changed()`: reloads table from specific version, clears undo stack, emits `regions_edited`
|
||||
- `main.py` — `current_model_name()`: extracts model name from tab text
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Hard negatives — schema and training toggle -- DONE
|
||||
|
||||
**Commit:** `edc5784 feat: hard negative source_model tracking, training toggle`
|
||||
|
||||
- `core/db.py` — added `source_model TEXT NOT NULL DEFAULT ''` column to `hard_negatives` with migration
|
||||
- `core/db.py` — `add_hard_negatives(source_model="")`: stores originating model
|
||||
- `core/db.py` — `get_hard_negatives(profile)`: returns full rows as list of dicts
|
||||
- `core/db.py` — `delete_hard_negatives_by_ids(ids)`: bulk delete by row IDs
|
||||
- `core/db.py` — `get_training_data(use_hard_negatives=True)`: conditionally skips hard negatives query
|
||||
- `main.py` — `TrainDialog`: "Use hard negatives" checkbox + "Manage..." button in HBox layout
|
||||
- `main.py` — `_on_scan_negatives()`: passes `source_model=self._scan_panel.current_model_name()`
|
||||
- `tests/test_db.py` — `test_hard_negatives_source_model`, `test_training_data_skips_hard_negatives`, `test_delete_hard_negatives_by_ids`
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Hard negatives management dialog -- DONE
|
||||
|
||||
**Commit:** `e6db83f feat: hard negatives management dialog with filter and bulk delete`
|
||||
|
||||
- `main.py` — `HardNegativesDialog`: table with File/Time/Source Model/hidden ID columns, model filter combo, delete selected, filter-aware clear all, close button
|
||||
- Filter-aware "Clear All": respects active model filter, shows appropriate confirmation message
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Code review fixes -- DONE
|
||||
|
||||
**Commit:** `5d45b8d fix: timestamp collision, undo stack invalidation, label parsing, filter-aware clear`
|
||||
|
||||
Four issues found during code review:
|
||||
1. **Timestamp collision** — second-precision timestamps could merge versions on sub-second calls. Fixed with microsecond precision `%f`
|
||||
2. **Undo stack invalidation** — switching scan versions left stale undo entries. Fixed by clearing undo stack in `_on_version_changed()`
|
||||
3. **Timestamp label fragile parsing** — hard-coded string slicing. Fixed with `datetime.strptime` + try/except fallback
|
||||
4. **Clear All ignoring filter** — deleted all negatives regardless of model filter. Fixed to respect active filter
|
||||
|
||||
---
|
||||
|
||||
### Runtime fixes (discovered during manual testing)
|
||||
|
||||
| Commit | Fix |
|
||||
|--------|-----|
|
||||
| `a3c657c` | Install `torchvision` from CUDA wheel index (was pulling CPU build from PyPI) |
|
||||
| `3c3b1d7` | Remove "skip if torch exists" guard in Windows setup so re-runs fix broken envs |
|
||||
| `fd043f4` | Pin `transformers>=4.30,<5.0` — EAT remote model code incompatible with transformers 5.x |
|
||||
| `7d6fee9` | Copy read-only numpy array before `torch.from_numpy()` in EAT preprocessing |
|
||||
| `bd345ab` | Connect `tab_changed` to `_on_scan_regions_edited` so timeline refreshes on tab switch |
|
||||
| `d8b3972` | Add `--extra-index-url` to `pip install -r requirements.txt` in both setup scripts |
|
||||
|
||||
---
|
||||
|
||||
### Test results
|
||||
|
||||
All 68 tests pass (5 new DB tests + 63 existing).
|
||||
@@ -0,0 +1,130 @@
|
||||
# Main Window UI Restructure — Design
|
||||
|
||||
**Goal:** Reorganize the `MainWindow` UI in `main.py` from a flat wall of ~50 always-visible controls into a legible, grouped layout — a menu bar for rare actions, a tabbed control deck for settings, an always-visible transport bar, and a real status bar — plus a visual polish pass. Keep every existing behavior, shortcut, and mouse interaction working.
|
||||
|
||||
**Scope:** Reorganization **and** visual polish. **Not** an interaction-model change — single-key shortcuts, timeline mouse overloading, and the export/scan logic are untouched.
|
||||
|
||||
**Audience:** Single power user. Optimize for density and speed. The goal is *order, not hiding*: keep everything fast to reach; push only genuinely rare actions into menus.
|
||||
|
||||
**Runs in:** Python/Qt client (`main.py`), `MainWindow` class only. No `core/` changes.
|
||||
|
||||
---
|
||||
|
||||
## Problem (from audit)
|
||||
|
||||
- **No information architecture.** No menu bar, no toolbar; status bar explicitly disabled (`setStatusBar(None)`, main.py:4440). Every function is a permanently-visible widget at equal weight.
|
||||
- **`settings_row` overloaded** (main.py:4334–4370): 24 widgets in one non-wrapping `QHBoxLayout` spanning three unrelated domains (encode/clip params, export variants, audio-scan ML). Needs >1500px; window opens at 1100px.
|
||||
- **Stranded controls** — e.g. the workers spinbox sits between Cancel and Delete in the transport row (main.py:4316).
|
||||
- **Weak feedback** — only an 11px `#888` status label at the far-right end of the overflowing settings row (main.py:4364).
|
||||
- **Flat visual hierarchy** — single Fusion stylesheet, scattered inline `setStyleSheet` state swaps, no primary/secondary distinction, no grouping.
|
||||
|
||||
---
|
||||
|
||||
## Chosen approach: Tabbed control deck
|
||||
|
||||
The 3-pane horizontal splitter (Queue · Center · Scan results) is unchanged. The center column is restructured:
|
||||
|
||||
```
|
||||
╔═ File Edit Scan View Help ═══════════════════ Profile:[default▾] [?] ╗ menu bar (+ corner widgets)
|
||||
║ ┌Queue──┐ │ current_file.mp4 │ ┌ Scan results ─────┐ ║
|
||||
║ │+Open │ │ ┌──────────────────────────────────────┐ │ │ [model tabs] │ ║
|
||||
║ │filter │ │ │ VIDEO (mpv) │ │ │ version▾ │ ║
|
||||
║ │┌List┬+┐│ │ │ │ │ │ start end score │ ║
|
||||
║ ││f1 ││ │ │ └──────────────────────────────────────┘ │ │ ... │ ║
|
||||
║ ││f2 ││ │ │ [════════════ timeline ════════════════] │ │ │ ║
|
||||
║ │└────┘ ││ │ [════════════ crop bar ════════════════] │ │ [Neg] [Export] │ ║
|
||||
║ └───────┘ │ ┌─ transport (always visible) ──────────┐ │ └───────────────────┘ ║
|
||||
║ │ │▶ ⏸ x2 x4 🔒 --/-- ··· [Export] +₁+₂ Cancel Delete│ ║
|
||||
║ │ ├─[ Export ]─[ Crop & Track ]─[ Scan ]──┤ ← control deck (tabs) ║
|
||||
║ │ │ (controls for the active tab here) │ ║
|
||||
║ │ └───────────────────────────────────────┘ ║
|
||||
╠═══════════════════════════════════════════════════════════════════════════════╣
|
||||
║ Ready. current file · profile: default · 8 wk ║ status bar
|
||||
╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
```
|
||||
|
||||
**Why tabbed deck:** Replaces the three stacked rows with a compact tab strip. The transport bar (most-used controls) stays always visible above the tabs; settings group by concern behind tabs. Trade-off accepted: viewing Scan + Export controls simultaneously costs a tab switch.
|
||||
|
||||
---
|
||||
|
||||
## Control mapping
|
||||
|
||||
Every current control has an explicit home; nothing is removed.
|
||||
|
||||
### Menu bar (rare / batch / management)
|
||||
|
||||
| Menu | Items |
|
||||
|------|-------|
|
||||
| **File** | Open Files… · Set export folder… · Quit |
|
||||
| **Edit** | Undo *(Ctrl+Z → `_scan_panel.undo`)* · Subprofiles ▸ (Add… / Remove…) |
|
||||
| **Scan** | Scan current · Auto-export · Scan All… · Train classifier… |
|
||||
| **View** | Review mode ✓ · Subcategory markers ▸ · Hide exported ✓ · Show hidden ✓ |
|
||||
| **Help** | Keyboard shortcuts *(? / F1)* · What's new · About |
|
||||
| *corner (right)* | Profile ▾ · `?` |
|
||||
|
||||
*Hard Negatives and Dataset Stats remain inside the Train dialog (main.py:682, 762) — not surfaced separately. Profile new/delete remains driven by the profile combo's `activated` handler.*
|
||||
|
||||
### Transport bar (always visible — playback + one-press export actions)
|
||||
|
||||
`▶ Play · ⏸ Pause · x2 · x4 · 🔒 Lock · --/-- time · ⟨stretch⟩ · next-preview · **Export** · subprofile buttons ₁₂… · Cancel · Delete`
|
||||
|
||||
### Control deck — Export tab
|
||||
`Label · Category · Name · Folder + browse · Format · HW encode · Resize · Duration · Clips · Spread · Workers · Re-export`
|
||||
|
||||
### Control deck — Crop & Track tab
|
||||
`Portrait ratio · 1 random portrait · 1 random square · Track subject`
|
||||
|
||||
### Control deck — Scan tab
|
||||
`Scan model ▾ · ⏲ history · Scan · Auto · Speech · Review · Fuse · Threshold`
|
||||
|
||||
### Left pane (Queue) — unchanged
|
||||
`+ Open · filter · Hide exported · Show hidden · list tabs (tabbed / side-by-side)`
|
||||
|
||||
### Right pane (Scan results) — unchanged structurally
|
||||
|
||||
### Decisions
|
||||
- **Train** → Scan menu only (no deck button).
|
||||
- **Subcategory markers ("Sub")** → View menu submenu (off the deck).
|
||||
- Items appearing in both a menu and a visible control (Hide exported, Review, Scan, Auto) share one handler and stay synced.
|
||||
|
||||
---
|
||||
|
||||
## Status bar
|
||||
|
||||
Restores `QStatusBar` (removes `setStatusBar(None)`):
|
||||
- **Left**: transient feedback — `Exporting 2/3…`, `Scan complete · 14 regions`, `Ready.` — with an optional inline `QProgressBar` for export/scan runs. Replaces `_lbl_status` and the `_status_timer` clear logic.
|
||||
- **Right (permanent widget)**: `current file · profile: <name> · <n> workers`.
|
||||
|
||||
---
|
||||
|
||||
## Visual polish
|
||||
|
||||
Extends the existing dark Fusion theme — no theme change.
|
||||
|
||||
1. **Aligned tab layouts** — each deck tab uses `QFormLayout`/grid so `label : control` pairs align in columns (biggest legibility win vs. today's ragged horizontal runs).
|
||||
2. **Primary/secondary button weight** — **Export** gets an accent style (blue, reusing `#3a6ea8`); Cancel/Delete read as secondary/destructive. The existing **red Export = "armed to overwrite"** state (main.py:5403) is preserved as a distinct state layered on top.
|
||||
3. **Consistent toggle states** — x2 / x4 / 🔒 Lock / Review are checkable; one global `:checked` style replaces Lock's ad-hoc inline `#4a3000` swap (main.py:5705).
|
||||
4. **Spacing rhythm** — uniform margins/spacing; **fixed deck height** (= tallest tab) so the video never resizes on tab switch.
|
||||
5. **Label cleanup** — de-abbreviate where cheap (`Thr→Threshold`, `Dur→Duration`); replace cryptic `⏲` with a clearer history affordance.
|
||||
6. **One stylesheet block** — fold scattered inline `setStyleSheet` calls into the central sheet (tabs, separators, status bar, toggles, primary button); keep per-widget overrides only for genuine state changes (overwrite-armed Export).
|
||||
|
||||
---
|
||||
|
||||
## Implementation notes & risks
|
||||
|
||||
- **Preserve all signal wiring.** Controls are re-parented into new layouts, but every existing `connect()` and the controls' object identities are kept — this is a layout move, not a rewrite of handlers.
|
||||
- **Preserve all shortcuts.** The `QShortcut` block (main.py:4450–4483) and `_KeyFilter` focus suppression are untouched. Menu items reuse the same handler methods and may display the matching shortcut text.
|
||||
- **Fixed deck height** prevents video-area jump when switching tabs.
|
||||
- **Synced menu/button state** — checkable menu items (Review, Hide exported) and their visible toggles must reflect each other; route both through the existing handler and update both widgets.
|
||||
- **Profile combo** moves to a menu-bar corner widget but keeps its existing `activated` → new/delete/switch logic intact.
|
||||
- Risk: re-parenting a large `__init__` is error-prone. Mitigate by moving controls in small, independently-runnable stages (menu bar → status bar → deck tabs → transport bar → polish), launching the app after each.
|
||||
|
||||
---
|
||||
|
||||
## What this does NOT do
|
||||
|
||||
- No change to export, scan, tracking, or DB logic — `core/` untouched.
|
||||
- No change to keyboard shortcuts or timeline mouse interactions.
|
||||
- No theme change — stays dark Fusion.
|
||||
- No new features — every control already exists; this is rehousing + polish.
|
||||
- No change to the Queue or Scan-results panes' internal structure.
|
||||
@@ -0,0 +1,547 @@
|
||||
# Main Window UI Restructure — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Re-house `MainWindow`'s ~50 flat controls into a menu bar (rare actions), an always-visible transport bar, a 3-tab control deck (Export / Crop & Track / Scan), and a real status bar — then a visual-polish pass — without changing any behavior, shortcut, or `core/` logic.
|
||||
|
||||
**Architecture:** Pure layout reorganization inside `main.py`'s `MainWindow`. Existing widget objects and every `connect()` are **preserved and re-parented**, not recreated. The monster `__init__` is incrementally broken into `_build_*` helper methods (stays single-file — matches the project's architecture). Companion design doc: `docs/plans/2026-06-13-ui-restructure-design.md`.
|
||||
|
||||
**Tech Stack:** Python 3.11+, PyQt6, pytest. App entry: `main.py`; launch via `./8cut.sh`.
|
||||
|
||||
---
|
||||
|
||||
## Conventions for every task
|
||||
|
||||
- **Line references drift** as edits land. Always locate by the named symbol (method/variable), not the line number alone. Numbers are the *starting* anchors as of this plan.
|
||||
- **Authoritative verification is a manual launch.** After each task, run `./8cut.sh`, load a video, and confirm the task's controls work AND prior behavior is intact (play, scrub, export, scan). Use the `verify` skill for structured manual checks.
|
||||
- **Structure test is the safety net.** `tests/test_ui_structure.py` (built in Task 0.2) constructs `MainWindow` and asserts containment invariants. It **skips gracefully** if construction fails (e.g. no GL for `MpvWidget` in headless CI), so it never blocks `core/` tests. Run with a display: `pytest tests/test_ui_structure.py -v`.
|
||||
- **Commit after every task.** Small, reversible commits. Commit message convention matches the repo (`feat:`/`fix:`/`refactor:`/`change:`).
|
||||
- **Do not touch** `core/`, export/scan/tracking logic, the `QShortcut` block (around main.py:4450–4483), `_KeyFilter`, or `TimelineWidget` mouse handling.
|
||||
|
||||
---
|
||||
|
||||
## Stage 0 — Branch & safety net
|
||||
|
||||
### Task 0.1: Create a working branch
|
||||
|
||||
**Step 1:** Confirm clean intent and branch off `master`:
|
||||
```bash
|
||||
git switch -c ui-restructure
|
||||
```
|
||||
**Step 2:** Verify: `git branch --show-current` → `ui-restructure`.
|
||||
(The repo has pre-existing untracked/modified files; leave them alone — they are not part of this work.)
|
||||
|
||||
### Task 0.2: Add the structure-test safety net
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_ui_structure.py`
|
||||
|
||||
**Step 1: Write the test harness + baseline invariant**
|
||||
|
||||
```python
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# A real platform is needed because MpvWidget creates a GL context.
|
||||
# If construction fails for any environment reason, skip — this test is a
|
||||
# best-effort structural net, not a gate on core/ tests.
|
||||
pytestmark = pytest.mark.gui
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
from PyQt6.QtWidgets import QApplication
|
||||
inst = QApplication.instance() or QApplication([])
|
||||
yield inst
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def win(app):
|
||||
try:
|
||||
from main import MainWindow
|
||||
w = MainWindow()
|
||||
except Exception as e: # GL/mpv/display unavailable, etc.
|
||||
pytest.skip(f"MainWindow could not be constructed here: {e}")
|
||||
yield w
|
||||
w.close()
|
||||
w.deleteLater()
|
||||
|
||||
|
||||
def _descendant_object_names(widget):
|
||||
"""All objectNames in a widget's child tree (for containment asserts)."""
|
||||
return {c.objectName() for c in widget.findChildren(object) if c.objectName()}
|
||||
|
||||
|
||||
def test_window_constructs(win):
|
||||
assert win.windowTitle() == "8-cut"
|
||||
```
|
||||
|
||||
**Step 2: Run it**
|
||||
|
||||
Run: `pytest tests/test_ui_structure.py -v`
|
||||
Expected: `test_window_constructs` PASSES (with a display) or SKIPS (headless). Either is acceptable — it must not ERROR.
|
||||
|
||||
**Step 3:** Register the `gui` marker to silence warnings.
|
||||
|
||||
Modify `conftest.py` — append:
|
||||
```python
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "gui: constructs Qt widgets; needs a display")
|
||||
```
|
||||
|
||||
**Step 4: Confirm core tests still pass**
|
||||
|
||||
Run: `pytest tests/test_utils.py tests/test_db.py -q`
|
||||
Expected: PASS (unchanged).
|
||||
|
||||
**Step 5: Commit**
|
||||
```bash
|
||||
git add tests/test_ui_structure.py conftest.py
|
||||
git commit -m "test: add MainWindow structure smoke test (skips headless)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Stage 1 — Menu bar
|
||||
|
||||
Add a `QMenuBar` whose actions reuse existing handler methods. Move the profile combo and `?` button into menu-bar corner widgets. Keep the original buttons that also live elsewhere (Scan, Auto) — menus and buttons share handlers.
|
||||
|
||||
### Task 1.1: Extract a `_build_menubar()` and add the five menus
|
||||
|
||||
**Files:**
|
||||
- Modify: `main.py` `MainWindow.__init__` (call site) and add method `_build_menubar`
|
||||
|
||||
**Step 1:** Add the method (place near other `_build`/setup helpers, e.g. after `__init__`). Wire each action to the **existing** handler method:
|
||||
|
||||
```python
|
||||
def _build_menubar(self) -> None:
|
||||
from PyQt6.QtGui import QAction
|
||||
mb = self.menuBar()
|
||||
|
||||
# File
|
||||
m_file = mb.addMenu("&File")
|
||||
m_file.addAction("Open Files…", self._on_open_files)
|
||||
m_file.addAction("Set export folder…", self._pick_folder)
|
||||
m_file.addSeparator()
|
||||
m_file.addAction("Quit", self.close)
|
||||
|
||||
# Edit
|
||||
m_edit = mb.addMenu("&Edit")
|
||||
self._act_undo = m_edit.addAction("Undo scan edit", self._scan_panel.undo)
|
||||
self._act_undo.setShortcut("Ctrl+Z")
|
||||
m_edit.addSeparator()
|
||||
m_subs = m_edit.addMenu("Subprofiles")
|
||||
m_subs.addAction("Add…", self._new_subprofile)
|
||||
self._menu_subprofiles_remove = m_subs.addMenu("Remove")
|
||||
self._rebuild_remove_subprofile_menu() # built in Task 4.x
|
||||
|
||||
# Scan
|
||||
m_scan = mb.addMenu("&Scan")
|
||||
m_scan.addAction("Scan current", self._start_scan)
|
||||
m_scan.addAction("Auto-export", self._auto_export)
|
||||
m_scan.addSeparator()
|
||||
m_scan.addAction("Scan All…", self._start_scan_all)
|
||||
m_scan.addAction("Train classifier…", self._open_train_dialog)
|
||||
|
||||
# View
|
||||
m_view = mb.addMenu("&View")
|
||||
self._act_review = m_view.addAction("Review mode")
|
||||
self._act_review.setCheckable(True)
|
||||
self._act_review.toggled.connect(self._btn_scan_mode.setChecked)
|
||||
m_view.addAction("Subcategory markers…", self._show_subcat_menu)
|
||||
m_view.addSeparator()
|
||||
self._act_hide_exported = m_view.addAction("Hide exported")
|
||||
self._act_hide_exported.setCheckable(True)
|
||||
self._act_hide_exported.toggled.connect(self._chk_hide_exported.setChecked)
|
||||
self._chk_hide_exported.toggled.connect(self._act_hide_exported.setChecked)
|
||||
self._act_show_hidden = m_view.addAction("Show hidden")
|
||||
self._act_show_hidden.setCheckable(True)
|
||||
self._act_show_hidden.toggled.connect(self._btn_show_hidden.setChecked)
|
||||
self._btn_show_hidden.toggled.connect(self._act_show_hidden.setChecked)
|
||||
|
||||
# Help
|
||||
m_help = mb.addMenu("&Help")
|
||||
m_help.addAction("Keyboard shortcuts", self._show_shortcuts).setShortcut("F1")
|
||||
m_help.addAction("What's new", self._show_changelog)
|
||||
m_help.addAction("About", self._show_about) # tiny method, Task 1.3
|
||||
```
|
||||
|
||||
> **Sync note:** `QAction.toggled`/`QAbstractButton.toggled` do not re-emit when the value is unchanged, so the bidirectional `setChecked` connections (Review, Hide exported, Show hidden) cannot loop. `_btn_scan_mode` → `_act_review` reverse sync is added in Task 3.4 once the button is in the Scan tab.
|
||||
|
||||
**Step 2:** Stub the two small new methods referenced above:
|
||||
```python
|
||||
def _show_about(self) -> None:
|
||||
QMessageBox.about(self, "About 8-cut",
|
||||
f"<b>8-cut</b> v{self.APP_VERSION}<br>"
|
||||
"8-second clips for foley datasets.")
|
||||
|
||||
def _rebuild_remove_subprofile_menu(self) -> None:
|
||||
self._menu_subprofiles_remove.clear()
|
||||
for name in self._subprofiles:
|
||||
self._menu_subprofiles_remove.addAction(
|
||||
name, lambda _=False, n=name: self._remove_subprofile(n))
|
||||
self._menu_subprofiles_remove.setEnabled(bool(self._subprofiles))
|
||||
```
|
||||
|
||||
**Step 3:** Call `self._build_menubar()` in `__init__`, **after** `self._scan_panel` and all referenced buttons exist (i.e. just before/after the splitter assembly around main.py:4429). The scan panel is created at main.py:4414, so place the call after that.
|
||||
|
||||
**Step 4 (manual verify):** `./8cut.sh` → menu bar shows File/Edit/Scan/View/Help; each item triggers its action; Ctrl+Z still undoes scan edits; F1 shows shortcuts.
|
||||
|
||||
**Step 5:** Commit: `feat: add menu bar wired to existing handlers`.
|
||||
|
||||
### Task 1.2: Move profile combo + `?` into menu-bar corner
|
||||
|
||||
**Files:** Modify `main.py` — `top_bar` assembly (around main.py:4290–4294) and `_build_menubar`.
|
||||
|
||||
**Step 1:** Remove `self._cmb_profile` and `self._btn_shortcuts` (and the `"Profile:"` `QLabel`) from `top_bar`. Keep `self._lbl_file` in `top_bar` (it stays as the slim filename header above the video).
|
||||
|
||||
**Step 2:** In `_build_menubar`, set a corner widget:
|
||||
```python
|
||||
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QLabel
|
||||
corner = QWidget()
|
||||
ch = QHBoxLayout(corner)
|
||||
ch.setContentsMargins(0, 0, 6, 0)
|
||||
ch.addWidget(QLabel("Profile:"))
|
||||
ch.addWidget(self._cmb_profile)
|
||||
ch.addWidget(self._btn_shortcuts)
|
||||
mb.setCornerWidget(corner, Qt.Corner.TopRightCorner)
|
||||
```
|
||||
(Build the corner widget at the end of `_build_menubar`, after `self._cmb_profile` exists — it is created at main.py:4272.)
|
||||
|
||||
**Step 3 (manual verify):** Profile dropdown works (switch/new/delete); `?` opens shortcuts; filename still shows above the video.
|
||||
|
||||
**Step 4:** Commit: `change: move profile selector and help into menu-bar corner`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 2 — Status bar
|
||||
|
||||
### Task 2.1: Restore `QStatusBar` and route `_show_status` to it
|
||||
|
||||
**Files:** Modify `main.py` — `__init__` (`setStatusBar(None)` at main.py:4440, `_lbl_status`/`_status_timer` at main.py:4364–4370) and `_show_status` (main.py:5065).
|
||||
|
||||
**Step 1:** Replace `self.setStatusBar(None)` with a real status bar built in a helper:
|
||||
```python
|
||||
def _build_status_bar(self) -> None:
|
||||
sb = self.statusBar()
|
||||
self._status_perm = QLabel("")
|
||||
self._status_perm.setStyleSheet("color: #888;")
|
||||
sb.addPermanentWidget(self._status_perm)
|
||||
self._update_status_perm()
|
||||
|
||||
def _update_status_perm(self) -> None:
|
||||
name = os.path.basename(self._file_path) if self._file_path else "—"
|
||||
self._status_perm.setText(
|
||||
f"{name} · profile: {self._profile()} · {self._spn_workers.value()} workers")
|
||||
```
|
||||
Call `self._build_status_bar()` in `__init__` near the menubar call.
|
||||
|
||||
**Step 2:** Rewrite `_show_status` to use the status bar (this subsumes `_status_timer`):
|
||||
```python
|
||||
def _show_status(self, msg: str, timeout: int = 0) -> None:
|
||||
"""Show a transient message in the status bar. timeout in ms (0 = sticky)."""
|
||||
self.statusBar().showMessage(msg, timeout)
|
||||
```
|
||||
|
||||
**Step 3:** Delete `self._lbl_status`, `self._status_timer`, and `settings_row.addWidget(self._lbl_status)` (main.py:4364–4370). Remove the `_status_timer.timeout` connection.
|
||||
|
||||
**Step 4:** Keep `_update_status_perm()` fresh — call it where file/profile/workers change: end of `_after_load`, in `_on_profile_activated`, and in the `_spn_workers.valueChanged` lambda.
|
||||
|
||||
**Step 5 (manual verify):** Start an export → status text appears bottom-left and auto-clears; bottom-right shows file · profile · workers and updates on file/profile/worker change.
|
||||
|
||||
**Step 6:** Commit: `feat: real status bar replaces inline status label`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 3 — Control deck (the core move)
|
||||
|
||||
Build a fixed-height `QTabWidget` with three tab pages, then **re-parent** the existing controls from `path_row` and `settings_row` into them. Give each page an `objectName` for the structure test. Do tabs one at a time so the app stays runnable.
|
||||
|
||||
### Task 3.1: Build the empty deck and mount it
|
||||
|
||||
**Files:** Modify `main.py` — `right_layout` assembly (main.py:4372–4382).
|
||||
|
||||
**Step 1:** Add a helper that creates the deck and three empty pages:
|
||||
```python
|
||||
def _build_control_deck(self) -> "QTabWidget":
|
||||
from PyQt6.QtWidgets import QTabWidget, QWidget
|
||||
deck = QTabWidget()
|
||||
deck.setObjectName("control_deck")
|
||||
deck.setDocumentMode(True)
|
||||
self._tab_export = QWidget(); self._tab_export.setObjectName("export_tab")
|
||||
self._tab_crop = QWidget(); self._tab_crop.setObjectName("crop_tab")
|
||||
self._tab_scan = QWidget(); self._tab_scan.setObjectName("scan_tab")
|
||||
deck.addTab(self._tab_export, "Export")
|
||||
deck.addTab(self._tab_crop, "Crop && Track")
|
||||
deck.addTab(self._tab_scan, "Scan")
|
||||
self._control_deck = deck
|
||||
return deck
|
||||
```
|
||||
|
||||
**Step 2:** In `right_layout`, **keep** `transport_row` for now, but replace the `path_row` and `settings_row` additions with the deck:
|
||||
- Remove `right_layout.addLayout(path_row)` and `right_layout.addLayout(settings_row)`.
|
||||
- Add `right_layout.addWidget(self._build_control_deck())`.
|
||||
- Leave the `path_row`/`settings_row` *construction* in place for this task (the widgets are still parented to nothing visible) — they get moved into tabs in 3.2–3.4. **App is briefly missing those controls between 3.1 and 3.4; that's expected mid-stage.**
|
||||
|
||||
**Step 3 (manual verify):** App launches; three empty tabs appear under the transport bar; switching tabs doesn't resize the video (height fixed in Task 3.5).
|
||||
|
||||
**Step 4:** Commit: `refactor: add empty 3-tab control deck under transport`.
|
||||
|
||||
### Task 3.2: Populate the Export tab
|
||||
|
||||
**Files:** Modify `main.py` — move widgets from `path_row` (main.py:4322–4331) and the encode/clip parts of `settings_row` (main.py:4334–4348) plus `_spn_workers` (main.py:4213).
|
||||
|
||||
**Step 1:** Build the Export tab with an aligned grid:
|
||||
```python
|
||||
def _build_export_tab(self) -> None:
|
||||
from PyQt6.QtWidgets import QGridLayout, QLabel, QHBoxLayout
|
||||
g = QGridLayout(self._tab_export)
|
||||
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||
# Row 0: annotation
|
||||
g.addWidget(QLabel("Label:"), 0, 0); g.addWidget(self._txt_label, 0, 1)
|
||||
g.addWidget(QLabel("Cat:"), 0, 2); g.addWidget(self._cmb_category, 0, 3)
|
||||
g.addWidget(QLabel("Name:"), 0, 4); g.addWidget(self._txt_name, 0, 5)
|
||||
# Row 1: output path
|
||||
folder_row = QHBoxLayout()
|
||||
folder_row.addWidget(self._txt_folder, 1); folder_row.addWidget(self._btn_folder)
|
||||
g.addWidget(QLabel("Folder:"), 1, 0); g.addLayout(folder_row, 1, 1, 1, 5)
|
||||
# Row 2: encode / clip params
|
||||
g.addWidget(QLabel("Format:"), 2, 0); g.addWidget(self._cmb_format, 2, 1)
|
||||
g.addWidget(self._chk_hw, 2, 2)
|
||||
g.addWidget(QLabel("Resize:"), 2, 3); g.addWidget(self._spn_resize, 2, 4)
|
||||
# Row 3: batch params + actions
|
||||
g.addWidget(QLabel("Duration:"), 3, 0); g.addWidget(self._spn_clip_dur, 3, 1)
|
||||
g.addWidget(QLabel("Clips:"), 3, 2); g.addWidget(self._spn_clips, 3, 3)
|
||||
g.addWidget(QLabel("Spread:"), 3, 4); g.addWidget(self._spn_spread, 3, 5)
|
||||
g.addWidget(QLabel("Workers:"), 4, 0); g.addWidget(self._spn_workers, 4, 1)
|
||||
g.addWidget(self._btn_reexport, 4, 5)
|
||||
```
|
||||
Call it from `_build_control_deck` (or right after, in `__init__`).
|
||||
|
||||
**Step 2:** Delete the now-duplicate `addWidget` calls for these widgets from `path_row` and `settings_row` construction. (Re-parenting via `addWidget` into the grid auto-removes them from the old layout, but remove the dead lines to keep `__init__` honest.)
|
||||
|
||||
**Step 3 (manual verify):** Export tab shows aligned Label/Cat/Name, Folder+browse, Format/HW/Resize, Duration/Clips/Spread/Workers/Re-export. Change each → still persists to `QSettings` and updates the timeline span / next-label as before. Export still works (E).
|
||||
|
||||
**Step 4:** Commit: `refactor: move export & encode controls into Export tab`.
|
||||
|
||||
### Task 3.3: Populate the Crop & Track tab
|
||||
|
||||
**Files:** Modify `main.py` — move `_cmb_portrait`, `_chk_rand_portrait`, `_chk_rand_square`, `_chk_track` from `settings_row` (main.py:4337, 4349–4351).
|
||||
|
||||
**Step 1:**
|
||||
```python
|
||||
def _build_crop_tab(self) -> None:
|
||||
from PyQt6.QtWidgets import QGridLayout, QLabel
|
||||
g = QGridLayout(self._tab_crop)
|
||||
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||
g.addWidget(QLabel("Portrait:"), 0, 0); g.addWidget(self._cmb_portrait, 0, 1)
|
||||
g.addWidget(self._chk_rand_portrait, 1, 0, 1, 2)
|
||||
g.addWidget(self._chk_rand_square, 2, 0, 1, 2)
|
||||
g.addWidget(self._chk_track, 3, 0, 1, 2)
|
||||
g.setRowStretch(4, 1); g.setColumnStretch(2, 1)
|
||||
```
|
||||
|
||||
**Step 2:** Remove those four widgets' old `settings_row.addWidget` lines.
|
||||
|
||||
**Step 3 (manual verify):** Crop & Track tab shows the four controls; portrait ratio still toggles the crop overlay/crop-bar; random/track checkboxes persist.
|
||||
|
||||
**Step 4:** Commit: `refactor: move crop & track controls into their tab`.
|
||||
|
||||
### Task 3.4: Populate the Scan tab (and drop menu-only buttons)
|
||||
|
||||
**Files:** Modify `main.py` — move scan widgets from `settings_row` (main.py:4352–4362). Buttons that became **menu-only** (Train, Scan All, Sub) are NOT added to the tab and are deleted.
|
||||
|
||||
**Step 1:**
|
||||
```python
|
||||
def _build_scan_tab(self) -> None:
|
||||
from PyQt6.QtWidgets import QGridLayout, QLabel, QHBoxLayout
|
||||
g = QGridLayout(self._tab_scan)
|
||||
g.setContentsMargins(8, 6, 8, 6); g.setHorizontalSpacing(8); g.setVerticalSpacing(6)
|
||||
model_row = QHBoxLayout()
|
||||
model_row.addWidget(self._cmb_scan_model, 1); model_row.addWidget(self._btn_model_history)
|
||||
g.addWidget(QLabel("Model:"), 0, 0); g.addLayout(model_row, 0, 1, 1, 3)
|
||||
g.addWidget(self._btn_scan, 1, 0); g.addWidget(self._btn_auto_export, 1, 1)
|
||||
g.addWidget(self._btn_speech, 1, 2); g.addWidget(self._btn_scan_mode, 1, 3)
|
||||
g.addWidget(self._spn_auto_fuse, 2, 0); g.addWidget(self._sld_threshold, 2, 1)
|
||||
g.setColumnStretch(3, 1)
|
||||
```
|
||||
|
||||
**Step 2:** Reverse-sync Review with the View menu (the forward sync was added in Task 1.1):
|
||||
```python
|
||||
self._btn_scan_mode.toggled.connect(self._act_review.setChecked)
|
||||
```
|
||||
Add this right after `_build_scan_tab` runs (both `_btn_scan_mode` and `_act_review` exist by then).
|
||||
|
||||
**Step 3:** Delete the menu-only buttons and their `settings_row` lines: `self._btn_train` (main.py:4167–4170), `self._btn_scan_all` (main.py:4172–4174), `self._btn_hide_subcats` (main.py:4154–4157). Their handlers (`_open_train_dialog`, `_start_scan_all`, `_show_subcat_menu`) stay — now reached via menus.
|
||||
|
||||
**Step 4:** Re-anchor `_show_subcat_menu` (main.py:5989) so it no longer depends on the deleted `_btn_hide_subcats`:
|
||||
```python
|
||||
# was: self._btn_hide_subcats.mapToGlobal(self._btn_hide_subcats.rect().bottomLeft())
|
||||
from PyQt6.QtGui import QCursor
|
||||
menu.exec(QCursor.pos())
|
||||
```
|
||||
Apply to **both** `exec` call sites in that method.
|
||||
|
||||
**Step 5 (manual verify):** Scan tab shows Model+history, Scan/Auto/Speech/Review, Fuse/Threshold. `Scan` runs; `Review` toggles and stays in sync with View ▸ Review mode (both directions); View ▸ Subcategory markers… opens the full popup near the cursor; Scan ▸ Scan All / Train still work.
|
||||
|
||||
**Step 6:** Commit: `refactor: move scan controls into Scan tab; Train/ScanAll/Sub to menus`.
|
||||
|
||||
### Task 3.5: Fix deck height; remove dead `path_row`/`settings_row`
|
||||
|
||||
**Files:** Modify `main.py` — `__init__`.
|
||||
|
||||
**Step 1:** The `path_row`/`settings_row` `QHBoxLayout`s should now be empty. Delete their construction blocks entirely (main.py:4321–4370 minus what was already removed), including the `self._transport_row = transport_row` line only if unused elsewhere (it IS used by `_rebuild_subprofile_buttons` — keep `transport_row`).
|
||||
|
||||
**Step 2:** Pin the deck height so tab switches don't move the video:
|
||||
```python
|
||||
self._control_deck.setFixedHeight(self._control_deck.sizeHint().height())
|
||||
```
|
||||
Call after all three tabs are built. If the tallest tab (Export, 5 rows) clips, set an explicit value instead (e.g. `setFixedHeight(150)`); confirm visually.
|
||||
|
||||
**Step 3 (manual verify):** Switching Export↔Crop↔Scan keeps the video size constant; no clipped controls; all three tabs fully usable.
|
||||
|
||||
**Step 4:** Commit: `refactor: fix control-deck height; drop dead settings rows`.
|
||||
|
||||
### Task 3.6: Extend the structure test for the deck
|
||||
|
||||
**Files:** Modify `tests/test_ui_structure.py`.
|
||||
|
||||
**Step 1:** Add invariants:
|
||||
```python
|
||||
def test_menubar_has_expected_menus(win):
|
||||
titles = [m.title().replace("&", "") for m in win.menuBar().findChildren(type(win.menuBar().addMenu("")))]
|
||||
for expected in ("File", "Edit", "Scan", "View", "Help"):
|
||||
assert any(expected == t for t in titles)
|
||||
|
||||
def test_status_bar_exists(win):
|
||||
assert win.statusBar() is not None
|
||||
|
||||
def test_workers_spinbox_in_export_tab(win):
|
||||
from PyQt6.QtWidgets import QSpinBox
|
||||
assert win._spn_workers in win._tab_export.findChildren(QSpinBox)
|
||||
|
||||
def test_scan_button_in_scan_tab(win):
|
||||
from PyQt6.QtWidgets import QPushButton
|
||||
assert win._btn_scan in win._tab_scan.findChildren(QPushButton)
|
||||
|
||||
def test_portrait_combo_in_crop_tab(win):
|
||||
from PyQt6.QtWidgets import QComboBox
|
||||
assert win._cmb_portrait in win._tab_crop.findChildren(QComboBox)
|
||||
```
|
||||
(Adjust the menu-title introspection if the helper is awkward; the key invariants are the tab-containment ones.)
|
||||
|
||||
**Step 2:** Run: `pytest tests/test_ui_structure.py -v` → PASS with a display (or SKIP headless).
|
||||
|
||||
**Step 3:** Commit: `test: assert control-deck containment invariants`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 4 — Transport bar tidy & subprofile menu sync
|
||||
|
||||
### Task 4.1: Confirm transport bar contents; keep subprofile export buttons inline
|
||||
|
||||
**Files:** Modify `main.py` — `transport_row` (main.py:4296–4319).
|
||||
|
||||
**Step 1:** The workers spinbox was moved in Task 3.2 — confirm `transport_row.addWidget(self._spn_workers)` is gone. Remaining transport order: Play, Pause, x2, x4, Lock, time, stretch, next-label, **Export**, subprofile buttons, `+` (add subprofile), Cancel, Delete. Leave subprofile **export** buttons inline (they carry the 1–9 shortcuts and belong with Export).
|
||||
|
||||
**Step 2:** Keep the inline `+` add-subprofile button, but also ensure the Edit ▸ Subprofiles ▸ Remove submenu is rebuilt whenever subprofiles change. In `_rebuild_subprofile_buttons` (main.py:5530-ish) and after add/remove, call `self._rebuild_remove_subprofile_menu()`.
|
||||
|
||||
**Step 3 (manual verify):** Transport row reads cleanly; adding/removing a subprofile updates both the inline buttons and Edit ▸ Subprofiles ▸ Remove; number keys 1–9 still export to subprofiles.
|
||||
|
||||
**Step 4:** Commit: `change: tidy transport row; sync subprofile remove menu`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 5 — Visual polish
|
||||
|
||||
All Stage 5 verification is **manual** (visual). Take a screenshot before 5.1 for comparison (use the `run`/`verify` skill).
|
||||
|
||||
### Task 5.1: Consolidate the stylesheet (tabs, status bar, toggles, primary button)
|
||||
|
||||
**Files:** Modify `main.py` — global stylesheet in `main()` (main.py:3811–3827).
|
||||
|
||||
**Step 1:** Extend the central sheet (append rules; keep existing ones):
|
||||
```css
|
||||
QTabWidget::pane { border: 1px solid #444; border-radius: 3px; top: -1px; }
|
||||
QTabBar::tab { background: #2a2a2a; color: #bbb; padding: 5px 12px;
|
||||
border: 1px solid #444; border-bottom: none;
|
||||
border-top-left-radius: 3px; border-top-right-radius: 3px; }
|
||||
QTabBar::tab:selected { background: #333; color: #fff; }
|
||||
QPushButton:checked { background: #4a3000; border-color: #ffd230; color: #fff; }
|
||||
QStatusBar { background: #1a1a1a; color: #bbb; }
|
||||
QStatusBar::item { border: none; }
|
||||
QPushButton#primary { background: #3a6ea8; border-color: #4f86c6; color: #fff; }
|
||||
QPushButton#primary:hover { background: #4f86c6; }
|
||||
QMenuBar { background: #1e1e1e; } QMenuBar::item:selected { background: #3a6ea8; }
|
||||
QMenu { background: #2a2a2a; border: 1px solid #555; }
|
||||
QMenu::item:selected { background: #3a6ea8; }
|
||||
```
|
||||
|
||||
**Step 2:** Mark Export primary: `self._btn_export.setObjectName("primary")`.
|
||||
|
||||
**Step 3:** Replace Lock's inline stylesheet swap (main.py:5705) — since `QPushButton:checked` now styles all toggles, delete the two `self._btn_lock.setStyleSheet(...)` lines in `_on_lock_toggled` (keep the rest of the handler).
|
||||
|
||||
**Step 4 (manual verify):** Tabs, menus, status bar, and checked toggles (x2/x4/Lock/Review) all read consistently; Export stands out as primary; Lock still highlights when active.
|
||||
|
||||
**Step 5:** Commit: `style: unify tab/menu/statusbar/toggle styling; mark Export primary`.
|
||||
|
||||
### Task 5.2: Preserve the "armed to overwrite" Export state
|
||||
|
||||
**Files:** Inspect `main.py` — the red-Export swaps (main.py:5403, and the resets at 4960/5211/5447/7170/7199/7218).
|
||||
|
||||
**Step 1:** These set/clear `self._btn_export.setStyleSheet("QPushButton { background: #6a3030; ... }")` to mean "this export will overwrite". With Export now `objectName("primary")`, an empty `setStyleSheet("")` reset reverts to the **primary** look (good). Confirm the armed (red) state still visually overrides primary — inline stylesheet beats the objectName rule, so it does.
|
||||
|
||||
**Step 2 (manual verify):** Select a marker for re-export → Export turns red (armed); deselect → returns to blue primary; export → resets correctly.
|
||||
|
||||
**Step 3:** Commit (only if changes were needed): `fix: keep armed-overwrite Export state over primary style`.
|
||||
|
||||
### Task 5.3: Label cleanup
|
||||
|
||||
**Files:** Modify `main.py` — prefixes/labels.
|
||||
|
||||
**Step 1:** De-abbreviate where free: `_sld_threshold.setPrefix("Threshold: ")` (main.py:4207) → keep short if it overflows the tab; `_spn_auto_fuse` prefix stays `"Fuse: "`. Replace the `⏲` history button text with a tooltip-backed `"History"` or a clearer glyph; keep `setFixedWidth` generous enough.
|
||||
|
||||
**Step 2 (manual verify):** Labels legible; nothing clipped in the Scan tab.
|
||||
|
||||
**Step 3:** Commit: `style: de-abbreviate scan labels`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 6 — Finalize
|
||||
|
||||
### Task 6.1: Full regression pass
|
||||
|
||||
**Step 1 (manual, use `verify` skill):** With a real video loaded, confirm end-to-end: scrub/play/pause/speed/lock; export (E) single + batch + subprofile (1–9); re-export; delete; portrait crop + random + track; scan + auto + speech + review + threshold/fuse; scan-all; train dialog opens; profile switch; queue filter/hide/show-hidden; Ctrl+Z undo; F1/`?` shortcuts.
|
||||
|
||||
**Step 2:** Run `pytest -q` (all suites). Expected: `core/` PASS; `test_ui_structure` PASS (display) or SKIP.
|
||||
|
||||
### Task 6.2: Docs & changelog
|
||||
|
||||
**Files:** Modify `README.md` (UI/shortcuts sections if any references moved) and the in-app `CHANGELOG` list (main.py:4500) — bump `APP_VERSION` and add a "UI restructure" entry so the What's-new dialog announces it.
|
||||
|
||||
**Step 1:** Add changelog entry summarizing: menu bar, tabbed control deck, status bar, visual polish; note all shortcuts unchanged.
|
||||
|
||||
**Step 2:** Commit: `docs: changelog + README for UI restructure`.
|
||||
|
||||
### Task 6.3: Hand off the branch
|
||||
|
||||
**Step 1:** `git log --oneline master..ui-restructure` — review the commit series.
|
||||
**Step 2:** Offer the user: merge to `master`, open a PR, or keep iterating (use `finishing-a-development-branch` skill).
|
||||
|
||||
---
|
||||
|
||||
## Risk register
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|-----------|
|
||||
| Re-parenting breaks a `connect()` | Widgets keep identity; only layout membership changes. Manual launch after every task catches breakage immediately. |
|
||||
| Headless test can't build `MpvWidget` | Structure test skips on construction failure; manual launch is authoritative. |
|
||||
| Menu/button state desync (Review, Hide exported) | Bidirectional `setChecked` (no re-emit on equal value → no loop); verified manually in 3.4. |
|
||||
| Subcat popup anchored to deleted button | Re-anchored to `QCursor.pos()` in Task 3.4. |
|
||||
| Deck height jump on tab switch | `setFixedHeight` in Task 3.5. |
|
||||
| Armed-overwrite red Export lost under primary style | Inline stylesheet overrides objectName rule; verified in 5.2. |
|
||||
| Mid-Stage-3 app missing controls | Expected between 3.1–3.4; each sub-task is still committable and launchable. |
|
||||
|
||||
## What this plan does NOT change
|
||||
|
||||
`core/` logic · export/scan/tracking/DB behavior · keyboard shortcuts · timeline mouse interactions · the Queue and Scan-results panes' internals · the dark Fusion theme.
|
||||
@@ -0,0 +1,96 @@
|
||||
# Multi-pane Control Deck — Design + Plan Addendum
|
||||
|
||||
> Addendum to `2026-06-13-ui-restructure-design.md` / `-implementation.md`. Same branch (`ui-restructure`), same constraints (preserve behavior; reorg/feature only; no `core/` changes).
|
||||
|
||||
**Goal:** Let the control-deck panels (Export / Crop & Track / Scan) optionally show **side-by-side as resizable columns** instead of one-at-a-time tabs — mirroring the existing playlist pin→side-by-side pattern.
|
||||
|
||||
> **Revision (post-use, 2026-06-13):** The first implementation showed unpinned panels as a "leftover" tab-column so nothing was hidden — but in use, pinning 2 panels then displayed 3 columns, which read as "all three pinned" and was confusing (and inconsistent with what persisted). **Revised behavior:** the split view shows **exactly the pinned panels** as columns (pin 2 → 2 columns, pin 3 → 3). Unpinned panels are not shown as columns. Because the right-click-tab "Show side-by-side" gesture only works in tabbed mode, an always-available **View ▸ Side-by-side panels ▸ Export / Crop / Scan** submenu of checkable toggles is the way to pin/unpin any panel (including adding a 3rd while already in split view). The `if leftovers:` block below is removed; the View submenu + its sync in `_refresh_deck_layout` replace it.
|
||||
|
||||
**Mirror these existing playlist members** (study them — the deck is a simpler, fixed-3-panel version): `_PlaylistTabBar` (main.py:3284), `_refresh_layout` (~4872), `_on_pin_toggle`/`_on_unpin` (~4942), `_detach_all_pws`/`_clear_split_container` (~4861), and the `_list_stack`/`_split_container` setup (~3916–3923).
|
||||
|
||||
---
|
||||
|
||||
## Design
|
||||
|
||||
### Panel identity
|
||||
The deck's three pages (`_tab_export`, `_tab_crop`, `_tab_scan`) each get three attributes (set in `_build_control_deck`):
|
||||
- `_pinned: bool = False`
|
||||
- `_label: str` — "Export" / "Crop & Track" / "Scan"
|
||||
- `_deck_key: str` — "export" / "crop" / "scan" (stable key for persistence)
|
||||
|
||||
Keep an ordered list `self._deck_panels = [self._tab_export, self._tab_crop, self._tab_scan]` for deterministic column order.
|
||||
|
||||
### Tab bar
|
||||
New `class _DeckTabBar(QTabBar)` (minimal version of `_PlaylistTabBar`): on `contextMenuEvent`, show a checkable "Show side-by-side" action reflecting the page's `_pinned`, and emit `pin_toggle_requested(idx)` when chosen. No rename/folder. Install via `self._control_deck.setTabBar(_DeckTabBar())` in `_build_control_deck` and connect `pin_toggle_requested → self._on_deck_pin_toggle`.
|
||||
|
||||
### Stacked container (mirrors `_list_stack`)
|
||||
Wrap the deck so it can swap between tabbed and split views:
|
||||
- `self._deck_split_container = QWidget()` with an `QHBoxLayout` (`_deck_split_layout`, margins 0, spacing 2).
|
||||
- `self._deck_stack = QStackedWidget()`; page 0 = `self._control_deck`, page 1 = `self._deck_split_container`.
|
||||
- In `right_layout`, mount `self._deck_stack` where `self._control_deck` is currently added (replace that one `addWidget`).
|
||||
|
||||
### `_refresh_deck_layout()` (mirrors `_refresh_layout`)
|
||||
```
|
||||
pinned = [p for p in self._deck_panels if p._pinned]
|
||||
guard self._deck_loading = True (avoid re-entrant signals)
|
||||
detach all panels (setParent(None)); self._control_deck.clear(); clear _deck_split_layout
|
||||
if len(pinned) >= 2:
|
||||
splitter = QSplitter(Horizontal); splitter.setChildrenCollapsible(False)
|
||||
leftovers = []
|
||||
for panel in self._deck_panels: # preserve deck order
|
||||
if panel._pinned:
|
||||
col = QWidget(); v = QVBoxLayout(col) (0 margins)
|
||||
header = label(panel._label, bold) + "✕" button (unpin, fixed 18x18,
|
||||
tooltip "Return to tabs", clicked → self._on_deck_unpin(panel))
|
||||
header fixed height ~22
|
||||
panel.setVisible(True) # reparented pages start hidden
|
||||
v.addWidget(header); v.addWidget(panel, 1)
|
||||
splitter.addWidget(col)
|
||||
else:
|
||||
leftovers.append(panel)
|
||||
if leftovers: # keep unpinned reachable as a tab-column
|
||||
lt = QTabWidget(); lt.setDocumentMode(True)
|
||||
for panel in leftovers:
|
||||
panel.setVisible(True); lt.addTab(panel, panel._label)
|
||||
splitter.addWidget(lt)
|
||||
splitter.setSizes([1000]*splitter.count())
|
||||
_deck_split_layout.addWidget(splitter)
|
||||
self._deck_stack.setCurrentWidget(self._deck_split_container)
|
||||
else:
|
||||
for panel in self._deck_panels: # fixed order
|
||||
self._control_deck.addTab(panel, panel._label)
|
||||
self._deck_stack.setCurrentWidget(self._control_deck)
|
||||
restore self._deck_loading
|
||||
```
|
||||
|
||||
### Toggle handlers (mirror `_on_pin_toggle`/`_on_unpin`)
|
||||
- `_on_deck_pin_toggle(idx)`: `panel = self._control_deck.widget(idx)` (only valid in tabbed mode — pin is only offered there); flip `panel._pinned`; if now pinned and `<2` pinned, `_show_status("Pin another panel to show them side-by-side", 3500)`; `_refresh_deck_layout()`; `_save_deck_layout()`.
|
||||
- `_on_deck_unpin(panel)`: `panel._pinned = False`; `_refresh_deck_layout()`; `_save_deck_layout()`.
|
||||
|
||||
### Persistence
|
||||
- `_save_deck_layout()`: `self._settings.setValue("deck_pinned", [p._deck_key for p in self._deck_panels if p._pinned])`.
|
||||
- Restore at the end of `__init__` (after the deck + menubar exist): read `deck_pinned` (handle str/list like the subprofiles loader at main.py:3867), set each panel's `_pinned`, then `_refresh_deck_layout()` once.
|
||||
|
||||
### Height
|
||||
The deck pages now also render with a 22px header in split mode. After building, set the stack's minimum height to fit the tallest **split-mode** column (header + Export content) so split mode never clips: compute once via `self._deck_stack.setMinimumHeight(...)` using `sizeHint`, and keep vertical size policy `Fixed` (as the deck has now). Switching INTO split mode may change the deck height slightly (deliberate user action — acceptable); switching tabs within tabbed mode must still not jump. Reuse the existing height-pin logic — apply it to `_deck_stack` instead of `_control_deck`.
|
||||
|
||||
---
|
||||
|
||||
## Implementation tasks (bite-sized, commit per task)
|
||||
|
||||
**Task M.1 — scaffolding (no behavior change yet).** Add `_DeckTabBar`; in `_build_control_deck` set it on the deck, set `_pinned/_label/_deck_key` on the three pages, build `self._deck_panels`, create `_deck_split_container`/`_deck_split_layout`/`_deck_stack`, and mount `_deck_stack` in `right_layout` instead of `_control_deck`. Connect `pin_toggle_requested` to a stub. App still behaves as plain tabs. Verify: `import main`, structure tests 6/6, and a probe that `_deck_stack.currentWidget() is _control_deck`.
|
||||
|
||||
**Task M.2 — split rendering.** Implement `_refresh_deck_layout`, `_detach_deck_panels`, `_clear_deck_split`, `_on_deck_pin_toggle`, `_on_deck_unpin`. Verify with a probe: set two panels `_pinned=True`, call `_refresh_deck_layout()`, assert stack shows `_deck_split_container`, the splitter has 3 columns (2 pinned + 1 leftover QTabWidget), and all three panels are visible/parented; unpin one → back to `_control_deck` with 3 tabs in order.
|
||||
|
||||
**Task M.3 — persistence.** Add `_save_deck_layout()` + restore block in `__init__`. Verify a probe round-trips a pinned set through QSettings (use an isolated QSettings scope in the test if needed) without error and that restore calls refresh exactly once.
|
||||
|
||||
**Task M.4 — height + tests.** Apply the height-pin to `_deck_stack`; confirm split mode doesn't clip the tallest column. Add structure tests: `test_deck_stack_exists`, and `test_pinning_two_panels_switches_to_split` (programmatically pin 2, refresh, assert `_deck_stack.currentWidget() is _deck_split_container`).
|
||||
|
||||
## Verification note
|
||||
Env quirk (same as the restructure): bare `python -c` constructing `MainWindow` segfaults on mpv GL; run checks under the pytest fixture and `LD_PRELOAD=/usr/lib/libstdc++.so.6 QT_QPA_PLATFORM=offscreen`. Visual confirmation (drag dividers, pin/unpin gestures, persistence across real launches) is the user's, done at the end.
|
||||
|
||||
## Risks
|
||||
- **Reparenting hidden pages:** QTabWidget hides non-current pages; reparented panels must be `setVisible(True)` in split columns (same gotcha the playlist documents at main.py:4909-4911).
|
||||
- **Signal re-entrancy:** guard with `_deck_loading` during refresh.
|
||||
- **Pin offered in split mode:** `_on_deck_pin_toggle` reads `_control_deck.widget(idx)`, which is only meaningful in tabbed mode. The ✕ header is the unpin path in split mode — don't rely on the context menu there.
|
||||
- **Height jump on mode toggle:** acceptable (deliberate); tab-switch-within-tabs must remain jump-free.
|
||||
@@ -0,0 +1,66 @@
|
||||
# LTX-2 per-tab export mode — Design
|
||||
|
||||
**Goal:** Add an export *pipeline mode* to each file-list tab — **Foley** (current behavior) or **LTX-2** — so the same source videos can feed both a Foley dataset (8 s clips) and an LTX-2 V2A dataset (frame-exact, ÷32, 25 fps) without the two ever mixing.
|
||||
|
||||
**Depends on:** the per-tab export folder feature (branch `tab-export-folder`) — this design extends that per-tab state. Implementation branch `ltx2-preset` is based on it.
|
||||
|
||||
**Scope:** soft preset (no hard enforcement — defaults are LTX-2-legal but every control stays editable). `core/` gains optional pipeline params; Foley path is byte-for-byte unchanged.
|
||||
|
||||
---
|
||||
|
||||
## LTX-2 constraints (why this exists)
|
||||
|
||||
LTX-2 (32× spatial VAE, 8× temporal + 1) requires, for a clip:
|
||||
- **W and H each divisible by 32.**
|
||||
- **Frame count F such that `F % 8 == 1`** → 9, 17, 25, … 201, … (transformer seq-len ∝ `(W/32)·(H/32)·((F−1)/8+1)`).
|
||||
- **fps** only sets real duration `F/fps`; for V2A it fixes the paired-audio length and audio↔motion sync, so it must be **consistent across the dataset and equal to the inference `frame_rate`**. Target: **25 fps**.
|
||||
- V2A video is frozen conditioning → low spatial res (384–512) is fine and cheaper.
|
||||
|
||||
Note: 8 s @ 25 fps = 200 frames, and `200 % 8 == 0` → **8 s is not legal**. Nearest legal: F=193 (7.72 s) or **F=201 (8.04 s)**.
|
||||
|
||||
---
|
||||
|
||||
## Model: per-tab mode
|
||||
|
||||
Each tab (`PlaylistWidget`) gains `_mode ∈ {"foley","ltx2"}`, persisted alongside `_dest_folder`/`_pinned`/`_tab_folder` in `_save_playlist_tabs`/`_load_playlist_tabs`. Default `"foley"` → existing tabs load unchanged. The **active tab's mode drives the export pipeline and the length control.**
|
||||
|
||||
### Tab context menu (`_DeckTabBar`/`_PlaylistTabBar`)
|
||||
- **Duplicate as LTX-2** — headline action: clone the tab's file list + separators into a new tab; set `mode="ltx2"`; derive a separate export folder `"<dest_folder>_ltx2"`; load LTX-2 default geometry. Lets you spin an LTX-2 dataset off a Foley working set.
|
||||
- **Duplicate tab** — clone keeping the same mode.
|
||||
- **LTX-2 mode** — checkable, flips an existing tab between foley/ltx2.
|
||||
- Tab label shows a small **`[LTX2]`** badge when `mode=="ltx2"`.
|
||||
|
||||
## What `ltx2` mode changes (soft — still editable)
|
||||
|
||||
| Aspect | Foley | LTX-2 |
|
||||
|--------|-------|-------|
|
||||
| Clip length | Duration spinbox (seconds) | **Frame-count F** control stepping the legal series (9, 17, …, 201, …); shows `= F/25 s` |
|
||||
| Output fps | inherits source | **forced 25 fps** (resample; preserves duration/sync) |
|
||||
| Output W×H | short-side resize → even long side | **center-cropped to ÷32** on both axes (no aspect distortion; loses ≤31 px/side); resize default **512** |
|
||||
| Frame exactness | duration-based | exactly **F** frames (`-frames:v F`) |
|
||||
|
||||
Defaults loaded on convert: resize **512**, **F = 201** (≈8.04 s, mirrors the 8 s Foley clips), ratio as set. All editable afterward.
|
||||
|
||||
## Pipeline (`core/ffmpeg.build_ffmpeg_command`)
|
||||
|
||||
Add optional params; Foley calls pass none → identical output to today:
|
||||
- `target_fps: float | None` — when set, append `fps={target_fps}` filter and `-r {target_fps}`.
|
||||
- `snap32: bool` — when true, after the scale append a centered crop to the nearest lower multiple of 32 on each axis: `crop=trunc(iw/32)*32:trunc(ih/32)*32`.
|
||||
- Frame-exact length: caller computes `duration = F/target_fps` and passes `-frames:v F` on the video output so the clip has exactly F frames; audio extract uses the same `F/target_fps` duration so V2A pairing stays aligned.
|
||||
|
||||
Filter order: portrait-crop (aspect) → scale (short side, ÷32 default) → snap32 crop → fps. The snap32 center-crop runs after scaling so the ÷32 trim is on final pixels.
|
||||
|
||||
## UI wiring (`MainWindow`)
|
||||
|
||||
- The length spinbox area swaps with the active tab's mode: Foley shows *Duration (s)*; LTX-2 shows *Frames (F)* with a live `= s @25fps` readout. Switching tabs (or toggling mode) reconfigures it; uses the existing `_sync_folder_field_to_tab`-style sync hook on tab change.
|
||||
- `_on_export` / `_start_export_batch`: when the active tab is `ltx2`, pass `target_fps=25`, `snap32=True`, and frame-exact length to the ffmpeg builder; otherwise unchanged.
|
||||
- The mismatch guardrail (just added) and per-tab folder continue to apply.
|
||||
|
||||
## Persistence & migration
|
||||
`_mode` added to each tab's saved JSON (default `"foley"` when absent). No DB changes. Existing sessions load every tab as Foley → zero behavior change until a tab is converted.
|
||||
|
||||
## What this does NOT do
|
||||
- No hard enforcement: you can set an illegal F or non-÷32 resize manually; the pipeline still crops to ÷32 and uses whatever F you pick (the *control* defaults/steps keep you legal, but nothing blocks you).
|
||||
- No motion interpolation on fps resample (frame drop/dup only); keep sources native 25 fps where possible.
|
||||
- No change to Foley exports, the scan pipeline, or the DB schema.
|
||||
- No automatic re-export of existing clips into LTX-2 — you cut LTX-2 clips in the converted tab.
|
||||
@@ -0,0 +1,179 @@
|
||||
# LTX-2 per-tab export mode — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Add a per-tab export pipeline mode (Foley | LTX-2) so the same videos can feed both an 8 s Foley dataset and a frame-exact, ÷32, 25 fps LTX-2 V2A dataset, with a "Duplicate as LTX-2" tab action.
|
||||
|
||||
**Architecture:** `core/ffmpeg.build_ffmpeg_command` gains optional `target_fps` / `snap32` / `frames` params (Foley path unchanged); a tiny `core/ltx2.py` holds the legal-frame math. `PlaylistWidget` gains `_mode`; the tab menu gains duplicate/convert actions; the length control + `_on_export` wiring switch on the active tab's mode. Soft preset — defaults are legal, everything stays editable.
|
||||
|
||||
**Tech Stack:** Python 3.11+, PyQt6, ffmpeg, pytest. Branch `ltx2-preset` (based on `tab-export-folder`). Design: `docs/plans/2026-06-18-ltx2-preset-design.md`.
|
||||
|
||||
---
|
||||
|
||||
## Conventions
|
||||
- **Core (`core/ffmpeg.py`, `core/ltx2.py`) is real TDD** — pure functions tested in `tests/test_utils.py` style. Run: `LD_PRELOAD=/usr/lib/libstdc++.so.6 python -m pytest tests/test_utils.py -q` (the preload is needed because importing `main` pulls `mpv`; see `project_qt_test_env`). 3 pre-existing failures there are unrelated — don't count them.
|
||||
- **GUI parts** verified by the offscreen structure test (`LD_PRELOAD=/usr/lib/libstdc++.so.6 QT_QPA_PLATFORM=offscreen python -m pytest tests/test_ui_structure.py -v`) plus a **manual launch** (`./8cut.sh`).
|
||||
- Line numbers are starting anchors; locate by symbol. Commit per task. Co-author trailer on every commit:
|
||||
`Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>`
|
||||
|
||||
---
|
||||
|
||||
## Stage 1 — LTX-2 math (`core/ltx2.py`) [TDD]
|
||||
|
||||
### Task 1.1: legal-frame helpers
|
||||
**Files:** Create `core/ltx2.py`; Test in `tests/test_utils.py` (append).
|
||||
|
||||
**Step 1 — failing tests** (append to `tests/test_utils.py`):
|
||||
```python
|
||||
from core.ltx2 import is_legal_frames, nearest_legal_frames, frames_for_duration, duration_for_frames, legal_frames
|
||||
|
||||
def test_ltx2_is_legal():
|
||||
assert is_legal_frames(201) and is_legal_frames(9) and is_legal_frames(25)
|
||||
assert not is_legal_frames(200) and not is_legal_frames(8)
|
||||
|
||||
def test_ltx2_nearest():
|
||||
assert nearest_legal_frames(200) == 201 # 200 -> nearest 8k+1
|
||||
assert nearest_legal_frames(196) == 193
|
||||
assert nearest_legal_frames(5) == 9 # floor at 9
|
||||
|
||||
def test_ltx2_duration_roundtrip():
|
||||
assert duration_for_frames(201, 25) == 201 / 25
|
||||
assert frames_for_duration(8.0, 25) == 201 # 200 -> 201
|
||||
|
||||
def test_ltx2_legal_series():
|
||||
s = legal_frames(min_f=9, max_f=33)
|
||||
assert s == [9, 17, 25, 33]
|
||||
```
|
||||
**Step 2 — run, expect ImportError/FAIL:** `LD_PRELOAD=/usr/lib/libstdc++.so.6 python -m pytest tests/test_utils.py -k ltx2 -q`
|
||||
|
||||
**Step 3 — implement `core/ltx2.py`:**
|
||||
```python
|
||||
"""LTX-2 frame-count math. Legal F satisfy F % 8 == 1 (8x temporal + 1)."""
|
||||
|
||||
def is_legal_frames(f: int) -> bool:
|
||||
return f >= 9 and f % 8 == 1
|
||||
|
||||
def legal_frames(min_f: int = 9, max_f: int = 1000) -> list[int]:
|
||||
start = max(9, min_f + ((1 - min_f) % 8)) # first 8k+1 >= min_f
|
||||
return list(range(start, max_f + 1, 8))
|
||||
|
||||
def nearest_legal_frames(f: int) -> int:
|
||||
if f <= 9:
|
||||
return 9
|
||||
low = ((f - 1) // 8) * 8 + 1
|
||||
high = low + 8
|
||||
return low if (f - low) <= (high - f) else high
|
||||
|
||||
def duration_for_frames(frames: int, fps: float) -> float:
|
||||
return frames / fps
|
||||
|
||||
def frames_for_duration(duration: float, fps: float) -> int:
|
||||
return nearest_legal_frames(round(duration * fps))
|
||||
```
|
||||
**Step 4 — run, expect PASS** (same command). **Step 5 — commit:** `feat: LTX-2 legal-frame helpers (core/ltx2.py)`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 2 — ffmpeg pipeline params [TDD]
|
||||
|
||||
### Task 2.1: `target_fps`, `snap32`, `frames` in `build_ffmpeg_command`
|
||||
**Files:** Modify `core/ffmpeg.py:74` (`build_ffmpeg_command`); Test `tests/test_utils.py`.
|
||||
|
||||
**Step 1 — failing tests:**
|
||||
```python
|
||||
def test_ffmpeg_ltx2_fps_and_frames():
|
||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||
short_side=512, target_fps=25, frames=201)
|
||||
assert "-r" in cmd and cmd[cmd.index("-r")+1] == "25"
|
||||
assert "-frames:v" in cmd and cmd[cmd.index("-frames:v")+1] == "201"
|
||||
vf = cmd[cmd.index("-vf")+1]
|
||||
assert "fps=25" in vf
|
||||
|
||||
def test_ffmpeg_ltx2_snap32_crop():
|
||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4",
|
||||
short_side=512, snap32=True)
|
||||
vf = cmd[cmd.index("-vf")+1]
|
||||
assert "crop=trunc(iw/32)*32:trunc(ih/32)*32" in vf
|
||||
|
||||
def test_ffmpeg_foley_unchanged():
|
||||
cmd = build_ffmpeg_command("/in/v.mp4", 0.0, "/out/c.mp4", short_side=256)
|
||||
assert "-r" not in cmd and "-frames:v" not in cmd
|
||||
assert "crop=trunc" not in cmd[cmd.index("-vf")+1]
|
||||
```
|
||||
**Step 2 — run, expect FAIL** (unexpected kwargs).
|
||||
|
||||
**Step 3 — implement:** add params `target_fps: float | None = None, snap32: bool = False, frames: int | None = None` to the signature. After the scale filter (and before the VAAPI block), append:
|
||||
```python
|
||||
if snap32:
|
||||
filters.append("crop=trunc(iw/32)*32:trunc(ih/32)*32")
|
||||
if target_fps is not None:
|
||||
filters.append(f"fps={target_fps:g}")
|
||||
```
|
||||
Add output flags: after `-t duration` (or near the encoder args, before `output_path`), when `target_fps` set add `cmd += ["-r", f"{target_fps:g}"]`; when `frames` set add `cmd += ["-frames:v", str(frames)]` (video frame cap — exact F). Ensure ordering keeps `-vf` before outputs. Keep `fps`/`snap32` filters out of the `image_sequence=False` vs `True` branches consistently (they apply to both; webp seq also benefits from fps/÷32).
|
||||
|
||||
**Step 4 — run, expect PASS.** Also run full `tests/test_utils.py` (the 3 pre-existing failures only). **Step 5 — commit:** `feat: LTX-2 ffmpeg params (target_fps, snap32, frames)`.
|
||||
|
||||
### Task 2.2: audio extract honors frame-exact duration
|
||||
**Files:** `core/ffmpeg.py:145` (`build_audio_extract_command`) — confirm it takes a duration; if it derives from a fixed 8 s, add a `duration` param so the `.wav` for an LTX-2 webp sequence is exactly `F/25 s`. Add a test mirroring `test_audio_extract_timing` asserting the `-t` value equals `frames/fps`. Commit: `fix: audio extract duration for LTX-2 frame-exact clips`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 3 — per-tab `_mode`
|
||||
|
||||
### Task 3.1: attribute + persistence + migration
|
||||
**Files:** `main.py` — `PlaylistWidget.__init__` (~3409, next to `_dest_folder`); `_save_playlist_tabs` (~5271); `_load_playlist_tabs` (~5315).
|
||||
- Add `self._mode: str = "foley"` in `PlaylistWidget.__init__`.
|
||||
- `_save_playlist_tabs`: add `"mode": pw._mode` to each tab dict.
|
||||
- `_load_playlist_tabs`: after creating each pw, `pw._mode = t.get("mode", "foley")`.
|
||||
- `_add_playlist_tab`: new tabs default `_mode="foley"` (already via init).
|
||||
|
||||
**Verify:** structure test passes; add `test_tab_mode_defaults_foley` (construct, assert each `_pws[i]._mode == "foley"`). Commit: `feat: per-tab export mode attribute (foley default)`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 4 — tab menu: duplicate / convert / toggle
|
||||
|
||||
### Task 4.1: menu actions + label badge
|
||||
**Files:** `main.py` — `_PlaylistTabBar.contextMenuEvent` (~3300) add items; new handlers in `MainWindow`; tab-title rendering.
|
||||
- Add to the tab context menu: **"Duplicate tab"**, **"Duplicate as LTX-2"**, and a checkable **"LTX-2 mode"** (checked when `pw._mode=="ltx2"`). Emit new signals (e.g. `duplicate_requested(idx, as_ltx2: bool)`, `mode_toggle_requested(idx)`) like the existing `pin_toggle_requested`.
|
||||
- `MainWindow._on_duplicate_tab(idx, as_ltx2)`: build a new tab via `_add_playlist_tab(label=…, files=list(src._paths), separators=sorted(src._separators_before), select=True)`; set `pw._dest_folder = src._dest_folder + ("_ltx2" if as_ltx2 else "")`; `pw._mode = "ltx2" if as_ltx2 else src._mode`; if ltx2, apply LTX-2 defaults (Stage 5 hook); `_save_playlist_tabs()`; refresh.
|
||||
- `MainWindow._on_tab_mode_toggle(idx)`: flip `pw._mode`; if now ltx2, apply LTX-2 defaults; `_save_playlist_tabs()`; re-sync controls (Stage 5).
|
||||
- Label badge: when adding/refreshing a tab whose `_mode=="ltx2"`, show `f"{label} [LTX2]"` (or set a distinct color) — apply in `_refresh_layout`/`_add_playlist_tab` title set.
|
||||
|
||||
**Verify:** manual launch — right-click a tab → Duplicate as LTX-2 creates a `[LTX2]` tab with `_ltx2` folder; toggle works. Structure test still green. Commit: `feat: tab duplicate / Duplicate-as-LTX-2 / mode toggle + [LTX2] badge`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 5 — length control swap + export wiring
|
||||
|
||||
### Task 5.1: length control reflects active tab mode
|
||||
**Files:** `main.py` — the clip-length widgets (`_spn_clip_dur` ~4051 area) + the tab-change sync hook (`_on_tab_changed` / `_sync_folder_field_to_tab` neighbor).
|
||||
- Add a frames spinbox `_spn_frames` (min 9, singleStep 8 → always 8k+1; suffix " f"; tooltip live `= F/25 s`). Default 201.
|
||||
- Add `_apply_mode_to_controls()`: if active tab `ltx2` → show `_spn_frames` (+ "Frames" label), hide the seconds Duration control, default resize 512 if unset; else show Duration (seconds), hide frames. Call it from `_on_tab_changed`, after `_on_duplicate_tab`/`_on_tab_mode_toggle`, and once after `_load_playlist_tabs`.
|
||||
- A small label shows `= {F/25:.2f}s @25fps` updating on `_spn_frames.valueChanged`.
|
||||
|
||||
### Task 5.2: route LTX-2 params through export
|
||||
**Files:** `main.py` — `_on_export` (~7317) + `ExportWorker` construction (~7484) + `_update_next_label`.
|
||||
- When the active tab's `_mode=="ltx2"`: compute `frames = self._spn_frames.value()`; `fps = 25`; `duration = frames / fps`; pass `target_fps=25, snap32=True, frames=frames, duration=duration` through to `ExportWorker` → `build_ffmpeg_command`. Default `short_side` to 512 if 0/None in ltx2.
|
||||
- Foley path: unchanged (no new params).
|
||||
- `ExportWorker.__init__`/`run`: thread the new params (default None/False) into `build_ffmpeg_command`.
|
||||
|
||||
**Verify (manual, authoritative):** in an LTX-2 tab, export → inspect an output clip: `ffprobe` shows **25 fps, exactly F frames, W&H ÷32**; a Foley tab still exports 8 s/source-fps unchanged. Structure test green; full `pytest tests/test_utils.py` (3 pre-existing fails only). Commit: `feat: route LTX-2 (25fps, ÷32 crop, F frames) through export for ltx2 tabs`.
|
||||
|
||||
---
|
||||
|
||||
## Stage 6 — finalize
|
||||
- **Task 6.1:** Full regression — `pytest tests/test_ui_structure.py` + `tests/test_utils.py` separately; manual: Foley export unchanged, LTX-2 export legal (ffprobe), duplicate/convert, persistence across relaunch, guardrail + per-tab folder still work.
|
||||
- **Task 6.2:** Changelog (`main.py` CHANGELOG, bump APP_VERSION) + README note (per-tab LTX-2 mode). Commit `docs: changelog + README for LTX-2 export mode`.
|
||||
- **Task 6.3:** Hand off branch (depends on `tab-export-folder`; merge that first, then this).
|
||||
|
||||
## Risks
|
||||
| Risk | Mitigation |
|
||||
|------|-----------|
|
||||
| `-frames:v` vs `-t` interaction yields F±1 frames | Set both `-t F/fps` and `-frames:v F`; verify exact count with ffprobe in 5.2. |
|
||||
| `fps` filter + HW (VAAPI) filter ordering | Place `fps`/`snap32` among CPU filters before the VAAPI hwupload block; test a HW-encoder build if available. |
|
||||
| Length-control swap leaves stale state across tab switches | `_apply_mode_to_controls()` called on every tab change + mode toggle + load. |
|
||||
| Depends on unmerged `tab-export-folder` | Branch is based on it; land that branch first. |
|
||||
|
||||
## NOT in scope
|
||||
Hard enforcement (illegal F/resize allowed manually), motion-interpolated fps, auto re-export of existing Foley clips, DB schema changes, scan-pipeline changes.
|
||||
@@ -1,4 +1,23 @@
|
||||
# Core GUI
|
||||
PyQt6>=6.4
|
||||
python-mpv>=1.0
|
||||
pytest>=7.0
|
||||
|
||||
# Audio & ML
|
||||
librosa>=0.10
|
||||
numpy>=1.24
|
||||
scikit-learn>=1.3
|
||||
joblib>=1.3
|
||||
soundfile>=0.12
|
||||
|
||||
# Deep learning — install via setup_env.sh for correct CUDA version,
|
||||
# or manually: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
torch>=2.0
|
||||
torchaudio>=2.0
|
||||
transformers>=4.30,<5.0 # EAT remote model code incompatible with transformers 5.x
|
||||
timm>=0.9
|
||||
|
||||
# Object detection
|
||||
ultralytics>=8.0
|
||||
|
||||
# Dev
|
||||
pytest>=7.0
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from core.db import ProcessedDB
|
||||
from .config import DB_PATH
|
||||
from .routes import files, stream, markers, export, hidden
|
||||
from . import ws
|
||||
|
||||
app = FastAPI(title="8-cut Server")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
db = ProcessedDB(DB_PATH)
|
||||
|
||||
app.include_router(files.router, prefix="/api")
|
||||
app.include_router(stream.router, prefix="/api")
|
||||
app.include_router(markers.router, prefix="/api")
|
||||
app.include_router(export.router, prefix="/api")
|
||||
app.include_router(hidden.router, prefix="/api")
|
||||
|
||||
|
||||
@app.websocket("/ws/export")
|
||||
async def export_ws(websocket: WebSocket):
|
||||
await ws.connect(websocket)
|
||||
@@ -1,171 +0,0 @@
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
from core.paths import _bin, _log
|
||||
from .config import CACHE_DIR, QUALITY_PRESETS
|
||||
|
||||
|
||||
class CacheStatus(str, Enum):
|
||||
READY = "ready"
|
||||
TRANSCODING = "transcoding"
|
||||
MISSING = "missing"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
_jobs_lock = threading.Lock()
|
||||
_active_jobs: dict[str, threading.Thread] = {}
|
||||
|
||||
|
||||
def _cache_key(source_path: str) -> str:
|
||||
"""Stable hash from absolute source path."""
|
||||
return hashlib.sha256(source_path.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def cache_path(source_path: str, quality: str) -> str:
|
||||
key = _cache_key(source_path)
|
||||
return os.path.join(CACHE_DIR, quality, f"{key}.mp4")
|
||||
|
||||
|
||||
def audio_cache_path(source_path: str) -> str:
|
||||
key = _cache_key(source_path)
|
||||
return os.path.join(CACHE_DIR, "audio", f"{key}.wav")
|
||||
|
||||
|
||||
def get_status(source_path: str, quality: str) -> CacheStatus:
|
||||
cp = cache_path(source_path, quality)
|
||||
if os.path.isfile(cp):
|
||||
return CacheStatus.READY
|
||||
job_key = f"{source_path}:{quality}"
|
||||
with _jobs_lock:
|
||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||
return CacheStatus.TRANSCODING
|
||||
return CacheStatus.MISSING
|
||||
|
||||
|
||||
def get_audio_status(source_path: str) -> CacheStatus:
|
||||
ap = audio_cache_path(source_path)
|
||||
if os.path.isfile(ap):
|
||||
return CacheStatus.READY
|
||||
job_key = f"{source_path}:audio"
|
||||
with _jobs_lock:
|
||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||
return CacheStatus.TRANSCODING
|
||||
return CacheStatus.MISSING
|
||||
|
||||
|
||||
def get_all_statuses(source_path: str) -> dict:
|
||||
result = {}
|
||||
for q in QUALITY_PRESETS:
|
||||
result[q] = get_status(source_path, q)
|
||||
result["audio"] = get_audio_status(source_path)
|
||||
return result
|
||||
|
||||
|
||||
def _transcode_worker(source_path: str, quality: str) -> None:
|
||||
preset = QUALITY_PRESETS[quality]
|
||||
out = cache_path(source_path, quality)
|
||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||
tmp = out + ".tmp.mp4"
|
||||
|
||||
cmd = [_bin("ffmpeg"), "-y", "-i", source_path, "-an"]
|
||||
|
||||
if preset["height"] > 0:
|
||||
cmd += [
|
||||
"-vf", f"scale=-2:{preset['height']}:flags=lanczos",
|
||||
]
|
||||
|
||||
cmd += [
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-b:v", preset["bitrate"],
|
||||
"-movflags", "+faststart",
|
||||
tmp,
|
||||
]
|
||||
|
||||
_log(f"Transcode start: {source_path} @ {quality}")
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
|
||||
if result.returncode == 0:
|
||||
os.rename(tmp, out)
|
||||
_log(f"Transcode done: {out}")
|
||||
else:
|
||||
_log(f"Transcode failed: {result.stderr[-300:]}")
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
except Exception as e:
|
||||
_log(f"Transcode error: {e}")
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
def _audio_extract_worker(source_path: str) -> None:
|
||||
out = audio_cache_path(source_path)
|
||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||
tmp = out + ".tmp.wav"
|
||||
|
||||
cmd = [
|
||||
_bin("ffmpeg"), "-y",
|
||||
"-i", source_path,
|
||||
"-vn",
|
||||
"-c:a", "pcm_s16le",
|
||||
tmp,
|
||||
]
|
||||
|
||||
_log(f"Audio extract start: {source_path}")
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
if result.returncode == 0:
|
||||
os.rename(tmp, out)
|
||||
_log(f"Audio extract done: {out}")
|
||||
else:
|
||||
_log(f"Audio extract failed: {result.stderr[-300:]}")
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
except Exception as e:
|
||||
_log(f"Audio extract error: {e}")
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
def _prune_dead_jobs() -> None:
|
||||
"""Remove finished threads from _active_jobs. Must be called under _jobs_lock."""
|
||||
dead = [k for k, t in _active_jobs.items() if not t.is_alive()]
|
||||
for k in dead:
|
||||
del _active_jobs[k]
|
||||
|
||||
|
||||
def ensure_transcode(source_path: str, quality: str) -> CacheStatus:
|
||||
"""Start transcode if not cached. Returns current status."""
|
||||
status = get_status(source_path, quality)
|
||||
if status != CacheStatus.MISSING:
|
||||
return status
|
||||
|
||||
job_key = f"{source_path}:{quality}"
|
||||
with _jobs_lock:
|
||||
_prune_dead_jobs()
|
||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||
return CacheStatus.TRANSCODING
|
||||
t = threading.Thread(target=_transcode_worker, args=(source_path, quality), daemon=True)
|
||||
_active_jobs[job_key] = t
|
||||
t.start()
|
||||
return CacheStatus.TRANSCODING
|
||||
|
||||
|
||||
def ensure_audio(source_path: str) -> CacheStatus:
|
||||
"""Start audio extraction if not cached. Returns current status."""
|
||||
status = get_audio_status(source_path)
|
||||
if status != CacheStatus.MISSING:
|
||||
return status
|
||||
|
||||
job_key = f"{source_path}:audio"
|
||||
with _jobs_lock:
|
||||
_prune_dead_jobs()
|
||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||
return CacheStatus.TRANSCODING
|
||||
t = threading.Thread(target=_audio_extract_worker, args=(source_path,), daemon=True)
|
||||
_active_jobs[job_key] = t
|
||||
t.start()
|
||||
return CacheStatus.TRANSCODING
|
||||
@@ -1,21 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MEDIA_DIRS: list[str] = [
|
||||
d.strip() for d in os.environ.get("MEDIA_DIRS", str(Path.home())).split(",") if d.strip()
|
||||
]
|
||||
EXPORT_DIR: str = os.environ.get("EXPORT_DIR", str(Path.home() / "8cut-exports"))
|
||||
DB_PATH: str = os.environ.get("DB_PATH", str(Path.home() / ".8cut.db"))
|
||||
CACHE_DIR: str = os.environ.get("CACHE_DIR", str(Path.home() / ".8cut-cache"))
|
||||
HOST: str = os.environ.get("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.environ.get("PORT", "8000"))
|
||||
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".ts", ".flv", ".wmv"}
|
||||
|
||||
QUALITY_PRESETS = {
|
||||
"potato": {"height": 480, "bitrate": "500k"},
|
||||
"low": {"height": 720, "bitrate": "2M"},
|
||||
"medium": {"height": 1080, "bitrate": "5M"},
|
||||
"high": {"height": 0, "bitrate": "10M"}, # 0 = original resolution
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.export import ExportRunner
|
||||
from core.paths import build_export_path, build_sequence_dir
|
||||
from core.ffmpeg import _RATIOS, apply_keyframes_to_jobs
|
||||
from .. import ws as ws_module
|
||||
from ..config import EXPORT_DIR, MEDIA_DIRS
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_jobs: dict[str, dict] = {}
|
||||
_counter_lock = threading.Lock()
|
||||
|
||||
_VALID_ENCODERS = {"libx264", "h264_nvenc", "h264_vaapi", "h264_qsv", "h264_amf", "h264_videotoolbox"}
|
||||
|
||||
_MAX_FINISHED_JOBS = 200
|
||||
|
||||
|
||||
class CropKeyframe(BaseModel):
|
||||
time: float
|
||||
center: float
|
||||
ratio: str | None = None
|
||||
rand_portrait: bool = False
|
||||
rand_square: bool = False
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
input_path: str
|
||||
cursor: float
|
||||
name: str
|
||||
clips: int = 3
|
||||
spread: float = 3.0
|
||||
short_side: int | None = None
|
||||
portrait_ratio: str | None = None
|
||||
crop_center: float = 0.5
|
||||
format: str = "MP4"
|
||||
label: str = ""
|
||||
category: str = ""
|
||||
profile: str = "default"
|
||||
folder_suffix: str = ""
|
||||
crop_keyframes: list[CropKeyframe] | None = None
|
||||
rand_portrait: bool = False
|
||||
rand_square: bool = False
|
||||
encoder: str = "libx264"
|
||||
|
||||
|
||||
def _next_counter(folder: str, basename: str) -> int:
|
||||
"""Scan folder for existing {basename}_NNN dirs and return max + 1."""
|
||||
pattern = re.compile(rf'^{re.escape(basename)}_(\d{{3}})$')
|
||||
highest = 0
|
||||
if os.path.isdir(folder):
|
||||
for entry in os.listdir(folder):
|
||||
m = pattern.match(entry)
|
||||
if m:
|
||||
highest = max(highest, int(m.group(1)))
|
||||
return highest + 1
|
||||
|
||||
|
||||
def _validate_input_path(path: str) -> str:
|
||||
"""Verify input_path falls under a configured MEDIA_DIR."""
|
||||
real = os.path.realpath(path)
|
||||
for root in MEDIA_DIRS:
|
||||
root_real = os.path.realpath(root)
|
||||
if real == root_real or real.startswith(root_real + os.sep):
|
||||
return real
|
||||
raise HTTPException(status_code=403, detail="input_path outside media directories")
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
def start_export(req: ExportRequest):
|
||||
from ..app import db
|
||||
|
||||
# Validate inputs
|
||||
input_path = _validate_input_path(req.input_path)
|
||||
|
||||
if req.encoder not in _VALID_ENCODERS:
|
||||
raise HTTPException(status_code=400, detail=f"invalid encoder: {req.encoder}")
|
||||
|
||||
if req.portrait_ratio is not None and req.portrait_ratio not in _RATIOS:
|
||||
raise HTTPException(status_code=400, detail=f"invalid portrait_ratio: {req.portrait_ratio}")
|
||||
|
||||
if req.folder_suffix and ("/" in req.folder_suffix or "\\" in req.folder_suffix or ".." in req.folder_suffix):
|
||||
raise HTTPException(status_code=400, detail="folder_suffix must not contain path separators")
|
||||
|
||||
if "/" in req.name or "\\" in req.name or ".." in req.name:
|
||||
raise HTTPException(status_code=400, detail="name must not contain path separators")
|
||||
|
||||
job_id = str(uuid.uuid4())[:8]
|
||||
folder = EXPORT_DIR
|
||||
if req.folder_suffix:
|
||||
folder = folder.rstrip(os.sep) + "_" + req.folder_suffix
|
||||
|
||||
image_sequence = req.format in ("WebP", "WebP sequence")
|
||||
|
||||
# Lock counter + directory creation to prevent race between concurrent exports
|
||||
with _counter_lock:
|
||||
counter = _next_counter(folder, req.name)
|
||||
jobs = []
|
||||
for i in range(req.clips):
|
||||
start = req.cursor + i * req.spread
|
||||
if image_sequence:
|
||||
out = build_sequence_dir(folder, req.name, counter, sub=i if req.clips > 1 else None)
|
||||
else:
|
||||
out = build_export_path(folder, req.name, counter, sub=i if req.clips > 1 else None)
|
||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||
jobs.append((start, out, req.portrait_ratio, req.crop_center))
|
||||
|
||||
# Apply keyframes if provided — returns 6-tuples, strip back to 4
|
||||
if req.crop_keyframes:
|
||||
kf_tuples = [
|
||||
(kf.time, kf.center, kf.ratio, kf.rand_portrait, kf.rand_square)
|
||||
for kf in req.crop_keyframes
|
||||
]
|
||||
widened = apply_keyframes_to_jobs(
|
||||
jobs, kf_tuples,
|
||||
req.crop_center, req.portrait_ratio,
|
||||
req.rand_portrait, req.rand_square,
|
||||
)
|
||||
jobs = [(s, o, r, c) for s, o, r, c, _rp, _rs in widened]
|
||||
|
||||
completed = []
|
||||
|
||||
def on_clip_done(path: str):
|
||||
completed.append(path)
|
||||
# Record in DB so markers show up
|
||||
db.add(
|
||||
filename=os.path.basename(input_path),
|
||||
start_time=req.cursor,
|
||||
output_path=path,
|
||||
label=req.label,
|
||||
category=req.category,
|
||||
short_side=req.short_side,
|
||||
portrait_ratio=req.portrait_ratio or "",
|
||||
crop_center=req.crop_center,
|
||||
fmt=req.format,
|
||||
clip_count=req.clips,
|
||||
spread=req.spread,
|
||||
profile=req.profile,
|
||||
)
|
||||
ws_module.broadcast({"type": "clip_done", "job_id": job_id, "path": path})
|
||||
|
||||
def on_all_done():
|
||||
_jobs[job_id]["status"] = "done"
|
||||
_jobs[job_id].pop("runner", None)
|
||||
ws_module.broadcast({"type": "all_done", "job_id": job_id})
|
||||
|
||||
def on_error(msg: str):
|
||||
_jobs[job_id]["status"] = "error"
|
||||
_jobs[job_id]["error"] = msg
|
||||
_jobs[job_id].pop("runner", None)
|
||||
ws_module.broadcast({"type": "error", "job_id": job_id, "msg": msg})
|
||||
|
||||
runner = ExportRunner(
|
||||
input_path=input_path,
|
||||
jobs=jobs,
|
||||
short_side=req.short_side,
|
||||
image_sequence=image_sequence,
|
||||
encoder=req.encoder,
|
||||
on_clip_done=on_clip_done,
|
||||
on_all_done=on_all_done,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
# Evict old finished jobs to prevent unbounded growth
|
||||
finished = [k for k, v in _jobs.items() if v["status"] in ("done", "error")]
|
||||
if len(finished) > _MAX_FINISHED_JOBS:
|
||||
for k in finished[:len(finished) - _MAX_FINISHED_JOBS]:
|
||||
del _jobs[k]
|
||||
|
||||
_jobs[job_id] = {
|
||||
"status": "running",
|
||||
"total": len(jobs),
|
||||
"completed": completed,
|
||||
"runner": runner,
|
||||
"created_at": time.monotonic(),
|
||||
}
|
||||
runner.start()
|
||||
|
||||
return {"job_id": job_id}
|
||||
|
||||
|
||||
@router.get("/export/{job_id}")
|
||||
def get_export_status(job_id: str):
|
||||
job = _jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
return {
|
||||
"status": job["status"],
|
||||
"total": job["total"],
|
||||
"completed": len(job["completed"]),
|
||||
"outputs": list(job["completed"]),
|
||||
"error": job.get("error"),
|
||||
}
|
||||
|
||||
|
||||
def _is_under_export_dir(real_path: str) -> bool:
|
||||
"""Check if path is under EXPORT_DIR or any EXPORT_DIR_suffix sibling."""
|
||||
export_real = os.path.realpath(EXPORT_DIR).rstrip(os.sep)
|
||||
# Walk up ancestors — must find EXPORT_DIR or EXPORT_DIR_suffix
|
||||
d = os.path.dirname(real_path)
|
||||
while d != os.path.dirname(d):
|
||||
if d == export_real or d.startswith(export_real + "_"):
|
||||
return True
|
||||
d = os.path.dirname(d)
|
||||
return False
|
||||
|
||||
|
||||
@router.delete("/export")
|
||||
def delete_export(output_path: str = Query(...)):
|
||||
from ..app import db
|
||||
real = os.path.realpath(output_path)
|
||||
if not _is_under_export_dir(real):
|
||||
raise HTTPException(status_code=403, detail="path outside export directory")
|
||||
db.delete_by_output_path(real)
|
||||
if os.path.isfile(real):
|
||||
os.unlink(real)
|
||||
elif os.path.isdir(real):
|
||||
shutil.rmtree(real)
|
||||
return {"deleted": real}
|
||||
@@ -1,56 +0,0 @@
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _scan_videos(root: str) -> list[dict]:
|
||||
results = []
|
||||
for dirpath, _, filenames in os.walk(root):
|
||||
for f in sorted(filenames):
|
||||
if os.path.splitext(f)[1].lower() in VIDEO_EXTENSIONS:
|
||||
full = os.path.join(dirpath, f)
|
||||
rel = os.path.relpath(full, root)
|
||||
results.append({
|
||||
"name": f,
|
||||
"path": rel,
|
||||
"root": root,
|
||||
"size": os.path.getsize(full),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
def list_files(root: str | None = Query(None)):
|
||||
dirs = [root] if root and root in MEDIA_DIRS else MEDIA_DIRS
|
||||
files = []
|
||||
for d in dirs:
|
||||
files.extend(_scan_videos(d))
|
||||
return files
|
||||
|
||||
|
||||
@router.get("/roots")
|
||||
def list_roots():
|
||||
return MEDIA_DIRS
|
||||
|
||||
|
||||
def _safe_resolve(path: str, root: str) -> str:
|
||||
"""Join path to root and verify it stays within the root directory."""
|
||||
if root not in MEDIA_DIRS:
|
||||
raise HTTPException(status_code=400, detail="invalid root")
|
||||
full = os.path.realpath(os.path.join(root, path))
|
||||
if not full.startswith(os.path.realpath(root) + os.sep):
|
||||
raise HTTPException(status_code=403, detail="path outside media root")
|
||||
return full
|
||||
|
||||
|
||||
@router.get("/video/{path:path}")
|
||||
def serve_video(path: str, root: str = Query(...)):
|
||||
full = _safe_resolve(path, root)
|
||||
if not os.path.isfile(full):
|
||||
raise HTTPException(status_code=404, detail="not found")
|
||||
return FileResponse(full, media_type="video/mp4")
|
||||