feat: YOLO subject tracking for per-clip crop centering, fix end-frame preview

- Track subject checkbox: auto-adjusts crop center per sub-clip using
  YOLOv8-nano detection on each start frame
- Detects target nearest to user's crop click, follows same class across clips
- Graceful fallback when ultralytics not installed or detection fails
- Fix end-frame preview not updating on clip/spread change (retry on busy grabber)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-13 15:12:17 +02:00
parent bd37938a4a
commit 703874721b
2 changed files with 134 additions and 0 deletions
+133
View File
@@ -10,6 +10,7 @@ import random
import shutil
import sqlite3
import subprocess
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from difflib import SequenceMatcher
@@ -249,6 +250,110 @@ def _normalize_filename(filename: str) -> str:
return name
# ---------------------------------------------------------------------------
# Subject tracking (YOLO-based, optional)
# ---------------------------------------------------------------------------
_yolo_model = None
def _get_yolo():
"""Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed."""
global _yolo_model
if _yolo_model is None:
try:
from ultralytics import YOLO
_yolo_model = YOLO("yolov8n.pt")
_log("YOLO model loaded")
except ImportError:
_log("ultralytics not installed — tracking disabled")
return None
except Exception as e:
_log(f"YOLO load failed: {e}")
return None
return _yolo_model
def extract_frame_cv(video_path: str, time: float):
"""Extract a single frame as a numpy array (BGR) via ffmpeg → temp PNG → cv2."""
try:
import cv2
import numpy as np
except ImportError:
return None
fd, tmp = tempfile.mkstemp(suffix=".png")
os.close(fd)
try:
cmd = ["ffmpeg", "-y", "-ss", str(time), "-i", video_path,
"-frames:v", "1", tmp]
result = subprocess.run(cmd, capture_output=True, timeout=10)
if result.returncode != 0:
return None
return cv2.imread(tmp)
except Exception:
return None
finally:
if os.path.exists(tmp):
os.unlink(tmp)
def detect_subject_center(
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
) -> tuple[int | None, float, float] | None:
"""Detect objects at *time* and return (class_id, norm_x, norm_y) of the
best match to (target_cls, last_x, last_y). Returns None on failure."""
model = _get_yolo()
if model is None:
return None
frame = extract_frame_cv(video_path, time)
if frame is None:
return None
results = model(frame, verbose=False)
if not results or len(results[0].boxes) == 0:
return None
h, w = frame.shape[:2]
dets = []
for box in results[0].boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
cls = int(box.cls[0])
cx = (x1 + x2) / 2 / w
cy = (y1 + y2) / 2 / h
dets.append((cls, cx, cy))
# Prefer same class, nearest to last known position.
def score(d):
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
return cls_penalty + dist
best = min(dets, key=score)
return best
def track_centers_for_jobs(
video_path: str, cursor: float, crop_center: float,
starts: list[float],
) -> list[float]:
"""Run detection at the cursor (to identify the target) then at each start
time. Returns a list of horizontal crop centers (one per start)."""
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
if ref is None:
_log("Tracking: no detection at cursor, using fixed center")
return [crop_center] * len(starts)
target_cls, last_x, last_y = ref
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
centers = []
for t in starts:
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
if det is not None:
_, cx, cy = det
_log(f" t={t:.2f}s → center={cx:.3f}")
centers.append(cx)
last_x, last_y = cx, cy
else:
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
centers.append(last_x)
return centers
class ProcessedDB:
_SCHEMA_VERSION = 3 # bump when schema changes
@@ -1596,6 +1701,18 @@ class MainWindow(QMainWindow):
)
self._chk_rand_square.toggled.connect(self._on_rand_toggle)
self._chk_track = QCheckBox("Track subject")
self._chk_track.setToolTip(
"Auto-adjust crop center per sub-clip using YOLO detection\n"
"(requires: pip install ultralytics)"
)
self._chk_track.setChecked(
self._settings.value("track_subject", "false") == "true"
)
self._chk_track.toggled.connect(
lambda v: self._settings.setValue("track_subject", "true" if v else "false")
)
cpu_count = os.cpu_count() or 2
self._spn_workers = QSpinBox()
self._spn_workers.setRange(1, cpu_count)
@@ -1710,6 +1827,7 @@ class MainWindow(QMainWindow):
settings_row.addWidget(self._spn_spread)
settings_row.addWidget(self._chk_rand_portrait)
settings_row.addWidget(self._chk_rand_square)
settings_row.addWidget(self._chk_track)
settings_row.addStretch()
right = QWidget()
@@ -2126,6 +2244,8 @@ class MainWindow(QMainWindow):
if not self._file_path:
return
if self._frame_grabber and self._frame_grabber.isRunning():
# Previous grab still running — retry shortly.
self._preview_timer.start()
return
end_t = self._cursor + self._clip_span
dur = self._mpv.get_duration()
@@ -2298,6 +2418,19 @@ class MainWindow(QMainWindow):
s, o, _, _ = jobs[idx]
jobs[idx] = (s, o, random.choice(ratios), base_center)
# Subject tracking: re-detect crop center per sub-clip.
if self._chk_track.isChecked() and any(j[2] for j in jobs):
starts = [j[0] for j in jobs]
self.statusBar().showMessage(f"Tracking subject across {len(jobs)} clip(s)…")
QApplication.processEvents()
centers = track_centers_for_jobs(
self._file_path, self._cursor, base_center, starts,
)
jobs = [
(s, o, r, centers[i] if r else c)
for i, (s, o, r, c) in enumerate(jobs)
]
short_side = self._spn_resize.value() or None
# Stash export config for _on_clip_done DB writes.