diff --git a/tools/sam_masks.py b/tools/sam_masks.py index 93cb4af..e445a7f 100644 --- a/tools/sam_masks.py +++ b/tools/sam_masks.py @@ -49,12 +49,12 @@ def main(): print(f"Extracted {idx} frames", flush=True) - from sam2.build_sam import build_sam2_video_predictor + # SAM2: use from_pretrained (SAM2.1+ / HuggingFace integration) + from sam2.sam2_video_predictor import SAM2VideoPredictor - predictor = build_sam2_video_predictor( - "facebook/sam2-hiera-large", - device=device, - ) + predictor = SAM2VideoPredictor.from_pretrained( + "facebook/sam2-hiera-large" + ).to(device) with torch.inference_mode(): state = predictor.init_state(video_path=frame_dir) @@ -69,9 +69,9 @@ def main(): labels=np.array([1], dtype=np.int32), ) - for frame_idx, obj_ids, masks in predictor.propagate_in_video(state): - # masks shape: (N_objects, H, W) bool tensor - mask = masks[0].cpu().numpy().astype(np.uint8) * 255 + for frame_idx, obj_ids, out_mask_logits in predictor.propagate_in_video(state): + # out_mask_logits: (N_objects, 1, H, W) — threshold logits at 0 + mask = (out_mask_logits[0].squeeze().cpu().numpy() > 0.0).astype(np.uint8) * 255 out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png") cv2.imwrite(out_path, mask) print(f"frame {frame_idx + 1}/{total}", flush=True)