feat: waveform overlay, signal safety, training cancel, dynamic batch size, duplicate detection
- WaveformWorker extracts low-res audio envelope via ffmpeg, drawn as green polygon on timeline track - _safe_disconnect() replaces bare TypeError catches for signal cleanup - Train button toggles to Cancel during training, calls worker.cancel() - Dynamic GPU batch sizing: 64 for ≥16GB VRAM, 32 for ≥8GB, 16 default - Overlap warning before exporting clips that intersect existing markers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -171,7 +171,18 @@ def _extract_w2v_windows(y: np.ndarray, sr: int = _SR,
|
|||||||
import torch
|
import torch
|
||||||
model, device = _get_w2v_model(model_name)
|
model, device = _get_w2v_model(model_name)
|
||||||
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
is_beats = (model_name or _DEFAULT_EMBED_MODEL) == "BEATS"
|
||||||
|
# Auto-size batches based on available GPU memory
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
||||||
|
if vram_gb >= 16:
|
||||||
|
batch_size = 64
|
||||||
|
elif vram_gb >= 8:
|
||||||
|
batch_size = 32
|
||||||
|
_log(f"audio_scan: batch_size={batch_size} (VRAM {vram_gb:.1f} GB)")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
timestamps = np.arange(n_windows) * hop
|
timestamps = np.arange(n_windows) * hop
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
|
|||||||
@@ -988,6 +988,42 @@ class ScanResultsPanel(QWidget):
|
|||||||
super().keyPressEvent(event)
|
super().keyPressEvent(event)
|
||||||
|
|
||||||
|
|
||||||
|
class WaveformWorker(QThread):
|
||||||
|
"""Extract a low-res waveform envelope in the background."""
|
||||||
|
done = pyqtSignal(object) # emits numpy array of peak values
|
||||||
|
|
||||||
|
def __init__(self, video_path: str, n_bins: int = 2000):
|
||||||
|
super().__init__()
|
||||||
|
self._path = video_path
|
||||||
|
self._n_bins = n_bins
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
import numpy as np
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
_bin("ffmpeg"), "-i", self._path,
|
||||||
|
"-vn", "-ac", "1", "-ar", "8000",
|
||||||
|
"-f", "f32le", "-loglevel", "error", "pipe:1",
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd, capture_output=True, timeout=60)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
return
|
||||||
|
samples = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||||
|
if len(samples) == 0:
|
||||||
|
return
|
||||||
|
# Downsample to n_bins peak values
|
||||||
|
bin_size = max(1, len(samples) // self._n_bins)
|
||||||
|
n = (len(samples) // bin_size) * bin_size
|
||||||
|
peaks = np.abs(samples[:n].reshape(-1, bin_size)).max(axis=1)
|
||||||
|
# Normalize to 0-1
|
||||||
|
mx = peaks.max()
|
||||||
|
if mx > 0:
|
||||||
|
peaks = peaks / mx
|
||||||
|
self.done.emit(peaks)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TimelineWidget(QWidget):
|
class TimelineWidget(QWidget):
|
||||||
cursor_changed = pyqtSignal(float) # emits position in seconds
|
cursor_changed = pyqtSignal(float) # emits position in seconds
|
||||||
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
|
seek_changed = pyqtSignal(float) # emits seek position (lock mode)
|
||||||
@@ -1020,6 +1056,9 @@ class TimelineWidget(QWidget):
|
|||||||
self._scan_regions: list[tuple[float, float, float, float, float]] = []
|
self._scan_regions: list[tuple[float, float, float, float, float]] = []
|
||||||
self._scan_neg_times: set[float] = set()
|
self._scan_neg_times: set[float] = set()
|
||||||
|
|
||||||
|
# Waveform data (numpy array of 0-1 peak values, or None)
|
||||||
|
self._waveform = None
|
||||||
|
|
||||||
# Edge-drag state for scan regions
|
# Edge-drag state for scan regions
|
||||||
self._drag_idx: int | None = None # which region
|
self._drag_idx: int | None = None # which region
|
||||||
self._drag_edge: str | None = None # "left" or "right"
|
self._drag_edge: str | None = None # "left" or "right"
|
||||||
@@ -1052,6 +1091,10 @@ class TimelineWidget(QWidget):
|
|||||||
self._rebuild_hover_cache()
|
self._rebuild_hover_cache()
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
def set_waveform(self, peaks) -> None:
|
||||||
|
self._waveform = peaks
|
||||||
|
self.update()
|
||||||
|
|
||||||
def set_clip_span(self, span: float):
|
def set_clip_span(self, span: float):
|
||||||
self._clip_span = span
|
self._clip_span = span
|
||||||
self.update()
|
self.update()
|
||||||
@@ -1193,6 +1236,28 @@ class TimelineWidget(QWidget):
|
|||||||
p.setPen(QPen(QColor(55, 55, 55)))
|
p.setPen(QPen(QColor(55, 55, 55)))
|
||||||
p.drawLine(0, rh, w, rh)
|
p.drawLine(0, rh, w, rh)
|
||||||
|
|
||||||
|
# ── waveform ──────────────────────────────────────────────────
|
||||||
|
if self._waveform is not None and len(self._waveform) > 0:
|
||||||
|
n = len(self._waveform)
|
||||||
|
mid_y = rh + th // 2
|
||||||
|
half_h = th * 0.4 # waveform uses 80% of track height
|
||||||
|
p.setPen(Qt.PenStyle.NoPen)
|
||||||
|
p.setBrush(QColor(80, 180, 80, 50))
|
||||||
|
from PyQt6.QtGui import QPolygonF
|
||||||
|
from PyQt6.QtCore import QPointF
|
||||||
|
pts = []
|
||||||
|
# Top half (positive peaks)
|
||||||
|
for i in range(n):
|
||||||
|
x = i * w / n
|
||||||
|
y = mid_y - self._waveform[i] * half_h
|
||||||
|
pts.append(QPointF(x, y))
|
||||||
|
# Bottom half (mirror)
|
||||||
|
for i in range(n - 1, -1, -1):
|
||||||
|
x = i * w / n
|
||||||
|
y = mid_y + self._waveform[i] * half_h
|
||||||
|
pts.append(QPointF(x, y))
|
||||||
|
p.drawPolygon(QPolygonF(pts))
|
||||||
|
|
||||||
# ── selection region (full clip span) ─────────────────────────
|
# ── selection region (full clip span) ─────────────────────────
|
||||||
x_start = int(self._cursor / self._duration * w)
|
x_start = int(self._cursor / self._duration * w)
|
||||||
if not self._scan_mode:
|
if not self._scan_mode:
|
||||||
@@ -3013,6 +3078,14 @@ class MainWindow(QMainWindow):
|
|||||||
)
|
)
|
||||||
self._update_scan_export_count()
|
self._update_scan_export_count()
|
||||||
|
|
||||||
|
# Start waveform extraction in background
|
||||||
|
self._timeline.set_waveform(None)
|
||||||
|
if hasattr(self, '_waveform_worker') and self._waveform_worker is not None:
|
||||||
|
self._waveform_worker.quit()
|
||||||
|
self._waveform_worker = WaveformWorker(self._file_path)
|
||||||
|
self._waveform_worker.done.connect(self._timeline.set_waveform)
|
||||||
|
self._waveform_worker.start()
|
||||||
|
|
||||||
dur = self._mpv.get_duration()
|
dur = self._mpv.get_duration()
|
||||||
self._timeline.set_duration(dur)
|
self._timeline.set_duration(dur)
|
||||||
self._cursor = 0.0
|
self._cursor = 0.0
|
||||||
@@ -3560,15 +3633,22 @@ class MainWindow(QMainWindow):
|
|||||||
restore_model_version(chosen.data(), self._profile, embed_name)
|
restore_model_version(chosen.data(), self._profile, embed_name)
|
||||||
self._start_scan()
|
self._start_scan()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_disconnect(*signals) -> None:
|
||||||
|
for sig in signals:
|
||||||
|
try:
|
||||||
|
sig.disconnect()
|
||||||
|
except (TypeError, RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
def _cleanup_scan_worker(self) -> None:
|
def _cleanup_scan_worker(self) -> None:
|
||||||
"""Disconnect signals, cancel, and schedule deletion of old scan worker."""
|
"""Disconnect signals, cancel, and schedule deletion of old scan worker."""
|
||||||
if self._scan_worker is not None:
|
if self._scan_worker is not None:
|
||||||
try:
|
self._safe_disconnect(
|
||||||
self._scan_worker.scan_done.disconnect()
|
self._scan_worker.scan_done,
|
||||||
self._scan_worker.error.disconnect()
|
self._scan_worker.error,
|
||||||
self._scan_worker.progress.disconnect()
|
self._scan_worker.progress,
|
||||||
except TypeError:
|
)
|
||||||
pass # already disconnected
|
|
||||||
self._scan_worker.cancel()
|
self._scan_worker.cancel()
|
||||||
if self._scan_worker.isRunning():
|
if self._scan_worker.isRunning():
|
||||||
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
|
self._scan_worker.finished.connect(self._scan_worker.deleteLater)
|
||||||
@@ -3878,12 +3958,11 @@ class MainWindow(QMainWindow):
|
|||||||
def _cleanup_train_worker(self) -> None:
|
def _cleanup_train_worker(self) -> None:
|
||||||
"""Disconnect signals and schedule deletion of old train worker."""
|
"""Disconnect signals and schedule deletion of old train worker."""
|
||||||
if self._train_worker is not None:
|
if self._train_worker is not None:
|
||||||
try:
|
self._safe_disconnect(
|
||||||
self._train_worker.train_done.disconnect()
|
self._train_worker.train_done,
|
||||||
self._train_worker.error.disconnect()
|
self._train_worker.error,
|
||||||
self._train_worker.progress.disconnect()
|
self._train_worker.progress,
|
||||||
except TypeError:
|
)
|
||||||
pass
|
|
||||||
if self._train_worker.isRunning():
|
if self._train_worker.isRunning():
|
||||||
self._train_worker.cancel()
|
self._train_worker.cancel()
|
||||||
self._train_worker.finished.connect(self._train_worker.deleteLater)
|
self._train_worker.finished.connect(self._train_worker.deleteLater)
|
||||||
@@ -3894,7 +3973,10 @@ class MainWindow(QMainWindow):
|
|||||||
def _open_train_dialog(self):
|
def _open_train_dialog(self):
|
||||||
"""Show the training config dialog and start training if accepted."""
|
"""Show the training config dialog and start training if accepted."""
|
||||||
if self._train_worker and self._train_worker.isRunning():
|
if self._train_worker and self._train_worker.isRunning():
|
||||||
self._show_status("Training already in progress…")
|
self._train_worker.cancel()
|
||||||
|
self._btn_train.setText("Train")
|
||||||
|
self._btn_train.setEnabled(False)
|
||||||
|
self._show_status("Cancelling training…")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Default video dir: parent of currently loaded file, or saved setting
|
# Default video dir: parent of currently loaded file, or saved setting
|
||||||
@@ -3935,7 +4017,7 @@ class MainWindow(QMainWindow):
|
|||||||
model_path = default_model_path(self._profile, embed_model)
|
model_path = default_model_path(self._profile, embed_model)
|
||||||
|
|
||||||
self._cleanup_train_worker()
|
self._cleanup_train_worker()
|
||||||
self._btn_train.setEnabled(False)
|
self._btn_train.setText("Cancel")
|
||||||
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
|
self._show_status(f"Training {embed_model} on {len(video_infos)} videos...")
|
||||||
|
|
||||||
n_workers = self._spn_workers.value()
|
n_workers = self._spn_workers.value()
|
||||||
@@ -3946,12 +4028,14 @@ class MainWindow(QMainWindow):
|
|||||||
self._train_worker.start()
|
self._train_worker.start()
|
||||||
|
|
||||||
def _on_train_done(self, model_path: str):
|
def _on_train_done(self, model_path: str):
|
||||||
|
self._btn_train.setText("Train")
|
||||||
self._btn_train.setEnabled(True)
|
self._btn_train.setEnabled(True)
|
||||||
self._refresh_scan_models()
|
self._refresh_scan_models()
|
||||||
self._show_status(f"Model trained and saved")
|
self._show_status(f"Model trained and saved")
|
||||||
_log(f"Training complete: {model_path}")
|
_log(f"Training complete: {model_path}")
|
||||||
|
|
||||||
def _on_train_error(self, msg: str):
|
def _on_train_error(self, msg: str):
|
||||||
|
self._btn_train.setText("Train")
|
||||||
self._btn_train.setEnabled(True)
|
self._btn_train.setEnabled(True)
|
||||||
self._show_status(f"Training error: {msg}")
|
self._show_status(f"Training error: {msg}")
|
||||||
|
|
||||||
@@ -4232,6 +4316,17 @@ class MainWindow(QMainWindow):
|
|||||||
self._show_status("Export already running…")
|
self._show_status("Export already running…")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check for overlapping existing markers
|
||||||
|
if not self._overwrite_path:
|
||||||
|
clip_end = self._cursor + 8.0 + (self._spn_clips.value() - 1) * self._spn_spread.value()
|
||||||
|
for t, _num, _path in self._timeline._markers:
|
||||||
|
if abs(t - self._cursor) < 0.1:
|
||||||
|
continue # same position (overwrite case)
|
||||||
|
marker_end = t + 8.0
|
||||||
|
if self._cursor < marker_end and clip_end > t:
|
||||||
|
self._show_status("Warning: overlaps with existing export", 3000)
|
||||||
|
break
|
||||||
|
|
||||||
fmt = self._cmb_format.currentText()
|
fmt = self._cmb_format.currentText()
|
||||||
image_sequence = fmt == "WebP sequence"
|
image_sequence = fmt == "WebP sequence"
|
||||||
folder = self._txt_folder.text()
|
folder = self._txt_folder.text()
|
||||||
|
|||||||
Reference in New Issue
Block a user