feat: integrate training UI, BEATs model, and clean up legacy code
- Remove legacy distance-mode scanning (build_profile, _similarity, etc.) and hand-crafted intensity features — pipeline is now embedding-only - Integrate Microsoft BEATs as embedding option alongside wav2vec2/HuBERT - Add TrainDialog with positive class selector, model picker, video dir fallback, and live training stats - Add TrainWorker QThread with cancel support and proper lifecycle cleanup - Add source_path column to DB for robust source video tracking - Add get_export_folders/get_training_data/get_training_stats to DB - Wire source_path in all export DB writes (_on_clip_done, _on_auto_clip_done) - Cancel scan/train workers in closeEvent to prevent use-after-free crashes - Add setup_env.sh supporting both conda and python venv (CUDA 12.8) - Update requirements.txt with all actual dependencies - Update 8cut_train.py with --positive flag for new DB-driven training Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+106
-5
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
@@ -7,7 +8,7 @@ from .paths import _log
|
||||
|
||||
|
||||
class ProcessedDB:
|
||||
_SCHEMA_VERSION = 3 # bump when schema changes
|
||||
_SCHEMA_VERSION = 4 # bump when schema changes
|
||||
|
||||
def __init__(self, db_path: str | None = None):
|
||||
if db_path is None:
|
||||
@@ -47,6 +48,7 @@ class ProcessedDB:
|
||||
" clip_count INTEGER NOT NULL DEFAULT 3,"
|
||||
" spread REAL NOT NULL DEFAULT 3.0,"
|
||||
" profile TEXT NOT NULL DEFAULT 'default',"
|
||||
" source_path TEXT NOT NULL DEFAULT '',"
|
||||
" processed_at TEXT NOT NULL"
|
||||
")"
|
||||
)
|
||||
@@ -62,6 +64,7 @@ class ProcessedDB:
|
||||
"clip_count": "INTEGER NOT NULL DEFAULT 3",
|
||||
"spread": "REAL NOT NULL DEFAULT 3.0",
|
||||
"profile": "TEXT NOT NULL DEFAULT 'default'",
|
||||
"source_path": "TEXT NOT NULL DEFAULT ''",
|
||||
}
|
||||
for col, typedef in new_cols.items():
|
||||
if col not in cols:
|
||||
@@ -85,7 +88,7 @@ class ProcessedDB:
|
||||
short_side: int | None = None, portrait_ratio: str = "",
|
||||
crop_center: float = 0.5, fmt: str = "MP4",
|
||||
clip_count: int = 3, spread: float = 3.0,
|
||||
profile: str = "default") -> None:
|
||||
profile: str = "default", source_path: str = "") -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
with self._lock:
|
||||
@@ -93,11 +96,11 @@ class ProcessedDB:
|
||||
"INSERT INTO processed"
|
||||
" (filename, start_time, output_path, label, category,"
|
||||
" short_side, portrait_ratio, crop_center, format,"
|
||||
" clip_count, spread, profile, processed_at)"
|
||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
" clip_count, spread, profile, source_path, processed_at)"
|
||||
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(filename, start_time, output_path, label, category,
|
||||
short_side, portrait_ratio, crop_center, fmt,
|
||||
clip_count, spread, profile,
|
||||
clip_count, spread, profile, source_path,
|
||||
datetime.now(timezone.utc).isoformat()),
|
||||
)
|
||||
self._con.commit()
|
||||
@@ -223,6 +226,104 @@ class ProcessedDB:
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def get_export_folders(self, profile: str = "default") -> list[str]:
|
||||
"""Return distinct export folder names found in output_paths for a profile.
|
||||
|
||||
Export paths follow the structure:
|
||||
.../export_folder/group_dir/clip.mp4
|
||||
The export folder is 2 levels up from the clip file.
|
||||
Returns folder names sorted alphabetically (e.g. ["mp4_Intense", "mp4_Soft"]).
|
||||
"""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT DISTINCT output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
folder_names: set[str] = set()
|
||||
for (op,) in rows:
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent:
|
||||
folder_names.add(grandparent)
|
||||
return sorted(folder_names)
|
||||
|
||||
def get_training_data(self, profile: str, positive_folder: str,
|
||||
fallback_video_dir: str = "",
|
||||
) -> list[tuple[str, list[float], list[float]]]:
|
||||
"""Build training video_infos from DB data.
|
||||
|
||||
Args:
|
||||
profile: profile name
|
||||
positive_folder: export folder name for positive class (e.g. "mp4_Intense")
|
||||
fallback_video_dir: if source_path is empty, try filename in this dir
|
||||
|
||||
Returns:
|
||||
list of (source_video_path, positive_times, soft_times) per video.
|
||||
Soft times = clips from any other export folder.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return []
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, start_time, output_path, source_path"
|
||||
" FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
|
||||
# Collect times by video, split by positive vs other folders
|
||||
pos_by_video: dict[str, set[float]] = {}
|
||||
soft_by_video: dict[str, set[float]] = {}
|
||||
source_by_filename: dict[str, str] = {}
|
||||
|
||||
for fn, st, op, sp in rows:
|
||||
if sp:
|
||||
source_by_filename[fn] = sp
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent == positive_folder:
|
||||
pos_by_video.setdefault(fn, set()).add(st)
|
||||
else:
|
||||
soft_by_video.setdefault(fn, set()).add(st)
|
||||
|
||||
result = []
|
||||
for fn in pos_by_video:
|
||||
sp = source_by_filename.get(fn, "")
|
||||
if not sp or not os.path.exists(sp):
|
||||
# Fallback: try video_dir / filename
|
||||
if fallback_video_dir:
|
||||
sp = os.path.join(fallback_video_dir, fn)
|
||||
if not sp or not os.path.exists(sp):
|
||||
continue
|
||||
gt_pos = sorted(pos_by_video[fn])
|
||||
gt_soft = sorted(soft_by_video.get(fn, set()))
|
||||
result.append((sp, gt_pos, gt_soft))
|
||||
return result
|
||||
|
||||
def get_training_stats(self, profile: str) -> dict[str, dict]:
|
||||
"""Return per-subprofile stats for training readiness display.
|
||||
|
||||
Returns dict mapping subprofile_name → {
|
||||
'videos': number of distinct source videos,
|
||||
'clips': total clip count,
|
||||
}
|
||||
"""
|
||||
if not self._enabled:
|
||||
return {}
|
||||
rows = self._con.execute(
|
||||
"SELECT filename, output_path FROM processed WHERE profile = ?",
|
||||
(profile,),
|
||||
).fetchall()
|
||||
folders = self.get_export_folders(profile)
|
||||
stats: dict[str, dict] = {}
|
||||
for folder_name in folders:
|
||||
videos: set[str] = set()
|
||||
clips = 0
|
||||
for fn, op in rows:
|
||||
grandparent = os.path.basename(os.path.dirname(os.path.dirname(op)))
|
||||
if grandparent == folder_name:
|
||||
videos.add(fn)
|
||||
clips += 1
|
||||
stats[folder_name] = {"videos": len(videos), "clips": clips}
|
||||
return stats
|
||||
|
||||
def hide_file(self, filename: str, profile: str = "default") -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user