feat: waveform overlay, signal safety, training cancel, dynamic batch size, duplicate detection
- WaveformWorker extracts low-res audio envelope via ffmpeg, drawn as green polygon on timeline track - _safe_disconnect() replaces bare TypeError catches for signal cleanup - Train button toggles to Cancel during training, calls worker.cancel() - Dynamic GPU batch sizing: 64 for ≥16GB VRAM, 32 for ≥8GB, 16 default - Overlap warning before exporting clips that intersect existing markers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -988,6 +988,42 @@ class ScanResultsPanel(QWidget):
|
||||
super().keyPressEvent(event)
|
||||
|
||||
|
||||
class WaveformWorker(QThread):
|
||||
"""Extract a low-res waveform envelope in the background."""
|
||||
done = pyqtSignal(object) # emits numpy array of peak values
|
||||
|
||||
def __init__(self, video_path: str, n_bins: int = 2000):
|
||||
super().__init__()
|
||||
self._path = video_path
|
||||
self._n_bins = n_bins
|
||||
|
||||
def run(self):
|
||||
import numpy as np
|
||||
try:
|
||||
cmd = [
|
||||
_bin("ffmpeg"), "-i", self._path,
|
||||
"-vn", "-ac", "1", "-ar", "8000",
|
||||
"-f", "f32le", "-loglevel", "error", "pipe:1",
|
||||
]
|
||||
proc = subprocess.run(cmd, capture_output=True, timeout=60)
|
||||
if proc.returncode != 0:
|
||||
return
|
||||
samples = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
if len(samples) == 0:
|
||||
return
|
||||
# Downsample to n_bins peak values
|
||||
bin_size = max(1, len(samples) // self._n_bins)
|
||||
n = (len(samples) // bin_size) * bin_size
|
||||
peaks = np.abs(samples[:n].reshape(-1, bin_size)).max(axis=1)
|
||||
# Normalize to 0-1
|
||||
mx = peaks.max()
|
||||
if mx > 0:
|
||||
peaks = peaks / mx
|
||||
self.done.emit(peaks)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TimelineWidget(QWidget):
|
||||
cursor_changed = pyqtSignal(float) # emits position in seconds
|
||||
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
|
||||
@@ -1020,6 +1056,9 @@ class TimelineWidget(QWidget):
|
||||
self._scan_regions: list[tuple[float, float, float, float, float]] = []
|
||||
self._scan_neg_times: set[float] = set()
|
||||
|
||||
# Waveform data (numpy array of 0-1 peak values, or None)
|
||||
self._waveform = None
|
||||
|
||||
# Edge-drag state for scan regions
|
||||
self._drag_idx: int | None = None # which region
|
||||
self._drag_edge: str | None = None # "left" or "right"
|
||||
@@ -1052,6 +1091,10 @@ class TimelineWidget(QWidget):
|
||||
self._rebuild_hover_cache()
|
||||
self.update()
|
||||
|
||||
def set_waveform(self, peaks) -> None:
|
||||
self._waveform = peaks
|
||||
self.update()
|
||||
|
||||
def set_clip_span(self, span: float):
|
||||
self._clip_span = span
|
||||
self.update()
|
||||
@@ -1193,6 +1236,28 @@ class TimelineWidget(QWidget):
|
||||
p.setPen(QPen(QColor(55, 55, 55)))
|
||||
p.drawLine(0, rh, w, rh)
|
||||
|
||||
# ── waveform ──────────────────────────────────────────────────
|
||||
if self._waveform is not None and len(self._waveform) > 0:
|
||||
n = len(self._waveform)
|
||||
mid_y = rh + th // 2
|
||||
half_h = th * 0.4 # waveform uses 80% of track height
|
||||
p.setPen(Qt.PenStyle.NoPen)
|
||||
p.setBrush(QColor(80, 180, 80, 50))
|
||||
from PyQt6.QtGui import QPolygonF
|
||||
from PyQt6.QtCore import QPointF
|
||||
pts = []
|
||||
# Top half (positive peaks)
|
||||
for i in range(n):
|
||||
x = i * w / n
|
||||
y = mid_y - self._waveform[i] * half_h
|
||||
pts.append(QPointF(x, y))
|
||||
# Bottom half (mirror)
|
||||
for i in range(n - 1, -1, -1):
|
||||
x = i * w / n
|
||||
y = mid_y + self._waveform[i] * half_h
|
||||
pts.append(QPointF(x, y))
|
||||
p.drawPolygon(QPolygonF(pts))
|
||||
|
||||
# ── selection region (full clip span) ─────────────────────────
|
||||
x_start = int(self._cursor / self._duration * w)
|
||||
if not self._scan_mode:
|
||||
@@ -3013,6 +3078,14 @@ class MainWindow(QMainWindow):
|
||||
)
|
||||
self._update_scan_export_count()
|
||||
|
||||
# Start waveform extraction in background
|
||||
self._timeline.set_waveform(None)
|
||||
if hasattr(self, '_waveform_worker') and self._waveform_worker is not None:
|
||||
self._waveform_worker.quit()
|
||||
self._waveform_worker = WaveformWorker(self._file_path)
|
||||
self._waveform_worker.done.connect(self._timeline.set_waveform)
|
||||
self._waveform_worker.start()
|
||||
|
||||
dur = self._mpv.get_duration()
|
||||
self._timeline.set_duration(dur)
|
||||
self._cursor = 0.0
|
||||
@@ -3560,15 +3633,22 @@ class MainWindow(QMainWindow):
|
||||
restore_model_version(chosen.data(), self._profile, embed_name)
|
||||
self._start_scan()
|
||||
|
||||
@staticmethod
|
||||
def _safe_disconnect(*signals) -> None:
|
||||
for sig in signals:
|
||||
try:
|
||||
sig.disconnect()
|
||||
except (TypeError, RuntimeError):
|
||||
pass
|
||||
|
||||
def _cleanup_scan_worker(self) -> None:
|
||||
"""Disconnect signals, cancel, and schedule deletion of old scan worker."""
|
||||
if self._scan_worker is not None:
|
||||
try:
|
||||
self._scan_worker.scan_done.disconnect()
|
||||
self._scan_worker.error.disconnect()
|
||||
self._scan_worker.progress.disconnect()
|
||||
except TypeError:
|
||||
pass # already disconnected
|
||||
self._safe_disconnect(
|
||||
self._scan_worker.scan_done,
|
||||
self._scan_worker.error,
|
||||
self._scan_worker.progress,
|
||||
)
|
||||
self._scan_worker.cancel()
|
||||
if self._scan_worker.isRunning():
|
||||
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
|
||||
@@ -3878,12 +3958,11 @@ class MainWindow(QMainWindow):
|
||||
def _cleanup_train_worker(self) -> None:
|
||||
"""Disconnect signals and schedule deletion of old train worker."""
|
||||
if self._train_worker is not None:
|
||||
try:
|
||||
self._train_worker.train_done.disconnect()
|
||||
self._train_worker.error.disconnect()
|
||||
self._train_worker.progress.disconnect()
|
||||
except TypeError:
|
||||
pass
|
||||
self._safe_disconnect(
|
||||
self._train_worker.train_done,
|
||||
self._train_worker.error,
|
||||
self._train_worker.progress,
|
||||
)
|
||||
if self._train_worker.isRunning():
|
||||
self._train_worker.cancel()
|
||||
self._train_worker.finished.connect(self._train_worker.deleteLater)
|
||||
@@ -3894,7 +3973,10 @@ class MainWindow(QMainWindow):
|
||||
def _open_train_dialog(self):
|
||||
"""Show the training config dialog and start training if accepted."""
|
||||
if self._train_worker and self._train_worker.isRunning():
|
||||
self._show_status("Training already in progress…")
|
||||
self._train_worker.cancel()
|
||||
self._btn_train.setText("Train")
|
||||
self._btn_train.setEnabled(False)
|
||||
self._show_status("Cancelling training…")
|
||||
return
|
||||
|
||||
# Default video dir: parent of currently loaded file, or saved setting
|
||||
@@ -3935,7 +4017,7 @@ class MainWindow(QMainWindow):
|
||||
model_path = default_model_path(self._profile, embed_model)
|
||||
|
||||
self._cleanup_train_worker()
|
||||
self._btn_train.setEnabled(False)
|
||||
self._btn_train.setText("Cancel")
|
||||
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
|
||||
|
||||
n_workers = self._spn_workers.value()
|
||||
@@ -3946,12 +4028,14 @@ class MainWindow(QMainWindow):
|
||||
self._train_worker.start()
|
||||
|
||||
def _on_train_done(self, model_path: str):
|
||||
self._btn_train.setText("Train")
|
||||
self._btn_train.setEnabled(True)
|
||||
self._refresh_scan_models()
|
||||
self._show_status(f"Model trained and saved")
|
||||
_log(f"Training complete: {model_path}")
|
||||
|
||||
def _on_train_error(self, msg: str):
|
||||
self._btn_train.setText("Train")
|
||||
self._btn_train.setEnabled(True)
|
||||
self._show_status(f"Training error: {msg}")
|
||||
|
||||
@@ -4232,6 +4316,17 @@ class MainWindow(QMainWindow):
|
||||
self._show_status("Export already running…")
|
||||
return
|
||||
|
||||
# Check for overlapping existing markers
|
||||
if not self._overwrite_path:
|
||||
clip_end = self._cursor + 8.0 + (self._spn_clips.value() - 1) * self._spn_spread.value()
|
||||
for t, _num, _path in self._timeline._markers:
|
||||
if abs(t - self._cursor) < 0.1:
|
||||
continue # same position (overwrite case)
|
||||
marker_end = t + 8.0
|
||||
if self._cursor < marker_end and clip_end > t:
|
||||
self._show_status("Warning: overlaps with existing export", 3000)
|
||||
break
|
||||
|
||||
fmt = self._cmb_format.currentText()
|
||||
image_sequence = fmt == "WebP sequence"
|
||||
folder = self._txt_folder.text()
|
||||
|
||||
Reference in New Issue
Block a user