remove: mask generation, venv setup, and settings dialog
Dead code — masking is handled externally via ComfyUI. Removes SetupWorker, MaskWorker, SettingsDialog, build_mask_output_dir, the mask UI row, Settings button, and associated test cases. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from PyQt6.QtWidgets import (
|
|||||||
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
||||||
QLabel, QPushButton, QLineEdit, QFileDialog, QFrame, QStatusBar,
|
QLabel, QPushButton, QLineEdit, QFileDialog, QFrame, QStatusBar,
|
||||||
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
|
QListWidget, QListWidgetItem, QAbstractItemView, QSplitter, QToolTip,
|
||||||
QComboBox, QDialog, QPlainTextEdit, QCheckBox, QSpinBox, QDoubleSpinBox,
|
QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox,
|
||||||
QMessageBox,
|
QMessageBox,
|
||||||
)
|
)
|
||||||
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, pyqtSignal, QSettings
|
from PyQt6.QtCore import Qt, QObject, QThread, QTimer, pyqtSignal, QSettings
|
||||||
@@ -157,26 +157,12 @@ def upsert_clip_annotation(folder: str, clip_path: str, label: str) -> None:
|
|||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
def build_mask_output_dir(video_path: str) -> str:
|
|
||||||
"""Return path of mask output directory: <stem>_masks/ next to the video."""
|
|
||||||
p = Path(video_path)
|
|
||||||
return str(p.parent / f"{p.stem}_masks")
|
|
||||||
|
|
||||||
|
|
||||||
_RATIOS: dict[str, tuple[int, int]] = {
|
_RATIOS: dict[str, tuple[int, int]] = {
|
||||||
"9:16": (9, 16),
|
"9:16": (9, 16),
|
||||||
"4:5": (4, 5),
|
"4:5": (4, 5),
|
||||||
"1:1": (1, 1),
|
"1:1": (1, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
_VENV_PYTHON = str(
|
|
||||||
Path.home() / ".8cut" / "venv"
|
|
||||||
/ ("Scripts" if sys.platform == "win32" else "bin")
|
|
||||||
/ ("python.exe" if sys.platform == "win32" else "python")
|
|
||||||
)
|
|
||||||
_TOOLS_DIR = str(Path(__file__).parent / "tools")
|
|
||||||
|
|
||||||
|
|
||||||
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
def _portrait_crop_filter(ratio: str, crop_center: float) -> str:
|
||||||
"""Return an ffmpeg crop= filter expression for the given portrait ratio.
|
"""Return an ffmpeg crop= filter expression for the given portrait ratio.
|
||||||
|
|
||||||
@@ -1097,150 +1083,6 @@ class PlaylistWidget(QListWidget):
|
|||||||
self._select(self.row(item))
|
self._select(self.row(item))
|
||||||
|
|
||||||
|
|
||||||
class SetupWorker(QThread):
|
|
||||||
"""Installs the ML venv. Streams output line-by-line via `line` signal."""
|
|
||||||
line = pyqtSignal(str)
|
|
||||||
finished = pyqtSignal()
|
|
||||||
error = pyqtSignal(str)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
venv_dir = str(Path.home() / ".8cut" / "venv")
|
|
||||||
steps = [
|
|
||||||
[sys.executable, "-m", "venv", venv_dir],
|
|
||||||
[_VENV_PYTHON, "-m", "pip", "install", "--upgrade", "pip"],
|
|
||||||
[
|
|
||||||
_VENV_PYTHON, "-m", "pip", "install",
|
|
||||||
"torch", "torchvision",
|
|
||||||
"transformers",
|
|
||||||
"opencv-python",
|
|
||||||
"Pillow",
|
|
||||||
"segment-anything-2",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
for cmd in steps:
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
for output_line in proc.stdout:
|
|
||||||
self.line.emit(output_line.rstrip())
|
|
||||||
proc.wait()
|
|
||||||
if proc.returncode != 0:
|
|
||||||
self.error.emit(f"Step failed: {' '.join(cmd[:3])}")
|
|
||||||
return
|
|
||||||
self.finished.emit()
|
|
||||||
except Exception as e:
|
|
||||||
self.error.emit(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
class MaskWorker(QThread):
|
|
||||||
"""Runs a mask generation script as a subprocess inside the ML venv."""
|
|
||||||
progress = pyqtSignal(str)
|
|
||||||
finished = pyqtSignal()
|
|
||||||
error = pyqtSignal(str)
|
|
||||||
|
|
||||||
def __init__(self, script: str, input_path: str, output_dir: str):
|
|
||||||
super().__init__()
|
|
||||||
self._script = script
|
|
||||||
self._input = input_path
|
|
||||||
self._output = output_dir
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
cmd = [_VENV_PYTHON, self._script, "--input", self._input, "--output", self._output]
|
|
||||||
try:
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
for line in proc.stdout:
|
|
||||||
self.progress.emit(line.rstrip())
|
|
||||||
proc.wait()
|
|
||||||
if proc.returncode == 0:
|
|
||||||
self.finished.emit()
|
|
||||||
else:
|
|
||||||
self.error.emit(f"Script exited with code {proc.returncode}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
self.error.emit("venv not found — install ML tools via Settings")
|
|
||||||
except Exception as e:
|
|
||||||
self.error.emit(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
class SettingsDialog(QDialog):
|
|
||||||
"""Settings dialog: shows ML venv status and Install/Reinstall button."""
|
|
||||||
|
|
||||||
venv_installed = pyqtSignal() # emitted when install completes successfully
|
|
||||||
masks_visibility_changed = pyqtSignal(bool)
|
|
||||||
|
|
||||||
def __init__(self, parent=None):
|
|
||||||
super().__init__(parent)
|
|
||||||
self.setWindowTitle("Settings")
|
|
||||||
self.setMinimumWidth(500)
|
|
||||||
self.setMinimumHeight(300)
|
|
||||||
|
|
||||||
self._worker: SetupWorker | None = None
|
|
||||||
self._qsettings = QSettings("8cut", "8cut")
|
|
||||||
|
|
||||||
status_text = "Installed" if Path(_VENV_PYTHON).exists() else "Not installed"
|
|
||||||
self._lbl_status = QLabel(f"ML Tools: {status_text}")
|
|
||||||
|
|
||||||
btn_label = "Reinstall" if Path(_VENV_PYTHON).exists() else "Install"
|
|
||||||
self._btn_install = QPushButton(btn_label)
|
|
||||||
self._btn_install.clicked.connect(self._on_install)
|
|
||||||
|
|
||||||
self._chk_masks = QCheckBox("Show mask generation row")
|
|
||||||
show_masks = self._qsettings.value("show_masks_row", "true") == "true"
|
|
||||||
self._chk_masks.setChecked(show_masks)
|
|
||||||
self._chk_masks.toggled.connect(self._on_masks_toggled)
|
|
||||||
|
|
||||||
self._log = QPlainTextEdit()
|
|
||||||
self._log.setReadOnly(True)
|
|
||||||
self._log.setPlaceholderText("Install output will appear here…")
|
|
||||||
|
|
||||||
top = QHBoxLayout()
|
|
||||||
top.addWidget(self._lbl_status)
|
|
||||||
top.addStretch()
|
|
||||||
top.addWidget(self._btn_install)
|
|
||||||
|
|
||||||
layout = QVBoxLayout(self)
|
|
||||||
layout.addLayout(top)
|
|
||||||
layout.addWidget(self._chk_masks)
|
|
||||||
layout.addWidget(self._log)
|
|
||||||
|
|
||||||
def _on_masks_toggled(self, checked: bool) -> None:
|
|
||||||
self._qsettings.setValue("show_masks_row", "true" if checked else "false")
|
|
||||||
self.masks_visibility_changed.emit(checked)
|
|
||||||
|
|
||||||
def _on_install(self):
|
|
||||||
if self._worker and self._worker.isRunning():
|
|
||||||
return
|
|
||||||
if self._worker:
|
|
||||||
self._worker.quit()
|
|
||||||
self._worker.wait()
|
|
||||||
self._btn_install.setEnabled(False)
|
|
||||||
self._log.clear()
|
|
||||||
self._worker = SetupWorker()
|
|
||||||
self._worker.line.connect(self._log.appendPlainText)
|
|
||||||
self._worker.finished.connect(self._on_install_done)
|
|
||||||
self._worker.error.connect(self._on_install_error)
|
|
||||||
self._worker.start()
|
|
||||||
|
|
||||||
def _on_install_done(self):
|
|
||||||
self._lbl_status.setText("ML Tools: Installed")
|
|
||||||
self._btn_install.setText("Reinstall")
|
|
||||||
self._btn_install.setEnabled(True)
|
|
||||||
self._log.appendPlainText("✓ Installation complete.")
|
|
||||||
self.venv_installed.emit()
|
|
||||||
|
|
||||||
def _on_install_error(self, msg: str):
|
|
||||||
self._btn_install.setEnabled(True)
|
|
||||||
self._log.appendPlainText(f"ERROR: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
class _KeyFilter(QObject):
|
class _KeyFilter(QObject):
|
||||||
"""Suppress global keyboard shortcuts when a text input widget has focus."""
|
"""Suppress global keyboard shortcuts when a text input widget has focus."""
|
||||||
def eventFilter(self, obj, event):
|
def eventFilter(self, obj, event):
|
||||||
@@ -1300,7 +1142,6 @@ class MainWindow(QMainWindow):
|
|||||||
self._export_worker: ExportWorker | None = None
|
self._export_worker: ExportWorker | None = None
|
||||||
self._last_export_path: str = ""
|
self._last_export_path: str = ""
|
||||||
self._overwrite_path: str = "" # set when a marker is selected for re-export
|
self._overwrite_path: str = "" # set when a marker is selected for re-export
|
||||||
self._mask_worker: MaskWorker | None = None
|
|
||||||
self._db_worker: _DBWorker | None = None
|
self._db_worker: _DBWorker | None = None
|
||||||
self._frame_grabber: FrameGrabber | None = None
|
self._frame_grabber: FrameGrabber | None = None
|
||||||
self._fps: float = 25.0 # cached on file load via get_fps()
|
self._fps: float = 25.0 # cached on file load via get_fps()
|
||||||
@@ -1478,25 +1319,9 @@ class MainWindow(QMainWindow):
|
|||||||
self._btn_delete.setToolTip("Delete last export (or selected marker) from disk, DB, and dataset.json")
|
self._btn_delete.setToolTip("Delete last export (or selected marker) from disk, DB, and dataset.json")
|
||||||
self._btn_delete.clicked.connect(self._on_delete_export)
|
self._btn_delete.clicked.connect(self._on_delete_export)
|
||||||
|
|
||||||
# Settings dialog
|
|
||||||
self._settings_dialog = SettingsDialog(self)
|
|
||||||
self._settings_dialog.venv_installed.connect(self._on_venv_installed)
|
|
||||||
self._settings_dialog.masks_visibility_changed.connect(self._on_masks_visibility_changed)
|
|
||||||
|
|
||||||
self._btn_settings = QPushButton("Settings…")
|
|
||||||
self._btn_settings.clicked.connect(self._settings_dialog.show)
|
|
||||||
|
|
||||||
# Mask generation row
|
|
||||||
self._cmb_mask = QComboBox()
|
|
||||||
self._cmb_mask.addItems(["Depth Anything", "SAM"])
|
|
||||||
self._btn_masks = QPushButton("Generate Masks")
|
|
||||||
self._btn_masks.setEnabled(Path(_VENV_PYTHON).exists())
|
|
||||||
self._btn_masks.clicked.connect(self._on_generate_masks)
|
|
||||||
|
|
||||||
# Right-side layout (video + controls)
|
# Right-side layout (video + controls)
|
||||||
top_bar = QHBoxLayout()
|
top_bar = QHBoxLayout()
|
||||||
top_bar.addWidget(self._lbl_file, stretch=1)
|
top_bar.addWidget(self._lbl_file, stretch=1)
|
||||||
top_bar.addWidget(self._btn_settings)
|
|
||||||
|
|
||||||
# Row 1 — transport + annotation + export trigger
|
# Row 1 — transport + annotation + export trigger
|
||||||
transport_row = QHBoxLayout()
|
transport_row = QHBoxLayout()
|
||||||
@@ -1543,21 +1368,6 @@ class MainWindow(QMainWindow):
|
|||||||
right_layout.addLayout(transport_row)
|
right_layout.addLayout(transport_row)
|
||||||
right_layout.addLayout(settings_row)
|
right_layout.addLayout(settings_row)
|
||||||
|
|
||||||
self._mask_row_widget = QWidget()
|
|
||||||
mask_row = QHBoxLayout(self._mask_row_widget)
|
|
||||||
mask_row.setContentsMargins(0, 0, 0, 0)
|
|
||||||
mask_row.addWidget(QLabel("Masks:"))
|
|
||||||
mask_row.addWidget(self._cmb_mask)
|
|
||||||
mask_row.addWidget(self._btn_masks)
|
|
||||||
_lbl_mask_warn = QLabel("⚠ Untested — use ComfyUI instead")
|
|
||||||
_lbl_mask_warn.setStyleSheet("color: #e0a030; font-style: italic;")
|
|
||||||
mask_row.addWidget(_lbl_mask_warn)
|
|
||||||
mask_row.addStretch()
|
|
||||||
show_masks = self._settings.value("show_masks_row", "true") == "true"
|
|
||||||
self._mask_row_widget.setVisible(show_masks)
|
|
||||||
|
|
||||||
right_layout.addWidget(self._mask_row_widget)
|
|
||||||
|
|
||||||
# Left: queue label + playlist
|
# Left: queue label + playlist
|
||||||
queue_label = QLabel("Queue")
|
queue_label = QLabel("Queue")
|
||||||
queue_label.setStyleSheet("color: #aaa; padding: 4px;")
|
queue_label.setStyleSheet("color: #aaa; padding: 4px;")
|
||||||
@@ -2070,53 +1880,5 @@ class MainWindow(QMainWindow):
|
|||||||
if self._db.get_markers(os.path.basename(p)):
|
if self._db.get_markers(os.path.basename(p)):
|
||||||
self._playlist.mark_done(p)
|
self._playlist.mark_done(p)
|
||||||
|
|
||||||
def _on_venv_installed(self) -> None:
|
|
||||||
self._btn_masks.setEnabled(True)
|
|
||||||
|
|
||||||
def _on_masks_visibility_changed(self, visible: bool) -> None:
|
|
||||||
self._mask_row_widget.setVisible(visible)
|
|
||||||
|
|
||||||
def _on_generate_masks(self) -> None:
|
|
||||||
if not self._last_export_path:
|
|
||||||
self.statusBar().showMessage("No clip exported yet — export first.")
|
|
||||||
return
|
|
||||||
if os.path.isdir(self._last_export_path):
|
|
||||||
self.statusBar().showMessage("Mask generation requires an MP4 export — switch format to MP4 and export first.")
|
|
||||||
return
|
|
||||||
if self._mask_worker and self._mask_worker.isRunning():
|
|
||||||
self.statusBar().showMessage("Mask generation already running…")
|
|
||||||
return
|
|
||||||
|
|
||||||
output_dir = build_mask_output_dir(self._last_export_path)
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
method = self._cmb_mask.currentText()
|
|
||||||
script = os.path.join(
|
|
||||||
_TOOLS_DIR,
|
|
||||||
"depth_masks.py" if method == "Depth Anything" else "sam_masks.py",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._btn_masks.setEnabled(False)
|
|
||||||
self.statusBar().showMessage(f"Generating masks ({method})…")
|
|
||||||
|
|
||||||
self._mask_worker = MaskWorker(script, self._last_export_path, output_dir)
|
|
||||||
self._mask_worker.progress.connect(self._on_masks_progress)
|
|
||||||
self._mask_worker.finished.connect(self._on_masks_done)
|
|
||||||
self._mask_worker.error.connect(self._on_masks_error)
|
|
||||||
self._mask_worker.start()
|
|
||||||
|
|
||||||
def _on_masks_progress(self, msg: str) -> None:
|
|
||||||
self.statusBar().showMessage(msg)
|
|
||||||
|
|
||||||
def _on_masks_done(self) -> None:
|
|
||||||
self._btn_masks.setEnabled(True)
|
|
||||||
output_dir = build_mask_output_dir(self._last_export_path)
|
|
||||||
self.statusBar().showMessage(f"Masks saved to {os.path.basename(output_dir)}/")
|
|
||||||
|
|
||||||
def _on_masks_error(self, msg: str) -> None:
|
|
||||||
self._btn_masks.setEnabled(True)
|
|
||||||
self.statusBar().showMessage(f"Mask error: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
+1
-11
@@ -1,5 +1,5 @@
|
|||||||
import tempfile, os, json
|
import tempfile, os, json
|
||||||
from main import build_export_path, format_time, build_ffmpeg_command, build_mask_output_dir, build_sequence_dir, build_audio_extract_command, build_annotation_json_path, upsert_clip_annotation
|
from main import build_export_path, format_time, build_ffmpeg_command, build_sequence_dir, build_audio_extract_command, build_annotation_json_path, upsert_clip_annotation
|
||||||
from main import _normalize_filename, ProcessedDB
|
from main import _normalize_filename, ProcessedDB
|
||||||
|
|
||||||
|
|
||||||
@@ -182,16 +182,6 @@ def test_ffmpeg_command_portrait_off():
|
|||||||
cmd = build_ffmpeg_command("/in/video.mp4", 0.0, "/out/clip.mp4")
|
cmd = build_ffmpeg_command("/in/video.mp4", 0.0, "/out/clip.mp4")
|
||||||
assert "-vf" not in cmd
|
assert "-vf" not in cmd
|
||||||
|
|
||||||
def test_mask_output_dir_basic():
|
|
||||||
assert build_mask_output_dir("/out/clip_001.mp4") == "/out/clip_001_masks"
|
|
||||||
|
|
||||||
def test_mask_output_dir_mkv():
|
|
||||||
assert build_mask_output_dir("/out/my_clip.mkv") == "/out/my_clip_masks"
|
|
||||||
|
|
||||||
def test_mask_output_dir_nested():
|
|
||||||
assert build_mask_output_dir("/a/b/c/shot_042.mp4") == "/a/b/c/shot_042_masks"
|
|
||||||
|
|
||||||
|
|
||||||
# --- build_audio_extract_command ---
|
# --- build_audio_extract_command ---
|
||||||
|
|
||||||
def test_audio_extract_output_path():
|
def test_audio_extract_output_path():
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
"""Depth Anything V2 mask generation script.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python tools/depth_masks.py --input video.mp4 --output masks_dir/
|
|
||||||
|
|
||||||
Outputs one binary PNG per frame: frame_0000.png, frame_0001.png, …
|
|
||||||
Foreground = white (255), background = black (0), via Otsu threshold on depth map.
|
|
||||||
Requires: torch, transformers, opencv-python, Pillow
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--input", required=True)
|
|
||||||
parser.add_argument("--output", required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
os.makedirs(args.output, exist_ok=True)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
print(f"Using device: {device}", flush=True)
|
|
||||||
|
|
||||||
pipe = pipeline(
|
|
||||||
"depth-estimation",
|
|
||||||
model="depth-anything/Depth-Anything-V2-Large-hf",
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
cap = cv2.VideoCapture(args.input)
|
|
||||||
if not cap.isOpened():
|
|
||||||
print(f"ERROR: cannot open {args.input}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
idx = 0
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
||||||
result = pipe(pil_img)
|
|
||||||
depth = np.array(result["depth"]) # float32 array
|
|
||||||
|
|
||||||
# Normalise to 0–255
|
|
||||||
d_min, d_max = depth.min(), depth.max()
|
|
||||||
if d_max > d_min:
|
|
||||||
depth_u8 = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
|
|
||||||
else:
|
|
||||||
depth_u8 = np.zeros_like(depth, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Otsu threshold: closer objects (higher depth value) = foreground
|
|
||||||
_, mask = cv2.threshold(depth_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
||||||
|
|
||||||
out_path = os.path.join(args.output, f"frame_{idx:04d}.png")
|
|
||||||
cv2.imwrite(out_path, mask)
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
print(f"frame {idx}/{total}", flush=True)
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
"""SAM2 mask generation script.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python tools/sam_masks.py --input video.mp4 --output masks_dir/
|
|
||||||
|
|
||||||
Outputs one binary PNG per frame: frame_0000.png, frame_0001.png, …
|
|
||||||
Uses center of first frame as positive point prompt, propagates across all frames.
|
|
||||||
Requires: torch, segment-anything-2, opencv-python
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--input", required=True)
|
|
||||||
parser.add_argument("--output", required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
os.makedirs(args.output, exist_ok=True)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
print(f"Using device: {device}", flush=True)
|
|
||||||
|
|
||||||
# Extract frames to temp directory (SAM2 video predictor needs image files)
|
|
||||||
with tempfile.TemporaryDirectory() as frame_dir:
|
|
||||||
cap = cv2.VideoCapture(args.input)
|
|
||||||
if not cap.isOpened():
|
|
||||||
print(f"ERROR: cannot open {args.input}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
||||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
||||||
idx = 0
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
cv2.imwrite(os.path.join(frame_dir, f"{idx:04d}.jpg"), frame)
|
|
||||||
idx += 1
|
|
||||||
cap.release()
|
|
||||||
|
|
||||||
print(f"Extracted {idx} frames", flush=True)
|
|
||||||
|
|
||||||
# SAM2: use from_pretrained (SAM2.1+ / HuggingFace integration)
|
|
||||||
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
|
||||||
|
|
||||||
predictor = SAM2VideoPredictor.from_pretrained(
|
|
||||||
"facebook/sam2-hiera-large"
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
state = predictor.init_state(video_path=frame_dir)
|
|
||||||
|
|
||||||
# Center of first frame as positive point prompt
|
|
||||||
cx, cy = width // 2, height // 2
|
|
||||||
_, _, _ = predictor.add_new_points_or_box(
|
|
||||||
inference_state=state,
|
|
||||||
frame_idx=0,
|
|
||||||
obj_id=1,
|
|
||||||
points=np.array([[cx, cy]], dtype=np.float32),
|
|
||||||
labels=np.array([1], dtype=np.int32),
|
|
||||||
)
|
|
||||||
|
|
||||||
for frame_idx, obj_ids, out_mask_logits in predictor.propagate_in_video(state):
|
|
||||||
# out_mask_logits: (N_objects, 1, H, W) — threshold logits at 0
|
|
||||||
mask = (out_mask_logits[0].squeeze().cpu().numpy() > 0.0).astype(np.uint8) * 255
|
|
||||||
out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png")
|
|
||||||
cv2.imwrite(out_path, mask)
|
|
||||||
print(f"frame {frame_idx + 1}/{total}", flush=True)
|
|
||||||
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user