diff --git a/main.py b/main.py index 5c40c39..e3c78a5 100644 --- a/main.py +++ b/main.py @@ -91,6 +91,19 @@ def build_audio_extract_command(input_path: str, start: float, sequence_dir: str ] +def build_annotation_tsv_path(folder: str) -> str: + return os.path.join(folder, "dataset.tsv") + + +def append_to_tsv(folder: str, clip_stem: str, label: str) -> None: + """Append one line to /dataset.tsv (creates file if absent).""" + if not label: + return + tsv_path = build_annotation_tsv_path(folder) + with open(tsv_path, "a", encoding="utf-8") as f: + f.write(f"{clip_stem}\t{label}\n") + + def build_mask_output_dir(video_path: str) -> str: """Return path of mask output directory: _masks/ next to the video.""" p = Path(video_path) @@ -144,7 +157,7 @@ def _normalize_filename(filename: str) -> str: class ProcessedDB: - _SCHEMA_VERSION = 2 # bump when schema changes + _SCHEMA_VERSION = 3 # bump when schema changes def __init__(self, db_path: str | None = None): if db_path is None: @@ -165,7 +178,7 @@ class ProcessedDB: row[1] for row in self._con.execute("PRAGMA table_info(processed)").fetchall() } - needs_recreate = "start_time" not in cols or "output_path" not in cols + needs_recreate = not {"start_time", "output_path", "label", "category"}.issubset(cols) if needs_recreate: self._con.execute("DROP TABLE IF EXISTS processed") self._con.execute( @@ -174,6 +187,8 @@ class ProcessedDB: " filename TEXT NOT NULL," " start_time REAL NOT NULL," " output_path TEXT NOT NULL," + " label TEXT NOT NULL DEFAULT ''," + " category TEXT NOT NULL DEFAULT ''," " processed_at TEXT NOT NULL" ")" ) @@ -182,13 +197,15 @@ class ProcessedDB: ) self._con.commit() - def add(self, filename: str, start_time: float, output_path: str) -> None: + def add(self, filename: str, start_time: float, output_path: str, + label: str = "", category: str = "") -> None: if not self._enabled: return self._con.execute( - "INSERT INTO processed (filename, start_time, output_path, processed_at)" - " VALUES (?, ?, ?, ?)", - (filename, start_time, output_path, datetime.now(timezone.utc).isoformat()), + "INSERT INTO processed (filename, start_time, output_path, label, category, processed_at)" + " VALUES (?, ?, ?, ?, ?, ?)", + (filename, start_time, output_path, label, category, + datetime.now(timezone.utc).isoformat()), ) self._con.commit() @@ -885,6 +902,25 @@ class MainWindow(QMainWindow): ) self._cmb_format.currentTextChanged.connect(self._update_next_label) + self._txt_label = QLineEdit() + self._txt_label.setPlaceholderText("Sound label (e.g. dog barking)") + self._txt_label.setFixedWidth(200) + saved_label = self._settings.value("sound_label", "") + self._txt_label.setText(saved_label) + self._txt_label.textChanged.connect( + lambda v: self._settings.setValue("sound_label", v) + ) + + self._cmb_category = QComboBox() + _SELVA_CATEGORIES = ["", "Human", "Animal", "Vehicle", "Tool", "Music", "Nature", "Sport", "Other"] + self._cmb_category.addItems(_SELVA_CATEGORIES) + saved_cat = self._settings.value("sound_category", "") + cat_idx = self._cmb_category.findText(saved_cat) + self._cmb_category.setCurrentIndex(max(cat_idx, 0)) + self._cmb_category.currentTextChanged.connect( + lambda v: self._settings.setValue("sound_category", v) + ) + self._crop_bar = CropBarWidget() self._crop_bar.set_crop_center(self._crop_center) self._crop_bar.set_portrait_ratio( @@ -960,8 +996,16 @@ class MainWindow(QMainWindow): show_masks = self._settings.value("show_masks_row", "true") == "true" self._mask_row_widget.setVisible(show_masks) + annotation_row = QHBoxLayout() + annotation_row.addWidget(QLabel("Label:")) + annotation_row.addWidget(self._txt_label) + annotation_row.addWidget(QLabel("Cat:")) + annotation_row.addWidget(self._cmb_category) + annotation_row.addStretch() + right_layout.addLayout(controls) right_layout.addLayout(export_row) + right_layout.addLayout(annotation_row) right_layout.addWidget(self._mask_row_widget) # Left: queue label + playlist @@ -1177,7 +1221,17 @@ class MainWindow(QMainWindow): self._export_worker.start() def _on_export_done(self, path: str): - self._db.add(os.path.basename(self._file_path), self._cursor, path) + label = self._txt_label.text().strip() + category = self._cmb_category.currentText() + self._db.add( + os.path.basename(self._file_path), + self._cursor, + path, + label=label, + category=category, + ) + clip_stem = os.path.splitext(os.path.basename(path))[0] + append_to_tsv(self._txt_folder.text(), clip_stem, label) # For MP4 exports path is a file; for WebP sequence it is a directory. # build_mask_output_dir handles both correctly via Path.stem. self._last_export_path = path diff --git a/tests/test_utils.py b/tests/test_utils.py index dc100fa..29e839e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ import tempfile, os -from main import build_export_path, format_time, build_ffmpeg_command, build_mask_output_dir, build_sequence_dir, build_audio_extract_command +from main import build_export_path, format_time, build_ffmpeg_command, build_mask_output_dir, build_sequence_dir, build_audio_extract_command, build_annotation_tsv_path, append_to_tsv from main import _normalize_filename, ProcessedDB @@ -235,3 +235,39 @@ def test_ffmpeg_command_image_sequence_no_audio(): assert "-an" in cmd assert "-c:a" not in cmd assert "aac" not in cmd + + +def test_annotation_tsv_path(): + assert build_annotation_tsv_path("/out") == "/out/dataset.tsv" + +def test_append_to_tsv_creates_file(): + with tempfile.TemporaryDirectory() as d: + append_to_tsv(d, "clip_001", "dog barking") + with open(os.path.join(d, "dataset.tsv")) as f: + lines = f.readlines() + assert lines == ["clip_001\tdog barking\n"] + +def test_append_to_tsv_appends(): + with tempfile.TemporaryDirectory() as d: + append_to_tsv(d, "clip_001", "dog barking") + append_to_tsv(d, "clip_002", "cat meowing") + with open(os.path.join(d, "dataset.tsv")) as f: + lines = f.readlines() + assert len(lines) == 2 + +def test_append_to_tsv_empty_label_skips(): + with tempfile.TemporaryDirectory() as d: + append_to_tsv(d, "clip_001", "") + assert not os.path.exists(os.path.join(d, "dataset.tsv")) + +def test_db_stores_label_and_category(): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + path = f.name + try: + db = ProcessedDB(path) + db.add("video.mp4", 0.0, "/out/clip_001.mp4", label="dog barking", category="Animal") + markers = db.get_markers("video.mp4") + assert len(markers) == 1 + assert markers[0][0] == 0.0 + finally: + os.unlink(path)