feat: create core/tracking module with YOLO subject tracking
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from .paths import _bin, _log
|
||||||
|
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_yolo():
|
||||||
|
"""Lazy-load YOLOv8-nano. Returns None if ultralytics is not installed."""
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model is None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
_yolo_model = YOLO("yolov8n.pt")
|
||||||
|
_log("YOLO model loaded")
|
||||||
|
except ImportError:
|
||||||
|
_log("ultralytics not installed — tracking disabled")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
_log(f"YOLO load failed: {e}")
|
||||||
|
return None
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frame_cv(video_path: str, time: float):
|
||||||
|
"""Extract a single frame as a numpy array (BGR) via ffmpeg -> temp PNG -> cv2."""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
fd, tmp = tempfile.mkstemp(suffix=".png")
|
||||||
|
os.close(fd)
|
||||||
|
try:
|
||||||
|
cmd = [_bin("ffmpeg"), "-y", "-ss", str(time), "-i", video_path,
|
||||||
|
"-frames:v", "1", tmp]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, timeout=10)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return cv2.imread(tmp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_subject_center(
|
||||||
|
video_path: str, time: float, target_cls: int | None, last_x: float, last_y: float,
|
||||||
|
) -> tuple[int | None, float, float] | None:
|
||||||
|
"""Detect objects at *time* and return (class_id, norm_x, norm_y) of the
|
||||||
|
best match to (target_cls, last_x, last_y). Returns None on failure."""
|
||||||
|
model = _get_yolo()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
frame = extract_frame_cv(video_path, time)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
results = model(frame, verbose=False)
|
||||||
|
if not results or len(results[0].boxes) == 0:
|
||||||
|
return None
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
dets = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
cx = (x1 + x2) / 2 / w
|
||||||
|
cy = (y1 + y2) / 2 / h
|
||||||
|
dets.append((cls, cx, cy))
|
||||||
|
# Prefer same class, nearest to last known position.
|
||||||
|
def score(d):
|
||||||
|
cls_penalty = 0 if (target_cls is None or d[0] == target_cls) else 1.0
|
||||||
|
dist = (d[1] - last_x) ** 2 + (d[2] - last_y) ** 2
|
||||||
|
return cls_penalty + dist
|
||||||
|
best = min(dets, key=score)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def track_centers_for_jobs(
|
||||||
|
video_path: str, cursor: float, crop_center: float,
|
||||||
|
starts: list[float],
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run detection at the cursor (to identify the target) then at each start
|
||||||
|
time. Returns a list of horizontal crop centers (one per start)."""
|
||||||
|
ref = detect_subject_center(video_path, cursor, None, crop_center, 0.5)
|
||||||
|
if ref is None:
|
||||||
|
_log("Tracking: no detection at cursor, using fixed center")
|
||||||
|
return [crop_center] * len(starts)
|
||||||
|
target_cls, last_x, last_y = ref
|
||||||
|
_log(f"Tracking: target class={target_cls} at ({last_x:.2f}, {last_y:.2f})")
|
||||||
|
centers = []
|
||||||
|
for t in starts:
|
||||||
|
det = detect_subject_center(video_path, t, target_cls, last_x, last_y)
|
||||||
|
if det is not None:
|
||||||
|
_, cx, cy = det
|
||||||
|
_log(f" t={t:.2f}s → center={cx:.3f}")
|
||||||
|
centers.append(cx)
|
||||||
|
last_x, last_y = cx, cy
|
||||||
|
else:
|
||||||
|
_log(f" t={t:.2f}s → lost, reusing {last_x:.3f}")
|
||||||
|
centers.append(last_x)
|
||||||
|
return centers
|
||||||
Reference in New Issue
Block a user