diff --git a/main.py b/main.py
index 836941a..d40b309 100755
--- a/main.py
+++ b/main.py
@@ -223,6 +223,103 @@ class ScanWorker(QThread):
self.error.emit(str(e))
+class DatasetStatsDialog(QDialog):
+ """Per-video dataset breakdown with class balance visualization."""
+
+ def __init__(self, video_infos: list, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle("Dataset Statistics")
+ self.setMinimumSize(600, 400)
+
+ layout = QVBoxLayout(self)
+
+ # ── Totals ────────────────────────────────────────────
+ n_pos = sum(len(vi[1]) for vi in video_infos)
+ n_soft = sum(len(vi[2]) for vi in video_infos)
+ n_neg = sum(len(vi[3]) for vi in video_infos)
+ n_total = n_pos + n_soft + n_neg
+
+ totals = QLabel(
+ f"{len(video_infos)} videos | "
+ f"{n_total} total clips | "
+ f"■ {n_pos} positive "
+ f"■ {n_soft} soft "
+ f"■ {n_neg} negative"
+ )
+ layout.addWidget(totals)
+
+ # ── Class balance bar ─────────────────────────────────
+ if n_total > 0:
+ bar = QWidget()
+ bar.setFixedHeight(20)
+ bar.setStyleSheet("background: #222;")
+
+ class _BalanceBar(QWidget):
+ def __init__(self, pos, soft, neg, total):
+ super().__init__()
+ self._fracs = (pos / total, soft / total, neg / total)
+ self.setFixedHeight(20)
+
+ def paintEvent(self, _ev):
+ p = QPainter(self)
+ w = self.width()
+ colors = [QColor(80, 170, 80), QColor(170, 170, 60), QColor(170, 70, 70)]
+ x = 0
+ for frac, col in zip(self._fracs, colors):
+ bw = int(frac * w)
+ if bw > 0:
+ p.fillRect(x, 0, bw, 20, col)
+ x += bw
+ p.end()
+
+ balance = _BalanceBar(n_pos, n_soft, n_neg, n_total)
+ layout.addWidget(balance)
+
+ # ── Per-video table ───────────────────────────────────
+ table = QTableWidget(len(video_infos), 5)
+ table.setHorizontalHeaderLabels(["Video", "Pos", "Soft", "Neg", "Total"])
+ table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
+ for c in range(1, 5):
+ table.horizontalHeader().setSectionResizeMode(c, QHeaderView.ResizeMode.ResizeToContents)
+ table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
+ table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
+ table.verticalHeader().setVisible(False)
+
+ for row, (path, pos, soft, neg) in enumerate(video_infos):
+ name = os.path.basename(path)
+ table.setItem(row, 0, QTableWidgetItem(name))
+ for col, val in enumerate([len(pos), len(soft), len(neg),
+ len(pos) + len(soft) + len(neg)], 1):
+ item = QTableWidgetItem(str(val))
+ item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
+ table.setItem(row, col, item)
+
+ table.sortItems(1, Qt.SortOrder.DescendingOrder)
+ layout.addWidget(table)
+
+ # ── Warnings ──────────────────────────────────────────
+ warnings = []
+ if n_pos == 0:
+ warnings.append("No positive clips — export some clips first.")
+ elif n_pos < 20:
+ warnings.append(f"Only {n_pos} positive clips — aim for 20+ for decent results.")
+ # Check for videos with zero positives (only negatives)
+ neg_only = sum(1 for vi in video_infos if len(vi[1]) == 0 and len(vi[3]) > 0)
+ if neg_only:
+ warnings.append(f"{neg_only} video(s) have only negatives, no positives.")
+ # Check balance ratio
+ if n_pos > 0 and n_neg > 0 and (n_neg / n_pos > 5 or n_pos / n_neg > 5):
+ warnings.append("Class imbalance >5:1 — consider adding more of the minority class.")
+ if warnings:
+ lbl = QLabel("
".join(f"⚠ {w}" for w in warnings))
+ lbl.setStyleSheet("color: #cc8800;")
+ layout.addWidget(lbl)
+
+ btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Close)
+ btns.rejected.connect(self.close)
+ layout.addWidget(btns)
+
+
class TrainDialog(QDialog):
"""Dialog for configuring and launching classifier training."""
@@ -301,11 +398,19 @@ class TrainDialog(QDialog):
layout.addLayout(form)
- # Stats summary
+ # Stats summary with details button
+ stats_row = QHBoxLayout()
self._lbl_stats = QLabel()
+ stats_row.addWidget(self._lbl_stats, 1)
+ self._btn_details = QPushButton("Details…")
+ self._btn_details.setFixedWidth(70)
+ self._btn_details.clicked.connect(self._show_details)
+ self._btn_details.setEnabled(False)
+ stats_row.addWidget(self._btn_details, 0, Qt.AlignmentFlag.AlignTop)
+ self._video_infos: list = []
self._update_stats()
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
- layout.addWidget(self._lbl_stats)
+ layout.addLayout(stats_row)
# Buttons
btns = QDialogButtonBox(
@@ -376,6 +481,8 @@ class TrainDialog(QDialog):
self._lbl_video_dir.setVisible(needs_fallback)
self._video_dir_widget.setVisible(needs_fallback)
+ self._video_infos = video_infos
+ self._btn_details.setEnabled(len(video_infos) > 0)
n_videos = len(video_infos)
n_pos = sum(len(vi[1]) for vi in video_infos)
n_soft = sum(len(vi[2]) for vi in video_infos)
@@ -392,6 +499,11 @@ class TrainDialog(QDialog):
lines.append("Recommend at least 3 videos for decent results.")
self._lbl_stats.setText("
".join(lines))
+ def _show_details(self):
+ if self._video_infos:
+ dlg = DatasetStatsDialog(self._video_infos, parent=self)
+ dlg.exec()
+
@property
def positive_folder(self) -> str:
return self._cmb_positive.currentData() or ""