diff --git a/main.py b/main.py index 609026a..6fea374 100755 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ import random import shutil import sqlite3 import subprocess +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from difflib import SequenceMatcher @@ -249,6 +250,110 @@ def _normalize_filename(filename: str) -> str: return name +# --------------------------------------------------------------------------- +# Subject tracking (YOLO-based, optional) +# --------------------------------------------------------------------------- + +_yolo_model = None + + +def _get_yolo(): + """Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed.""" + global _yolo_model + if _yolo_model is None: + try: + from ultralytics import YOLO + _yolo_model = YOLO("yolov8n.pt") + _log("YOLO model loaded") + except ImportError: + _log("ultralytics not installed — tracking disabled") + return None + except Exception as e: + _log(f"YOLO load failed: {e}") + return None + return _yolo_model + + +def extract_frame_cv(video_path: str, time: float): + """Extract a single frame as a numpy array (BGR) via ffmpeg → temp PNG → cv2.""" + try: + import cv2 + import numpy as np + except ImportError: + return None + fd, tmp = tempfile.mkstemp(suffix=".png") + os.close(fd) + try: + cmd = ["ffmpeg", "-y", "-ss", str(time), "-i", video_path, + "-frames:v", "1", tmp] + result = subprocess.run(cmd, capture_output=True, timeout=10) + if result.returncode != 0: + return None + return cv2.imread(tmp) + except Exception: + return None + finally: + if os.path.exists(tmp): + os.unlink(tmp) + + +def detect_subject_center( + video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float, +) -> tuple[int | None, float, float] | None: + """Detect objects at *time* and return (class_id, norm_x, norm_y) of the + best match to (target_cls, last_x, last_y). Returns None on failure.""" + model = _get_yolo() + if model is None: + return None + frame = extract_frame_cv(video_path, time) + if frame is None: + return None + results = model(frame, verbose=False) + if not results or len(results[0].boxes) == 0: + return None + h, w = frame.shape[:2] + dets = [] + for box in results[0].boxes: + x1, y1, x2, y2 = box.xyxy[0].tolist() + cls = int(box.cls[0]) + cx = (x1 + x2) / 2 / w + cy = (y1 + y2) / 2 / h + dets.append((cls, cx, cy)) + # Prefer same class, nearest to last known position. + def score(d): + cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0 + dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2 + return cls_penalty + dist + best = min(dets, key=score) + return best + + +def track_centers_for_jobs( + video_path: str, cursor: float, crop_center: float, + starts: list[float], +) -> list[float]: + """Run detection at the cursor (to identify the target) then at each start + time. Returns a list of horizontal crop centers (one per start).""" + ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5) + if ref is None: + _log("Tracking: no detection at cursor, using fixed center") + return [crop_center] * len(starts) + target_cls, last_x, last_y = ref + _log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})") + centers = [] + for t in starts: + det = detect_subject_center(video_path, t, target_cls, last_x, last_y) + if det is not None: + _, cx, cy = det + _log(f" t={t:.2f}s → center={cx:.3f}") + centers.append(cx) + last_x, last_y = cx, cy + else: + _log(f" t={t:.2f}s → lost, reusing {last_x:.3f}") + centers.append(last_x) + return centers + + class ProcessedDB: _SCHEMA_VERSION = 3 # bump when schema changes @@ -1596,6 +1701,18 @@ class MainWindow(QMainWindow): ) self._chk_rand_square.toggled.connect(self._on_rand_toggle) + self._chk_track = QCheckBox("Track subject") + self._chk_track.setToolTip( + "Auto-adjust crop center per sub-clip using YOLO detection\n" + "(requires: pip install ultralytics)" + ) + self._chk_track.setChecked( + self._settings.value("track_subject", "false") == "true" + ) + self._chk_track.toggled.connect( + lambda v: self._settings.setValue("track_subject", "true" if v else "false") + ) + cpu_count = os.cpu_count() or 2 self._spn_workers = QSpinBox() self._spn_workers.setRange(1, cpu_count) @@ -1710,6 +1827,7 @@ class MainWindow(QMainWindow): settings_row.addWidget(self._spn_spread) settings_row.addWidget(self._chk_rand_portrait) settings_row.addWidget(self._chk_rand_square) + settings_row.addWidget(self._chk_track) settings_row.addStretch() right = QWidget() @@ -2126,6 +2244,8 @@ class MainWindow(QMainWindow): if not self._file_path: return if self._frame_grabber and self._frame_grabber.isRunning(): + # Previous grab still running — retry shortly. + self._preview_timer.start() return end_t = self._cursor + self._clip_span dur = self._mpv.get_duration() @@ -2298,6 +2418,19 @@ class MainWindow(QMainWindow): s, o, _, _ = jobs[idx] jobs[idx] = (s, o, random.choice(ratios), base_center) + # Subject tracking: re-detect crop center per sub-clip. + if self._chk_track.isChecked() and any(j[2] for j in jobs): + starts = [j[0] for j in jobs] + self.statusBar().showMessage(f"Tracking subject across {len(jobs)} clip(s)…") + QApplication.processEvents() + centers = track_centers_for_jobs( + self._file_path, self._cursor, base_center, starts, + ) + jobs = [ + (s, o, r, centers[i] if r else c) + for i, (s, o, r, c) in enumerate(jobs) + ] + short_side = self._spn_resize.value() or None # Stash export config for _on_clip_done DB writes. diff --git a/requirements.txt b/requirements.txt index 8d8db26..180af07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ PyQt6>=6.4 python-mpv>=1.0 pytest>=7.0 +ultralytics>=8.0