Add fallback nms for environ without batched_nms

main
Uminosachi 2024-08-01 14:34:20 +09:00
parent 71dc7fe3ff
commit 1ccb9e462c
2 changed files with 62 additions and 36 deletions

View File

@ -13,24 +13,13 @@ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.utils.amg import (
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
MaskData,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
from sam2.utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box,
box_xyxy_to_xywh, build_all_layer_point_grids,
calculate_stability_score, coco_encode_rle, generate_crop_boxes,
is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions,
rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points)
from .utils.torch_nms import nms
class SAM2AutomaticMaskGenerator:
@ -220,13 +209,21 @@ class SAM2AutomaticMaskGenerator:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
try:
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
except Exception:
keep_by_nms = nms(
data["boxes"].float(),
scores,
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
@ -258,12 +255,19 @@ class SAM2AutomaticMaskGenerator:
self.predictor.reset_predictor()
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
try:
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
except Exception:
keep_by_nms = nms(
data["boxes"].float(),
data["iou_preds"],
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
@ -398,12 +402,19 @@ class SAM2AutomaticMaskGenerator:
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
try:
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
except Exception:
keep_by_nms = nms(
boxes.float(),
torch.as_tensor(scores),
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:

15
sam2/utils/torch_nms.py Normal file
View File

@ -0,0 +1,15 @@
import torch
from torchvision.ops.boxes import box_iou
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
order = torch.argsort(-scores).to(bboxes.device)
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
for i in indices:
if keep[i]:
bbox = bboxes[order[i]]
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
overlapped = torch.nonzero(iou > iou_threshold)
keep[overlapped + i + 1] = 0
return order[keep]