feat: waveform overlay, signal safety, training cancel, dynamic batch size, duplicate detection

- WaveformWorker extracts low-res audio envelope via ffmpeg, drawn as
  green polygon on timeline track
- _safe_disconnect() replaces bare TypeError catches for signal cleanup
- Train button toggles to Cancel during training, calls worker.cancel()
- Dynamic GPU batch sizing: 64 for ≥16GB VRAM, 32 for ≥8GB, 16 default
- Overlap warning before exporting clips that intersect existing markers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-19 12:53:48 +02:00
parent 2b7dfb330d
commit a0286d5cf9
2 changed files with 120 additions and 14 deletions
+11
View File
@@ -171,7 +171,18 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
import torch import torch
model, device = _get_w2v_model(model_name) model, device = _get_w2v_model(model_name)
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
# Auto-size batches based on available GPU memory
batch_size = 16 batch_size = 16
if device == "cuda":
try:
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
if vram_gb >= 16:
batch_size = 64
elif vram_gb >= 8:
batch_size = 32
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
except Exception:
pass
timestamps = np.arange(n_windows) * hop timestamps = np.arange(n_windows) * hop
embeddings = [] embeddings = []
+109 -14
View File
@@ -988,6 +988,42 @@ class ScanResultsPanel(QWidget):
super().keyPressEvent(event) super().keyPressEvent(event)
class WaveformWorker(QThread):
"""Extract a low-res waveform envelope in the background."""
done = pyqtSignal(object) # emits numpy array of peak values
def __init__(self, video_path: str, n_bins: int = 2000):
super().__init__()
self._path = video_path
self._n_bins = n_bins
def run(self):
import numpy as np
try:
cmd = [
_bin("ffmpeg"), "-i", self._path,
"-vn", "-ac", "1", "-ar", "8000",
"-f", "f32le", "-loglevel", "error", "pipe:1",
]
proc = subprocess.run(cmd, capture_output=True, timeout=60)
if proc.returncode != 0:
return
samples = np.frombuffer(proc.stdout, dtype=np.float32)
if len(samples) == 0:
return
# Downsample to n_bins peak values
bin_size = max(1, len(samples) // self._n_bins)
n = (len(samples) // bin_size) * bin_size
peaks = np.abs(samples[:n].reshape(-1, bin_size)).max(axis=1)
# Normalize to 0-1
mx = peaks.max()
if mx > 0:
peaks = peaks / mx
self.done.emit(peaks)
except Exception:
pass
class TimelineWidget(QWidget): class TimelineWidget(QWidget):
cursor_changed = pyqtSignal(float) # emits position in seconds cursor_changed = pyqtSignal(float) # emits position in seconds
seek_changed = pyqtSignal(float) # emits seek position (lock mode) seek_changed = pyqtSignal(float) # emits seek position (lock mode)
@@ -1020,6 +1056,9 @@ class TimelineWidget(QWidget):
self._scan_regions: list[tuple[float, float, float, float, float]] = [] self._scan_regions: list[tuple[float, float, float, float, float]] = []
self._scan_neg_times: set[float] = set() self._scan_neg_times: set[float] = set()
# Waveform data (numpy array of 0-1 peak values, or None)
self._waveform = None
# Edge-drag state for scan regions # Edge-drag state for scan regions
self._drag_idx: int | None = None # which region self._drag_idx: int | None = None # which region
self._drag_edge: str | None = None # "left" or "right" self._drag_edge: str | None = None # "left" or "right"
@@ -1052,6 +1091,10 @@ class TimelineWidget(QWidget):
self._rebuild_hover_cache() self._rebuild_hover_cache()
self.update() self.update()
def set_waveform(self, peaks) -> None:
self._waveform = peaks
self.update()
def set_clip_span(self, span: float): def set_clip_span(self, span: float):
self._clip_span = span self._clip_span = span
self.update() self.update()
@@ -1193,6 +1236,28 @@ class TimelineWidget(QWidget):
p.setPen(QPen(QColor(55, 55, 55))) p.setPen(QPen(QColor(55, 55, 55)))
p.drawLine(0, rh, w, rh) p.drawLine(0, rh, w, rh)
# ── waveform ──────────────────────────────────────────────────
if self._waveform is not None and len(self._waveform) > 0:
n = len(self._waveform)
mid_y = rh + th // 2
half_h = th * 0.4 # waveform uses 80% of track height
p.setPen(Qt.PenStyle.NoPen)
p.setBrush(QColor(80, 180, 80, 50))
from PyQt6.QtGui import QPolygonF
from PyQt6.QtCore import QPointF
pts = []
# Top half (positive peaks)
for i in range(n):
x = i * w / n
y = mid_y - self._waveform[i] * half_h
pts.append(QPointF(x, y))
# Bottom half (mirror)
for i in range(n - 1, -1, -1):
x = i * w / n
y = mid_y + self._waveform[i] * half_h
pts.append(QPointF(x, y))
p.drawPolygon(QPolygonF(pts))
# ── selection region (full clip span) ───────────────────────── # ── selection region (full clip span) ─────────────────────────
x_start = int(self._cursor / self._duration * w) x_start = int(self._cursor / self._duration * w)
if not self._scan_mode: if not self._scan_mode:
@@ -3013,6 +3078,14 @@ class MainWindow(QMainWindow):
) )
self._update_scan_export_count() self._update_scan_export_count()
# Start waveform extraction in background
self._timeline.set_waveform(None)
if hasattr(self, '_waveform_worker') and self._waveform_worker is not None:
self._waveform_worker.quit()
self._waveform_worker = WaveformWorker(self._file_path)
self._waveform_worker.done.connect(self._timeline.set_waveform)
self._waveform_worker.start()
dur = self._mpv.get_duration() dur = self._mpv.get_duration()
self._timeline.set_duration(dur) self._timeline.set_duration(dur)
self._cursor = 0.0 self._cursor = 0.0
@@ -3560,15 +3633,22 @@ class MainWindow(QMainWindow):
restore_model_version(chosen.data(), self._profile, embed_name) restore_model_version(chosen.data(), self._profile, embed_name)
self._start_scan() self._start_scan()
@staticmethod
def _safe_disconnect(*signals) -> None:
for sig in signals:
try:
sig.disconnect()
except (TypeError, RuntimeError):
pass
def _cleanup_scan_worker(self) -> None: def _cleanup_scan_worker(self) -> None:
"""Disconnect signals, cancel, and schedule deletion of old scan worker.""" """Disconnect signals, cancel, and schedule deletion of old scan worker."""
if self._scan_worker is not None: if self._scan_worker is not None:
try: self._safe_disconnect(
self._scan_worker.scan_done.disconnect() self._scan_worker.scan_done,
self._scan_worker.error.disconnect() self._scan_worker.error,
self._scan_worker.progress.disconnect() self._scan_worker.progress,
except TypeError: )
pass # already disconnected
self._scan_worker.cancel() self._scan_worker.cancel()
if self._scan_worker.isRunning(): if self._scan_worker.isRunning():
self._scan_worker.finished.connect(self._scan_worker.deleteLater) self._scan_worker.finished.connect(self._scan_worker.deleteLater)
@@ -3878,12 +3958,11 @@ class MainWindow(QMainWindow):
def _cleanup_train_worker(self) -> None: def _cleanup_train_worker(self) -> None:
"""Disconnect signals and schedule deletion of old train worker.""" """Disconnect signals and schedule deletion of old train worker."""
if self._train_worker is not None: if self._train_worker is not None:
try: self._safe_disconnect(
self._train_worker.train_done.disconnect() self._train_worker.train_done,
self._train_worker.error.disconnect() self._train_worker.error,
self._train_worker.progress.disconnect() self._train_worker.progress,
except TypeError: )
pass
if self._train_worker.isRunning(): if self._train_worker.isRunning():
self._train_worker.cancel() self._train_worker.cancel()
self._train_worker.finished.connect(self._train_worker.deleteLater) self._train_worker.finished.connect(self._train_worker.deleteLater)
@@ -3894,7 +3973,10 @@ class MainWindow(QMainWindow):
def _open_train_dialog(self): def _open_train_dialog(self):
"""Show the training config dialog and start training if accepted.""" """Show the training config dialog and start training if accepted."""
if self._train_worker and self._train_worker.isRunning(): if self._train_worker and self._train_worker.isRunning():
self._show_status("Training already in progress…") self._train_worker.cancel()
self._btn_train.setText("Train")
self._btn_train.setEnabled(False)
self._show_status("Cancelling training…")
return return
# Default video dir: parent of currently loaded file, or saved setting # Default video dir: parent of currently loaded file, or saved setting
@@ -3935,7 +4017,7 @@ class MainWindow(QMainWindow):
model_path = default_model_path(self._profile, embed_model) model_path = default_model_path(self._profile, embed_model)
self._cleanup_train_worker() self._cleanup_train_worker()
self._btn_train.setEnabled(False) self._btn_train.setText("Cancel")
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...") self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
n_workers = self._spn_workers.value() n_workers = self._spn_workers.value()
@@ -3946,12 +4028,14 @@ class MainWindow(QMainWindow):
self._train_worker.start() self._train_worker.start()
def _on_train_done(self, model_path: str): def _on_train_done(self, model_path: str):
self._btn_train.setText("Train")
self._btn_train.setEnabled(True) self._btn_train.setEnabled(True)
self._refresh_scan_models() self._refresh_scan_models()
self._show_status(f"Model trained and saved") self._show_status(f"Model trained and saved")
_log(f"Training complete: {model_path}") _log(f"Training complete: {model_path}")
def _on_train_error(self, msg: str): def _on_train_error(self, msg: str):
self._btn_train.setText("Train")
self._btn_train.setEnabled(True) self._btn_train.setEnabled(True)
self._show_status(f"Training error: {msg}") self._show_status(f"Training error: {msg}")
@@ -4232,6 +4316,17 @@ class MainWindow(QMainWindow):
self._show_status("Export already running…") self._show_status("Export already running…")
return return
# Check for overlapping existing markers
if not self._overwrite_path:
clip_end = self._cursor + 8.0 + (self._spn_clips.value() - 1) * self._spn_spread.value()
for t, _num, _path in self._timeline._markers:
if abs(t - self._cursor) < 0.1:
continue # same position (overwrite case)
marker_end = t + 8.0
if self._cursor < marker_end and clip_end > t:
self._show_status("Warning: overlaps with existing export", 3000)
break
fmt = self._cmb_format.currentText() fmt = self._cmb_format.currentText()
image_sequence = fmt == "WebP sequence" image_sequence = fmt == "WebP sequence"
folder = self._txt_folder.text() folder = self._txt_folder.text()