From a0286d5cf9ac8b3b55932f367c4ef485e98f30cd Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 19 Apr 2026 12:53:48 +0200 Subject: [PATCH] feat: waveform overlay, signal safety, training cancel, dynamic batch size, duplicate detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - WaveformWorker extracts low-res audio envelope via ffmpeg, drawn as green polygon on timeline track - _safe_disconnect() replaces bare TypeError catches for signal cleanup - Train button toggles to Cancel during training, calls worker.cancel() - Dynamic GPU batch sizing: 64 for ≥16GB VRAM, 32 for ≥8GB, 16 default - Overlap warning before exporting clips that intersect existing markers Co-Authored-By: Claude Opus 4.6 --- core/audio_scan.py | 11 ++++ main.py | 123 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 120 insertions(+), 14 deletions(-) diff --git a/core/audio_scan.py b/core/audio_scan.py index be9fc4c..7d7fda3 100644 --- a/core/audio_scan.py +++ b/core/audio_scan.py @@ -171,7 +171,18 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR, import torch model, device = _get_w2v_model(model_name) is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS" + # Auto-size batches based on available GPU memory batch_size = 16 + if device == "cuda": + try: + vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9 + if vram_gb >= 16: + batch_size = 64 + elif vram_gb >= 8: + batch_size = 32 + _log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)") + except Exception: + pass timestamps = np.arange(n_windows) * hop embeddings = [] diff --git a/main.py b/main.py index 8e6da6f..836941a 100755 --- a/main.py +++ b/main.py @@ -988,6 +988,42 @@ class ScanResultsPanel(QWidget): super().keyPressEvent(event) +class WaveformWorker(QThread): + """Extract a low-res waveform envelope in the background.""" + done = pyqtSignal(object) # emits numpy array of peak values + + def __init__(self, video_path: str, n_bins: int = 2000): + super().__init__() + self._path = video_path + self._n_bins = n_bins + + def run(self): + import numpy as np + try: + cmd = [ + _bin("ffmpeg"), "-i", self._path, + "-vn", "-ac", "1", "-ar", "8000", + "-f", "f32le", "-loglevel", "error", "pipe:1", + ] + proc = subprocess.run(cmd, capture_output=True, timeout=60) + if proc.returncode != 0: + return + samples = np.frombuffer(proc.stdout, dtype=np.float32) + if len(samples) == 0: + return + # Downsample to n_bins peak values + bin_size = max(1, len(samples) // self._n_bins) + n = (len(samples) // bin_size) * bin_size + peaks = np.abs(samples[:n].reshape(-1, bin_size)).max(axis=1) + # Normalize to 0-1 + mx = peaks.max() + if mx > 0: + peaks = peaks / mx + self.done.emit(peaks) + except Exception: + pass + + class TimelineWidget(QWidget): cursor_changed = pyqtSignal(float) # emits position in seconds seek_changed = pyqtSignal(float) # emits seek position (lock mode) @@ -1020,6 +1056,9 @@ class TimelineWidget(QWidget): self._scan_regions: list[tuple[float, float, float, float, float]] = [] self._scan_neg_times: set[float] = set() + # Waveform data (numpy array of 0-1 peak values, or None) + self._waveform = None + # Edge-drag state for scan regions self._drag_idx: int | None = None # which region self._drag_edge: str | None = None # "left" or "right" @@ -1052,6 +1091,10 @@ class TimelineWidget(QWidget): self._rebuild_hover_cache() self.update() + def set_waveform(self, peaks) -> None: + self._waveform = peaks + self.update() + def set_clip_span(self, span: float): self._clip_span = span self.update() @@ -1193,6 +1236,28 @@ class TimelineWidget(QWidget): p.setPen(QPen(QColor(55, 55, 55))) p.drawLine(0, rh, w, rh) + # ── waveform ────────────────────────────────────────────────── + if self._waveform is not None and len(self._waveform) > 0: + n = len(self._waveform) + mid_y = rh + th // 2 + half_h = th * 0.4 # waveform uses 80% of track height + p.setPen(Qt.PenStyle.NoPen) + p.setBrush(QColor(80, 180, 80, 50)) + from PyQt6.QtGui import QPolygonF + from PyQt6.QtCore import QPointF + pts = [] + # Top half (positive peaks) + for i in range(n): + x = i * w / n + y = mid_y - self._waveform[i] * half_h + pts.append(QPointF(x, y)) + # Bottom half (mirror) + for i in range(n - 1, -1, -1): + x = i * w / n + y = mid_y + self._waveform[i] * half_h + pts.append(QPointF(x, y)) + p.drawPolygon(QPolygonF(pts)) + # ── selection region (full clip span) ───────────────────────── x_start = int(self._cursor / self._duration * w) if not self._scan_mode: @@ -3013,6 +3078,14 @@ class MainWindow(QMainWindow): ) self._update_scan_export_count() + # Start waveform extraction in background + self._timeline.set_waveform(None) + if hasattr(self, '_waveform_worker') and self._waveform_worker is not None: + self._waveform_worker.quit() + self._waveform_worker = WaveformWorker(self._file_path) + self._waveform_worker.done.connect(self._timeline.set_waveform) + self._waveform_worker.start() + dur = self._mpv.get_duration() self._timeline.set_duration(dur) self._cursor = 0.0 @@ -3560,15 +3633,22 @@ class MainWindow(QMainWindow): restore_model_version(chosen.data(), self._profile, embed_name) self._start_scan() + @staticmethod + def _safe_disconnect(*signals) -> None: + for sig in signals: + try: + sig.disconnect() + except (TypeError, RuntimeError): + pass + def _cleanup_scan_worker(self) -> None: """Disconnect signals, cancel, and schedule deletion of old scan worker.""" if self._scan_worker is not None: - try: - self._scan_worker.scan_done.disconnect() - self._scan_worker.error.disconnect() - self._scan_worker.progress.disconnect() - except TypeError: - pass # already disconnected + self._safe_disconnect( + self._scan_worker.scan_done, + self._scan_worker.error, + self._scan_worker.progress, + ) self._scan_worker.cancel() if self._scan_worker.isRunning(): self._scan_worker.finished.connect(self._scan_worker.deleteLater) @@ -3878,12 +3958,11 @@ class MainWindow(QMainWindow): def _cleanup_train_worker(self) -> None: """Disconnect signals and schedule deletion of old train worker.""" if self._train_worker is not None: - try: - self._train_worker.train_done.disconnect() - self._train_worker.error.disconnect() - self._train_worker.progress.disconnect() - except TypeError: - pass + self._safe_disconnect( + self._train_worker.train_done, + self._train_worker.error, + self._train_worker.progress, + ) if self._train_worker.isRunning(): self._train_worker.cancel() self._train_worker.finished.connect(self._train_worker.deleteLater) @@ -3894,7 +3973,10 @@ class MainWindow(QMainWindow): def _open_train_dialog(self): """Show the training config dialog and start training if accepted.""" if self._train_worker and self._train_worker.isRunning(): - self._show_status("Training already in progress…") + self._train_worker.cancel() + self._btn_train.setText("Train") + self._btn_train.setEnabled(False) + self._show_status("Cancelling training…") return # Default video dir: parent of currently loaded file, or saved setting @@ -3935,7 +4017,7 @@ class MainWindow(QMainWindow): model_path = default_model_path(self._profile, embed_model) self._cleanup_train_worker() - self._btn_train.setEnabled(False) + self._btn_train.setText("Cancel") self._show_status(f"Training {embed_model} on {len(video_infos)} videos...") n_workers = self._spn_workers.value() @@ -3946,12 +4028,14 @@ class MainWindow(QMainWindow): self._train_worker.start() def _on_train_done(self, model_path: str): + self._btn_train.setText("Train") self._btn_train.setEnabled(True) self._refresh_scan_models() self._show_status(f"Model trained and saved") _log(f"Training complete: {model_path}") def _on_train_error(self, msg: str): + self._btn_train.setText("Train") self._btn_train.setEnabled(True) self._show_status(f"Training error: {msg}") @@ -4232,6 +4316,17 @@ class MainWindow(QMainWindow): self._show_status("Export already running…") return + # Check for overlapping existing markers + if not self._overwrite_path: + clip_end = self._cursor + 8.0 + (self._spn_clips.value() - 1) * self._spn_spread.value() + for t, _num, _path in self._timeline._markers: + if abs(t - self._cursor) < 0.1: + continue # same position (overwrite case) + marker_end = t + 8.0 + if self._cursor < marker_end and clip_end > t: + self._show_status("Warning: overlaps with existing export", 3000) + break + fmt = self._cmb_format.currentText() image_sequence = fmt == "WebP sequence" folder = self._txt_folder.text()