12ed183f1b
- Remove legacy distance-mode scanning (build_profile, _similarity, etc.) and hand-crafted intensity features — pipeline is now embedding-only - Integrate Microsoft BEATs as embedding option alongside wav2vec2/HuBERT - Add TrainDialog with positive class selector, model picker, video dir fallback, and live training stats - Add TrainWorker QThread with cancel support and proper lifecycle cleanup - Add source_path column to DB for robust source video tracking - Add get_export_folders/get_training_data/get_training_stats to DB - Wire source_path in all export DB writes (_on_clip_done, _on_auto_clip_done) - Cancel scan/train workers in closeEvent to prevent use-after-free crashes - Add setup_env.sh supporting both conda and python venv (CUDA 12.8) - Update requirements.txt with all actual dependencies - Update 8cut_train.py with --positive flag for new DB-driven training Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Train an audio scan classifier from DB ground truth.
|
|
|
|
Usage:
|
|
python 8cut_train.py # default model, auto-detect positive
|
|
python 8cut_train.py --model BEATS # specific embedding model
|
|
python 8cut_train.py --positive mp4_Intense # explicit positive folder
|
|
python 8cut_train.py --positive mp4_Intense --model BEATS # both
|
|
"""
|
|
import sys, os, warnings
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
warnings.filterwarnings("ignore")
|
|
|
|
from core.audio_scan import train_classifier, default_model_path, _EMBED_MODELS
|
|
from core.db import ProcessedDB
|
|
|
|
PROFILE_NAME = "JAV_missionary"
|
|
|
|
# Fallback for old DB rows without source_path
|
|
PLEX_DIR = "/media/unraid/appdata/plex/download/porn_jav/"
|
|
|
|
|
|
def main():
|
|
embed_model = None
|
|
if "--model" in sys.argv:
|
|
idx = sys.argv.index("--model")
|
|
if idx + 1 < len(sys.argv):
|
|
embed_model = sys.argv[idx + 1]
|
|
if embed_model not in _EMBED_MODELS:
|
|
print(f"Unknown model: {embed_model}")
|
|
print(f"Available: {', '.join(_EMBED_MODELS)}")
|
|
sys.exit(1)
|
|
|
|
positive_suffix = None
|
|
if "--positive" in sys.argv:
|
|
idx = sys.argv.index("--positive")
|
|
if idx + 1 < len(sys.argv):
|
|
positive_suffix = sys.argv[idx + 1]
|
|
|
|
db = ProcessedDB()
|
|
|
|
# If --positive given, use the new DB helper
|
|
if positive_suffix:
|
|
video_infos = db.get_training_data(
|
|
PROFILE_NAME, positive_suffix, fallback_video_dir=PLEX_DIR,
|
|
)
|
|
if not video_infos:
|
|
print(f"No training data found for positive='{positive_suffix}'")
|
|
sys.exit(1)
|
|
else:
|
|
# Legacy fallback: classify by folder path pattern
|
|
rows = db._con.execute(
|
|
"SELECT filename, start_time, output_path, source_path"
|
|
" FROM processed WHERE profile = ?",
|
|
(PROFILE_NAME,),
|
|
).fetchall()
|
|
|
|
intense_by_video, soft_by_video = {}, {}
|
|
source_by_fn = {}
|
|
for fn, st, op, sp in rows:
|
|
if sp:
|
|
source_by_fn[fn] = sp
|
|
if "/mp4_Intense/" in op or "_Intense/" in op:
|
|
intense_by_video.setdefault(fn, set()).add(st)
|
|
elif "/mp4_Soft/" in op or "_Soft/" in op:
|
|
soft_by_video.setdefault(fn, set()).add(st)
|
|
|
|
video_infos = []
|
|
for fn in intense_by_video:
|
|
# Try source_path from DB first, fall back to PLEX_DIR
|
|
vpath = source_by_fn.get(fn) or os.path.join(PLEX_DIR, fn)
|
|
if not os.path.exists(vpath):
|
|
print(f" skip (not found): {fn}")
|
|
continue
|
|
gt_intense = sorted(intense_by_video[fn])
|
|
gt_soft = sorted(soft_by_video.get(fn, set()))
|
|
video_infos.append((vpath, gt_intense, gt_soft))
|
|
|
|
label = embed_model or "WAV2VEC2_BASE"
|
|
print(f"Training {label} model on {len(video_infos)} videos...")
|
|
model_path = default_model_path(PROFILE_NAME)
|
|
result = train_classifier(
|
|
video_infos, model_path=model_path, embed_model=embed_model,
|
|
)
|
|
if result is None:
|
|
print("Training failed: no valid samples or missing class balance")
|
|
sys.exit(1)
|
|
print(f"Model saved to {model_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|