feat: integrate training UI, BEATs model, and clean up legacy code

- Remove legacy distance-mode scanning (build_profile, _similarity, etc.)
  and hand-crafted intensity features — pipeline is now embedding-only
- Integrate Microsoft BEATs as embedding option alongside wav2vec2/HuBERT
- Add TrainDialog with positive class selector, model picker, video dir
  fallback, and live training stats
- Add TrainWorker QThread with cancel support and proper lifecycle cleanup
- Add source_path column to DB for robust source video tracking
- Add get_export_folders/get_training_data/get_training_stats to DB
- Wire source_path in all export DB writes (_on_clip_done, _on_auto_clip_done)
- Cancel scan/train workers in closeEvent to prevent use-after-free crashes
- Add setup_env.sh supporting both conda and python venv (CUDA 12.8)
- Update requirements.txt with all actual dependencies
- Update 8cut_train.py with --positive flag for new DB-driven training

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-18 11:52:27 +02:00
parent f2c38aee79
commit 12ed183f1b
11 changed files with 2608 additions and 338 deletions
+483 -64
View File
@@ -15,7 +15,7 @@ from PyQt6.QtWidgets import (
QLabel, QPushButton, QLineEdit, QFileDialog,
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
QMessageBox, QInputDialog,
QMessageBox, QInputDialog, QDialog, QDialogButtonBox, QFormLayout,
)
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, QRect, QSize, pyqtSignal, QSettings
from PyQt6.QtGui import QPainter, QColor, QPen, QPixmap, QDragEnterEvent, QDropEvent, QCursor, QFont, QKeySequence, QShortcut
@@ -191,12 +191,11 @@ class ScanWorker(QThread):
error = pyqtSignal(str)
progress = pyqtSignal(str) # status message
def __init__(self, video_path: str, clip_paths: list[str],
mode: str = "average", threshold: float = 0.7):
def __init__(self, video_path: str, model: dict,
threshold: float = 0.30):
super().__init__()
self._video_path = video_path
self._clip_paths = clip_paths
self._mode = mode
self._model = model
self._threshold = threshold
self._cancel = False
@@ -204,20 +203,12 @@ class ScanWorker(QThread):
self._cancel = True
def run(self):
from core.audio_scan import build_profile, scan_video
from core.audio_scan import scan_video
try:
self.progress.emit(f"Building profile from {len(self._clip_paths)} clips...")
profile = build_profile(self._clip_paths)
if self._cancel:
return
if profile is None:
self.error.emit("No valid reference clips found")
return
self.progress.emit("Scanning audio...")
regions = scan_video(
self._video_path, profile,
mode=self._mode, threshold=self._threshold,
cancel_flag=self,
self._video_path, model=self._model,
threshold=self._threshold, cancel_flag=self,
)
if not self._cancel:
self.scan_done.emit(regions)
@@ -226,6 +217,151 @@ class ScanWorker(QThread):
self.error.emit(str(e))
class TrainDialog(QDialog):
"""Dialog for configuring and launching classifier training."""
def __init__(self, db: ProcessedDB, profile: str, video_dir: str = "",
parent=None):
super().__init__(parent)
self.setWindowTitle("Train Classifier")
self.setMinimumWidth(400)
from core.audio_scan import _EMBED_MODELS
self._db = db
self._profile = profile
self._video_dir = video_dir
layout = QVBoxLayout(self)
form = QFormLayout()
# Positive class selector — lists export folders
self._cmb_positive = QComboBox()
stats = db.get_training_stats(profile)
if not stats:
form.addRow("", QLabel("No exported clips found for this profile."))
for folder_name, info in stats.items():
label = f"{folder_name} ({info['videos']} videos, {info['clips']} clips)"
self._cmb_positive.addItem(label, userData=folder_name)
form.addRow("Positive class:", self._cmb_positive)
# Model selector
self._cmb_model = QComboBox()
for name in _EMBED_MODELS:
self._cmb_model.addItem(name)
self._cmb_model.setCurrentText("WAV2VEC2_BASE")
form.addRow("Model:", self._cmb_model)
# Video source directory (fallback for old DB rows without source_path)
self._txt_video_dir = QLineEdit(video_dir)
self._txt_video_dir.setPlaceholderText("Directory containing source videos")
self._debounce = QTimer(self)
self._debounce.setSingleShot(True)
self._debounce.setInterval(400)
self._debounce.timeout.connect(self._update_stats)
self._txt_video_dir.textChanged.connect(lambda: self._debounce.start())
vid_row = QHBoxLayout()
vid_row.addWidget(self._txt_video_dir)
btn_browse = QPushButton("...")
btn_browse.setFixedWidth(30)
btn_browse.clicked.connect(self._browse_video_dir)
vid_row.addWidget(btn_browse)
form.addRow("Video dir:", vid_row)
layout.addLayout(form)
# Stats summary
self._lbl_stats = QLabel()
self._update_stats()
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
layout.addWidget(self._lbl_stats)
# Buttons
btns = QDialogButtonBox(
QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel
)
btns.button(QDialogButtonBox.StandardButton.Ok).setText("Train")
btns.button(QDialogButtonBox.StandardButton.Ok).setEnabled(
self._cmb_positive.count() > 0
)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
layout.addWidget(btns)
def _browse_video_dir(self):
d = QFileDialog.getExistingDirectory(self, "Select video source directory")
if d:
self._txt_video_dir.setText(d)
def _update_stats(self):
folder = self._cmb_positive.currentData()
if not folder:
self._lbl_stats.setText("No export folder data available.")
return
video_infos = self._db.get_training_data(
self._profile, folder,
fallback_video_dir=self._txt_video_dir.text(),
)
n_videos = len(video_infos)
n_pos = sum(len(gt) for _, gt, _ in video_infos)
n_soft = sum(len(s) for _, _, s in video_infos)
lines = [f"<b>{n_videos}</b> videos with positive clips"]
lines.append(f"<b>{n_pos}</b> positive markers, <b>{n_soft}</b> soft/buffer markers")
if n_videos == 0:
lines.append("<i>No source videos found. Set Video dir above.</i>")
elif n_videos < 3:
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
self._lbl_stats.setText("<br>".join(lines))
@property
def positive_folder(self) -> str:
return self._cmb_positive.currentData() or ""
@property
def embed_model(self) -> str:
return self._cmb_model.currentText()
@property
def video_dir(self) -> str:
return self._txt_video_dir.text()
class TrainWorker(QThread):
"""Trains an audio classifier off the main thread."""
train_done = pyqtSignal(str) # emits model path on success
error = pyqtSignal(str)
progress = pyqtSignal(str) # per-video status
def __init__(self, video_infos: list, model_path: str,
embed_model: str | None = None):
super().__init__()
self._video_infos = video_infos
self._model_path = model_path
self._embed_model = embed_model
self._cancel = False
def cancel(self) -> None:
self._cancel = True
def run(self):
from core.audio_scan import train_classifier
try:
self.progress.emit(f"Training on {len(self._video_infos)} videos...")
result = train_classifier(
self._video_infos,
model_path=self._model_path,
embed_model=self._embed_model,
)
if self._cancel:
return
if result is None:
self.error.emit("Training failed: not enough data or missing class balance")
else:
self.train_done.emit(self._model_path)
except Exception as e:
if not self._cancel:
self.error.emit(str(e))
class TimelineWidget(QWidget):
cursor_changed = pyqtSignal(float) # emits position in seconds
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
@@ -1564,23 +1700,35 @@ class MainWindow(QMainWindow):
self._btn_scan.setToolTip("Scan current video for audio segments matching reference clips")
self._btn_scan.clicked.connect(self._start_scan)
self._btn_auto_export = QPushButton("Auto")
self._btn_auto_export.setToolTip("Scan + auto-export best 8s clips")
self._btn_auto_export.clicked.connect(self._auto_export)
self._btn_train = QPushButton("Train")
self._btn_train.setToolTip("Train audio classifier from exported clips")
self._btn_train.clicked.connect(self._open_train_dialog)
self._train_worker: TrainWorker | None = None
self._spn_auto_fuse = QDoubleSpinBox()
self._spn_auto_fuse.setDecimals(1)
self._spn_auto_fuse.setRange(0.0, 60.0)
self._spn_auto_fuse.setSingleStep(1.0)
self._spn_auto_fuse.setValue(float(self._settings.value("auto_fuse", "4.0")))
self._spn_auto_fuse.setPrefix("Fuse: ")
self._spn_auto_fuse.setSuffix("s")
self._spn_auto_fuse.setToolTip("Max gap between scan regions to merge into one cluster")
self._spn_auto_fuse.valueChanged.connect(
lambda v: self._settings.setValue("auto_fuse", str(v))
)
self._sld_threshold = QDoubleSpinBox()
self._sld_threshold.setDecimals(2)
self._sld_threshold.setRange(0.0, 1.0)
self._sld_threshold.setSingleStep(0.01)
self._sld_threshold.setValue(0.05)
self._sld_threshold.setValue(0.30)
self._sld_threshold.setPrefix("Thr: ")
self._sld_threshold.setToolTip("Similarity threshold (0=match everything, 1=exact match)")
self._cmb_scan_mode = QComboBox()
self._cmb_scan_mode.addItems(["Average", "Nearest"])
self._cmb_scan_mode.setToolTip("Average: compare to mean profile\nNearest: compare to closest clip")
self._cmb_scan_ref = QComboBox()
self._cmb_scan_ref.addItems(["Current Profile", "Custom Folder"])
self._cmb_scan_ref.currentIndexChanged.connect(self._on_scan_ref_changed)
self._scan_folder: str = ""
self._scan_worker: ScanWorker | None = None
cpu_count = os.cpu_count() or 2
@@ -1716,9 +1864,10 @@ class MainWindow(QMainWindow):
settings_row.addWidget(self._chk_rand_square)
settings_row.addWidget(self._chk_track)
settings_row.addWidget(self._btn_scan)
settings_row.addWidget(self._btn_auto_export)
settings_row.addWidget(self._spn_auto_fuse)
settings_row.addWidget(self._sld_threshold)
settings_row.addWidget(self._cmb_scan_mode)
settings_row.addWidget(self._cmb_scan_ref)
settings_row.addWidget(self._btn_train)
settings_row.addStretch()
self._lbl_status = QLabel()
self._lbl_status.setStyleSheet("color: #888; font-size: 11px;")
@@ -2503,16 +2652,6 @@ class MainWindow(QMainWindow):
return
self._step_cursor(markers[0][0] - self._cursor) # wrap to first
def _on_scan_ref_changed(self, index: int) -> None:
if index == 1: # Custom Folder
folder = QFileDialog.getExistingDirectory(self, "Select reference clip folder")
if folder:
self._scan_folder = folder
else:
self._cmb_scan_ref.blockSignals(True)
self._cmb_scan_ref.setCurrentIndex(0)
self._cmb_scan_ref.blockSignals(False)
def _cleanup_scan_worker(self) -> None:
"""Disconnect signals and schedule deletion of old scan worker."""
if self._scan_worker is not None:
@@ -2540,35 +2679,22 @@ class MainWindow(QMainWindow):
# Clean up previous worker
self._cleanup_scan_worker()
# Collect reference clip paths
if self._cmb_scan_ref.currentIndex() == 0:
# Current profile — all exports across all files in this profile
clip_paths = [p for p in self._db.get_all_export_paths(self._profile)
if os.path.exists(p)]
else:
# Custom folder
if not self._scan_folder:
self._show_status("No reference folder selected")
return
exts = (".mp4", ".mkv", ".avi", ".mov", ".wav", ".mp3", ".flac")
clip_paths = [
os.path.join(self._scan_folder, f)
for f in sorted(os.listdir(self._scan_folder))
if f.lower().endswith(exts)
]
if not clip_paths:
self._show_status("No reference clips found")
return
mode = self._cmb_scan_mode.currentText().lower()
threshold = self._sld_threshold.value()
self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path # remember which file we're scanning
self._show_status(f"Scanning with {len(clip_paths)} reference clips...")
from core.audio_scan import load_classifier, default_model_path
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
self._scan_worker = ScanWorker(self._file_path, clip_paths, mode, threshold)
if model is None:
self._show_status("No trained model — click Train first")
return
self._btn_scan.setEnabled(False)
self._scan_file_path = self._file_path
self._show_status("Scanning...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
self._scan_worker.scan_done.connect(self._on_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
@@ -2576,6 +2702,7 @@ class MainWindow(QMainWindow):
def _on_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
# Ignore stale results if the user switched files during scan
if self._file_path != getattr(self, '_scan_file_path', None):
return
@@ -2584,8 +2711,294 @@ class MainWindow(QMainWindow):
def _on_scan_error(self, msg: str) -> None:
self._btn_scan.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._show_status(f"Scan error: {msg}")
# ── Training ────────────────────────────────────────────────
def _cleanup_train_worker(self) -> None:
"""Disconnect signals and schedule deletion of old train worker."""
if self._train_worker is not None:
try:
self._train_worker.train_done.disconnect()
self._train_worker.error.disconnect()
self._train_worker.progress.disconnect()
except TypeError:
pass
if self._train_worker.isRunning():
self._train_worker.cancel()
self._train_worker.finished.connect(self._train_worker.deleteLater)
else:
self._train_worker.deleteLater()
self._train_worker = None
def _open_train_dialog(self):
"""Show the training config dialog and start training if accepted."""
if self._train_worker and self._train_worker.isRunning():
self._show_status("Training already in progress…")
return
# Default video dir: parent of currently loaded file, or saved setting
default_dir = ""
if self._file_path:
default_dir = os.path.dirname(self._file_path)
saved_dir = self._settings.value("train_video_dir", default_dir)
dlg = TrainDialog(self._db, self._profile,
video_dir=saved_dir or default_dir, parent=self)
if dlg.exec() != QDialog.DialogCode.Accepted:
return
pos_folder = dlg.positive_folder
embed_model = dlg.embed_model
video_dir = dlg.video_dir
if not pos_folder:
self._show_status("No positive class selected")
return
# Persist video dir for next time
if video_dir:
self._settings.setValue("train_video_dir", video_dir)
video_infos = self._db.get_training_data(
self._profile, pos_folder, fallback_video_dir=video_dir,
)
if not video_infos:
self._show_status("No training data found for this subprofile")
return
from core.audio_scan import default_model_path
model_path = default_model_path(self._profile)
self._cleanup_train_worker()
self._btn_train.setEnabled(False)
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
self._train_worker = TrainWorker(video_infos, model_path, embed_model)
self._train_worker.train_done.connect(self._on_train_done)
self._train_worker.error.connect(self._on_train_error)
self._train_worker.progress.connect(self._show_status)
self._train_worker.start()
def _on_train_done(self, model_path: str):
self._btn_train.setEnabled(True)
self._show_status(f"Model trained and saved")
_log(f"Training complete: {model_path}")
def _on_train_error(self, msg: str):
self._btn_train.setEnabled(True)
self._show_status(f"Training error: {msg}")
# ── Auto-export ─────────────────────────────────────────────
def _auto_export(self) -> None:
"""Scan → NMS → export one 8s clip per selected position."""
if not self._file_path:
self._show_status("No video loaded")
return
if self._export_worker and self._export_worker.isRunning():
self._show_status("Export already running…")
return
if self._scan_worker and self._scan_worker.isRunning():
self._show_status("Scan already running")
return
self._cleanup_scan_worker()
self._btn_auto_export.setEnabled(False)
self._btn_scan.setEnabled(False)
threshold = self._sld_threshold.value()
from core.audio_scan import load_classifier, default_model_path
model_path = default_model_path(self._profile)
model = load_classifier(model_path)
if model is not None:
self._scan_file_path = self._file_path
self._show_status("Auto: scanning with classifier...")
self._scan_worker = ScanWorker(
self._file_path, model=model, threshold=threshold,
)
else:
self._show_status("Auto: no trained model — click Train first")
self._btn_auto_export.setEnabled(True)
self._btn_scan.setEnabled(True)
return
self._scan_worker.scan_done.connect(self._on_auto_scan_done)
self._scan_worker.error.connect(self._on_scan_error)
self._scan_worker.progress.connect(self._show_status)
self._scan_worker.start()
@staticmethod
def _select_export_positions(regions: list[tuple[float, float, float]],
min_gap: float = 2.0,
cluster_fuse: float = 30.0,
) -> list[float]:
"""Cluster scan regions, then fill each cluster with clips spaced min_gap apart.
1. Merge overlapping regions into clusters, fusing clusters <cluster_fuse apart.
2. Within each cluster, greedily pick positions by score, min_gap apart.
"""
if not regions:
return []
# Build clusters — merge overlapping + fuse if gap < cluster_fuse
sorted_r = sorted(regions, key=lambda r: r[0])
clusters: list[list[tuple[float, float, float]]] = []
cur_start, cur_end = sorted_r[0][0], sorted_r[0][1]
cur_regions = [sorted_r[0]]
for start, end, score in sorted_r[1:]:
if start - cur_end <= cluster_fuse:
cur_end = max(cur_end, end)
cur_regions.append((start, end, score))
else:
clusters.append(cur_regions)
cur_start, cur_end = start, end
cur_regions = [(start, end, score)]
clusters.append(cur_regions)
# Within each cluster, NMS by score with min_gap
picked: list[float] = []
for cluster in clusters:
by_score = sorted(cluster, key=lambda r: -r[2])
cluster_picks: list[float] = []
for start, _end, _score in by_score:
if all(abs(start - p) >= min_gap for p in cluster_picks):
cluster_picks.append(start)
picked.extend(cluster_picks)
return sorted(picked)
def _on_auto_scan_done(self, regions: list) -> None:
self._btn_scan.setEnabled(True)
if self._file_path != getattr(self, '_scan_file_path', None):
self._btn_auto_export.setEnabled(True)
return
self._timeline.set_scan_regions(regions)
if not regions:
self._show_status("Auto: no regions found")
self._btn_auto_export.setEnabled(True)
return
positions = self._select_export_positions(
regions, min_gap=2.0, cluster_fuse=self._spn_auto_fuse.value(),
)
if not positions:
self._show_status("Auto: no positions after NMS")
self._btn_auto_export.setEnabled(True)
return
# Build export jobs — one 8s clip per position
folder = self._txt_folder.text()
name = self._txt_name.text() or "clip"
fmt = self._cmb_format.currentText()
image_sequence = fmt == "WebP sequence"
os.makedirs(folder, exist_ok=True)
# Find starting counter
counter = 1
while True:
if image_sequence:
p = build_sequence_dir(folder, name, counter, sub=0)
else:
p = build_export_path(folder, name, counter, sub=0)
if not os.path.exists(p):
break
counter += 1
jobs = []
self._auto_export_positions = [] # stash for DB writes
for start_t in positions:
group_dir = os.path.join(folder, f"{name}_{counter:03d}")
os.makedirs(group_dir, exist_ok=True)
if image_sequence:
out = build_sequence_dir(folder, name, counter, sub=0)
else:
out = build_export_path(folder, name, counter, sub=0)
jobs.append((start_t, out, None, 0.5))
self._auto_export_positions.append((start_t, counter))
counter += 1
self._show_status(f"Auto: exporting {len(jobs)} clips...")
short_side = self._spn_resize.value() or None
self._export_short_side = short_side
self._export_portrait = "Off"
self._export_format = fmt
self._export_clip_count = 1
self._export_spread = 0
self._export_folder = folder
self._export_folder_suffix = ""
hw_on = self._chk_hw.isChecked() and self._hw_encoders
encoder = self._hw_encoders[0] if hw_on else "libx264"
max_workers = min(self._spn_workers.value(), 3) if hw_on else self._spn_workers.value()
self._export_worker = ExportWorker(
self._file_path, jobs,
short_side=short_side,
image_sequence=image_sequence,
max_workers=max_workers,
encoder=encoder,
)
self._export_worker.finished.connect(self._on_auto_clip_done)
self._export_worker.all_done.connect(self._on_auto_batch_done)
self._export_worker.error.connect(self._on_export_error)
self._export_worker.cancelled.connect(self._on_export_cancelled)
self._btn_cancel.setEnabled(True)
self._btn_export.setEnabled(False)
self._set_subprofile_btns_enabled(False)
self._export_worker.start()
def _on_auto_clip_done(self, path: str):
"""Record each auto-exported clip to DB."""
# Find the start_time for this clip from stashed positions
counter_str = os.path.basename(os.path.dirname(path)) # e.g. "clip_042"
name = self._txt_name.text() or "clip"
start_t = None
for t, c in self._auto_export_positions:
if counter_str == f"{name}_{c:03d}":
start_t = t
break
label = self._txt_label.currentText().strip()
category = self._cmb_category.currentText()
self._db.add(
os.path.basename(self._file_path),
start_t or 0.0,
path,
label=label,
category=category,
short_side=self._export_short_side,
portrait_ratio="",
crop_center=0.5,
fmt=self._export_format,
clip_count=1,
spread=0,
profile=self._profile,
source_path=self._file_path,
)
upsert_clip_annotation(self._export_folder, path, label)
self._show_status(f"Auto: {os.path.basename(path)}")
_log(f" auto clip done: {os.path.basename(path)}")
def _on_auto_batch_done(self):
n = len(self._auto_export_positions)
self._btn_auto_export.setEnabled(True)
self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True)
self._set_subprofile_btns_enabled(True)
self._refresh_markers()
markers = self._db.get_markers(os.path.basename(self._file_path), self._profile)
self._playlist.mark_done(self._file_path, len(markers))
self._update_next_label()
self._show_status(f"Auto export complete: {n} clips")
_log(f"Auto export complete: {n} clips")
def _jump_to_next_scan_region(self) -> None:
regions = sorted(self._timeline._scan_regions, key=lambda r: r[0])
if not regions:
@@ -2812,6 +3225,7 @@ class MainWindow(QMainWindow):
clip_count=self._export_clip_count,
spread=self._export_spread,
profile=self._profile,
source_path=self._file_path,
)
upsert_clip_annotation(self._export_folder, path, label)
self._last_export_path = path
@@ -2851,6 +3265,7 @@ class MainWindow(QMainWindow):
_log(f"Export error: {msg}")
self._btn_cancel.setEnabled(False)
self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export")
self._btn_export.setStyleSheet("")
@@ -2866,6 +3281,7 @@ class MainWindow(QMainWindow):
def _on_export_cancelled(self):
_log("Export cancelled")
self._btn_export.setEnabled(True)
self._btn_auto_export.setEnabled(True)
self._set_subprofile_btns_enabled(True)
self._btn_export.setText("Export")
self._btn_export.setStyleSheet("")
@@ -2886,6 +3302,9 @@ class MainWindow(QMainWindow):
_log("Shutting down…")
# Save session playlist for resume.
self._settings.setValue("session_files", self._playlist._paths)
# Cancel background workers to prevent callbacks into dead objects.
self._cleanup_scan_worker()
self._cleanup_train_worker()
# Stop timers first to prevent callbacks into dead objects.
self._preview_timer.stop()
self._mpv._render_timer.stop()