feat: dataset statistics dialog with per-video breakdown and class balance
Details button in Train dialog opens a stats view showing: - Class totals (positive/soft/negative) with colored balance bar - Per-video table sortable by column - Warnings for low clip counts, class imbalance, negative-only videos Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -223,6 +223,103 @@ class ScanWorker(QThread):
|
|||||||
self.error.emit(str(e))
|
self.error.emit(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetStatsDialog(QDialog):
|
||||||
|
"""Per-video dataset breakdown with class balance visualization."""
|
||||||
|
|
||||||
|
def __init__(self, video_infos: list, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.setWindowTitle("Dataset Statistics")
|
||||||
|
self.setMinimumSize(600, 400)
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
# ── Totals ────────────────────────────────────────────
|
||||||
|
n_pos = sum(len(vi[1]) for vi in video_infos)
|
||||||
|
n_soft = sum(len(vi[2]) for vi in video_infos)
|
||||||
|
n_neg = sum(len(vi[3]) for vi in video_infos)
|
||||||
|
n_total = n_pos + n_soft + n_neg
|
||||||
|
|
||||||
|
totals = QLabel(
|
||||||
|
f"<b>{len(video_infos)}</b> videos | "
|
||||||
|
f"<b>{n_total}</b> total clips | "
|
||||||
|
f"<span style='color:#4a4'>■</span> {n_pos} positive "
|
||||||
|
f"<span style='color:#aa4'>■</span> {n_soft} soft "
|
||||||
|
f"<span style='color:#a44'>■</span> {n_neg} negative"
|
||||||
|
)
|
||||||
|
layout.addWidget(totals)
|
||||||
|
|
||||||
|
# ── Class balance bar ─────────────────────────────────
|
||||||
|
if n_total > 0:
|
||||||
|
bar = QWidget()
|
||||||
|
bar.setFixedHeight(20)
|
||||||
|
bar.setStyleSheet("background: #222;")
|
||||||
|
|
||||||
|
class _BalanceBar(QWidget):
|
||||||
|
def __init__(self, pos, soft, neg, total):
|
||||||
|
super().__init__()
|
||||||
|
self._fracs = (pos / total, soft / total, neg / total)
|
||||||
|
self.setFixedHeight(20)
|
||||||
|
|
||||||
|
def paintEvent(self, _ev):
|
||||||
|
p = QPainter(self)
|
||||||
|
w = self.width()
|
||||||
|
colors = [QColor(80, 170, 80), QColor(170, 170, 60), QColor(170, 70, 70)]
|
||||||
|
x = 0
|
||||||
|
for frac, col in zip(self._fracs, colors):
|
||||||
|
bw = int(frac * w)
|
||||||
|
if bw > 0:
|
||||||
|
p.fillRect(x, 0, bw, 20, col)
|
||||||
|
x += bw
|
||||||
|
p.end()
|
||||||
|
|
||||||
|
balance = _BalanceBar(n_pos, n_soft, n_neg, n_total)
|
||||||
|
layout.addWidget(balance)
|
||||||
|
|
||||||
|
# ── Per-video table ───────────────────────────────────
|
||||||
|
table = QTableWidget(len(video_infos), 5)
|
||||||
|
table.setHorizontalHeaderLabels(["Video", "Pos", "Soft", "Neg", "Total"])
|
||||||
|
table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
||||||
|
for c in range(1, 5):
|
||||||
|
table.horizontalHeader().setSectionResizeMode(c, QHeaderView.ResizeMode.ResizeToContents)
|
||||||
|
table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
|
||||||
|
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
||||||
|
table.verticalHeader().setVisible(False)
|
||||||
|
|
||||||
|
for row, (path, pos, soft, neg) in enumerate(video_infos):
|
||||||
|
name = os.path.basename(path)
|
||||||
|
table.setItem(row, 0, QTableWidgetItem(name))
|
||||||
|
for col, val in enumerate([len(pos), len(soft), len(neg),
|
||||||
|
len(pos) + len(soft) + len(neg)], 1):
|
||||||
|
item = QTableWidgetItem(str(val))
|
||||||
|
item.setTextAlignment(Qt.AlignmentFlag.AlignCenter)
|
||||||
|
table.setItem(row, col, item)
|
||||||
|
|
||||||
|
table.sortItems(1, Qt.SortOrder.DescendingOrder)
|
||||||
|
layout.addWidget(table)
|
||||||
|
|
||||||
|
# ── Warnings ──────────────────────────────────────────
|
||||||
|
warnings = []
|
||||||
|
if n_pos == 0:
|
||||||
|
warnings.append("No positive clips — export some clips first.")
|
||||||
|
elif n_pos < 20:
|
||||||
|
warnings.append(f"Only {n_pos} positive clips — aim for 20+ for decent results.")
|
||||||
|
# Check for videos with zero positives (only negatives)
|
||||||
|
neg_only = sum(1 for vi in video_infos if len(vi[1]) == 0 and len(vi[3]) > 0)
|
||||||
|
if neg_only:
|
||||||
|
warnings.append(f"{neg_only} video(s) have only negatives, no positives.")
|
||||||
|
# Check balance ratio
|
||||||
|
if n_pos > 0 and n_neg > 0 and (n_neg / n_pos > 5 or n_pos / n_neg > 5):
|
||||||
|
warnings.append("Class imbalance >5:1 — consider adding more of the minority class.")
|
||||||
|
if warnings:
|
||||||
|
lbl = QLabel("<br>".join(f"⚠ {w}" for w in warnings))
|
||||||
|
lbl.setStyleSheet("color: #cc8800;")
|
||||||
|
layout.addWidget(lbl)
|
||||||
|
|
||||||
|
btns = QDialogButtonBox(QDialogButtonBox.StandardButton.Close)
|
||||||
|
btns.rejected.connect(self.close)
|
||||||
|
layout.addWidget(btns)
|
||||||
|
|
||||||
|
|
||||||
class TrainDialog(QDialog):
|
class TrainDialog(QDialog):
|
||||||
"""Dialog for configuring and launching classifier training."""
|
"""Dialog for configuring and launching classifier training."""
|
||||||
|
|
||||||
@@ -301,11 +398,19 @@ class TrainDialog(QDialog):
|
|||||||
|
|
||||||
layout.addLayout(form)
|
layout.addLayout(form)
|
||||||
|
|
||||||
# Stats summary
|
# Stats summary with details button
|
||||||
|
stats_row = QHBoxLayout()
|
||||||
self._lbl_stats = QLabel()
|
self._lbl_stats = QLabel()
|
||||||
|
stats_row.addWidget(self._lbl_stats, 1)
|
||||||
|
self._btn_details = QPushButton("Details…")
|
||||||
|
self._btn_details.setFixedWidth(70)
|
||||||
|
self._btn_details.clicked.connect(self._show_details)
|
||||||
|
self._btn_details.setEnabled(False)
|
||||||
|
stats_row.addWidget(self._btn_details, 0, Qt.AlignmentFlag.AlignTop)
|
||||||
|
self._video_infos: list = []
|
||||||
self._update_stats()
|
self._update_stats()
|
||||||
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
|
self._cmb_positive.currentIndexChanged.connect(self._update_stats)
|
||||||
layout.addWidget(self._lbl_stats)
|
layout.addLayout(stats_row)
|
||||||
|
|
||||||
# Buttons
|
# Buttons
|
||||||
btns = QDialogButtonBox(
|
btns = QDialogButtonBox(
|
||||||
@@ -376,6 +481,8 @@ class TrainDialog(QDialog):
|
|||||||
self._lbl_video_dir.setVisible(needs_fallback)
|
self._lbl_video_dir.setVisible(needs_fallback)
|
||||||
self._video_dir_widget.setVisible(needs_fallback)
|
self._video_dir_widget.setVisible(needs_fallback)
|
||||||
|
|
||||||
|
self._video_infos = video_infos
|
||||||
|
self._btn_details.setEnabled(len(video_infos) > 0)
|
||||||
n_videos = len(video_infos)
|
n_videos = len(video_infos)
|
||||||
n_pos = sum(len(vi[1]) for vi in video_infos)
|
n_pos = sum(len(vi[1]) for vi in video_infos)
|
||||||
n_soft = sum(len(vi[2]) for vi in video_infos)
|
n_soft = sum(len(vi[2]) for vi in video_infos)
|
||||||
@@ -392,6 +499,11 @@ class TrainDialog(QDialog):
|
|||||||
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
|
lines.append("<i>Recommend at least 3 videos for decent results.</i>")
|
||||||
self._lbl_stats.setText("<br>".join(lines))
|
self._lbl_stats.setText("<br>".join(lines))
|
||||||
|
|
||||||
|
def _show_details(self):
|
||||||
|
if self._video_infos:
|
||||||
|
dlg = DatasetStatsDialog(self._video_infos, parent=self)
|
||||||
|
dlg.exec()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positive_folder(self) -> str:
|
def positive_folder(self) -> str:
|
||||||
return self._cmb_positive.currentData() or ""
|
return self._cmb_positive.currentData() or ""
|
||||||
|
|||||||
Reference in New Issue
Block a user