From eca49caee9dc44f8f64abdda48d3854cc3c6cc73 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 6 Apr 2026 15:45:53 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20sam=5Fmasks=20=E2=80=94=20correct=20SAM2?= =?UTF-8?q?=20API=20and=20mask=20logit=20threshold?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use SAM2VideoPredictor.from_pretrained() instead of the checkpoint-based build_sam2_video_predictor() which doesn't accept HuggingFace model IDs. Threshold out_mask_logits at 0.0 and squeeze shape before converting to binary PNG. Co-Authored-By: Claude Sonnet 4.6 --- tools/sam_masks.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tools/sam_masks.py b/tools/sam_masks.py index 93cb4af..e445a7f 100644 --- a/tools/sam_masks.py +++ b/tools/sam_masks.py @@ -49,12 +49,12 @@ def main(): print(f"Extracted {idx} frames", flush=True) - from sam2.build_sam import build_sam2_video_predictor + # SAM2: use from_pretrained (SAM2.1+ / HuggingFace integration) + from sam2.sam2_video_predictor import SAM2VideoPredictor - predictor = build_sam2_video_predictor( - "facebook/sam2-hiera-large", - device=device, - ) + predictor = SAM2VideoPredictor.from_pretrained( + "facebook/sam2-hiera-large" + ).to(device) with torch.inference_mode(): state = predictor.init_state(video_path=frame_dir) @@ -69,9 +69,9 @@ def main(): labels=np.array([1], dtype=np.int32), ) - for frame_idx, obj_ids, masks in predictor.propagate_in_video(state): - # masks shape: (N_objects, H, W) bool tensor - mask = masks[0].cpu().numpy().astype(np.uint8) * 255 + for frame_idx, obj_ids, out_mask_logits in predictor.propagate_in_video(state): + # out_mask_logits: (N_objects, 1, H, W) — threshold logits at 0 + mask = (out_mask_logits[0].squeeze().cpu().numpy() > 0.0).astype(np.uint8) * 255 out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png") cv2.imwrite(out_path, mask) print(f"frame {frame_idx + 1}/{total}", flush=True)