From 3ddf311c6454a9e0f49d670559a73922419a0f8d Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:29:52 +0900 Subject: [PATCH] Fix fallback nms to output same value as original --- mobile_sam/utils/torch_nms.py | 25 +++++++++++++++---------- sam2/utils/torch_nms.py | 25 +++++++++++++++---------- segment_anything_fb/utils/torch_nms.py | 25 +++++++++++++++---------- segment_anything_hq/utils/torch_nms.py | 25 +++++++++++++++---------- 4 files changed, 60 insertions(+), 40 deletions(-) diff --git a/mobile_sam/utils/torch_nms.py b/mobile_sam/utils/torch_nms.py index e62c32f..4bd93d9 100644 --- a/mobile_sam/utils/torch_nms.py +++ b/mobile_sam/utils/torch_nms.py @@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: - order = torch.argsort(-scores).to(bboxes.device) - indices = torch.arange(bboxes.shape[0]).to(bboxes.device) - keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device) - for i in indices: - if keep[i]: - bbox = bboxes[order[i]] - iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None]) - overlapped = torch.nonzero(iou > iou_threshold) - keep[overlapped + i + 1] = 0 - return order[keep] + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/sam2/utils/torch_nms.py b/sam2/utils/torch_nms.py index e62c32f..4bd93d9 100644 --- a/sam2/utils/torch_nms.py +++ b/sam2/utils/torch_nms.py @@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: - order = torch.argsort(-scores).to(bboxes.device) - indices = torch.arange(bboxes.shape[0]).to(bboxes.device) - keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device) - for i in indices: - if keep[i]: - bbox = bboxes[order[i]] - iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None]) - overlapped = torch.nonzero(iou > iou_threshold) - keep[overlapped + i + 1] = 0 - return order[keep] + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/segment_anything_fb/utils/torch_nms.py b/segment_anything_fb/utils/torch_nms.py index e62c32f..4bd93d9 100644 --- a/segment_anything_fb/utils/torch_nms.py +++ b/segment_anything_fb/utils/torch_nms.py @@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: - order = torch.argsort(-scores).to(bboxes.device) - indices = torch.arange(bboxes.shape[0]).to(bboxes.device) - keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device) - for i in indices: - if keep[i]: - bbox = bboxes[order[i]] - iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None]) - overlapped = torch.nonzero(iou > iou_threshold) - keep[overlapped + i + 1] = 0 - return order[keep] + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/segment_anything_hq/utils/torch_nms.py b/segment_anything_hq/utils/torch_nms.py index e62c32f..4bd93d9 100644 --- a/segment_anything_hq/utils/torch_nms.py +++ b/segment_anything_hq/utils/torch_nms.py @@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: - order = torch.argsort(-scores).to(bboxes.device) - indices = torch.arange(bboxes.shape[0]).to(bboxes.device) - keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device) - for i in indices: - if keep[i]: - bbox = bboxes[order[i]] - iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None]) - overlapped = torch.nonzero(iou > iou_threshold) - keep[overlapped + i + 1] = 0 - return order[keep] + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device)