fix: sam_masks — correct SAM2 API and mask logit threshold
Use SAM2VideoPredictor.from_pretrained() instead of the checkpoint-based build_sam2_video_predictor() which doesn't accept HuggingFace model IDs. Threshold out_mask_logits at 0.0 and squeeze shape before converting to binary PNG. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+8
-8
@@ -49,12 +49,12 @@ def main():
|
|||||||
|
|
||||||
print(f"Extracted {idx} frames", flush=True)
|
print(f"Extracted {idx} frames", flush=True)
|
||||||
|
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
# SAM2: use from_pretrained (SAM2.1+ / HuggingFace integration)
|
||||||
|
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
||||||
|
|
||||||
predictor = build_sam2_video_predictor(
|
predictor = SAM2VideoPredictor.from_pretrained(
|
||||||
"facebook/sam2-hiera-large",
|
"facebook/sam2-hiera-large"
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
state = predictor.init_state(video_path=frame_dir)
|
state = predictor.init_state(video_path=frame_dir)
|
||||||
@@ -69,9 +69,9 @@ def main():
|
|||||||
labels=np.array([1], dtype=np.int32),
|
labels=np.array([1], dtype=np.int32),
|
||||||
)
|
)
|
||||||
|
|
||||||
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
|
for frame_idx, obj_ids, out_mask_logits in predictor.propagate_in_video(state):
|
||||||
# masks shape: (N_objects, H, W) bool tensor
|
# out_mask_logits: (N_objects, 1, H, W) — threshold logits at 0
|
||||||
mask = masks[0].cpu().numpy().astype(np.uint8) * 255
|
mask = (out_mask_logits[0].squeeze().cpu().numpy() > 0.0).astype(np.uint8) * 255
|
||||||
out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png")
|
out_path = os.path.join(args.output, f"frame_{frame_idx:04d}.png")
|
||||||
cv2.imwrite(out_path, mask)
|
cv2.imwrite(out_path, mask)
|
||||||
print(f"frame {frame_idx + 1}/{total}", flush=True)
|
print(f"frame {frame_idx + 1}/{total}", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user