diff --git a/tools/sam_masks.py b/tools/sam_masks.py new file mode 100644 index 0000000..93cb4af --- /dev/null +++ b/tools/sam_masks.py @@ -0,0 +1,83 @@ +"""SAM2 mask generation script. + +Usage: + python tools/sam_masks.py --input video.mp4 --output masks_dir/ + +Outputs one binary PNG per frame: frame_0000.png, frame_0001.png, … +Uses center of first frame as positive point prompt, propagates across all frames. +Requires: torch, segment-anything-2, opencv-python +""" +import argparse +import os +import sys +import tempfile + +import cv2 +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True) + parser.add_argument("--output", required=True) + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}", flush=True) + + # Extract frames to temp directory (SAM2 video predictor needs image files) + with tempfile.TemporaryDirectory() as frame_dir: + cap = cv2.VideoCapture(args.input) + if not cap.isOpened(): + print(f"ERROR: cannot open {args.input}", file=sys.stderr) + sys.exit(1) + + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + cv2.imwrite(os.path.join(frame_dir, f"{idx:04d}.jpg"), frame) + idx += 1 + cap.release() + + print(f"Extracted {idx} frames", flush=True) + + from sam2.build_sam import build_sam2_video_predictor + + predictor = build_sam2_video_predictor( + "facebook/sam2-hiera-large", + device=device, + ) + + with torch.inference_mode(): + state = predictor.init_state(video_path=frame_dir) + + # Center of first frame as positive point prompt + cx, cy = width // 2, height // 2 + _, _, _ = predictor.add_new_points_or_box( + inference_state=state, + frame_idx=0, + obj_id=1, + points=np.array([[cx, cy]], dtype=np.float32), + labels=np.array([1], dtype=np.int32), + ) + + for frame_idx, obj_ids, masks in predictor.propagate_in_video(state): + # masks shape: (N_objects, H, W) bool tensor + mask = masks[0].cpu().numpy().astype(np.uint8) * 255 + out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png") + cv2.imwrite(out_path, mask) + print(f"frame {frame_idx + 1}/{total}", flush=True) + + print("done", flush=True) + + +if __name__ == "__main__": + main()