Skip to content

Commit

Permalink
Move batched NMS indices to correct device (closes facebookresearch#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmmx committed Apr 6, 2023
1 parent aac76a1 commit 006dbe5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions segment_anything/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _generate_masks(self, image: np.ndarray) -> MaskData:
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros(len(data["boxes"])), # categories
torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
Expand Down Expand Up @@ -251,7 +251,7 @@ def _process_crop(
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros(len(data["boxes"])), # categories
torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
Expand Down

0 comments on commit 006dbe5

Please sign in to comment.