Add fallback nms for environ without batched_nms
parent
71dc7fe3ff
commit
1ccb9e462c
|
|
@ -13,24 +13,13 @@ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
|||
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2.utils.amg import (
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
MaskData,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
from sam2.utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box,
|
||||
box_xyxy_to_xywh, build_all_layer_point_grids,
|
||||
calculate_stability_score, coco_encode_rle, generate_crop_boxes,
|
||||
is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions,
|
||||
rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points)
|
||||
|
||||
from .utils.torch_nms import nms
|
||||
|
||||
|
||||
class SAM2AutomaticMaskGenerator:
|
||||
|
|
@ -220,13 +209,21 @@ class SAM2AutomaticMaskGenerator:
|
|||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
scores = scores.to(data["boxes"].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
try:
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
except Exception:
|
||||
keep_by_nms = nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
|
|
@ -258,12 +255,19 @@ class SAM2AutomaticMaskGenerator:
|
|||
self.predictor.reset_predictor()
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
try:
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
except Exception:
|
||||
keep_by_nms = nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
# Return to the original image frame
|
||||
|
|
@ -398,12 +402,19 @@ class SAM2AutomaticMaskGenerator:
|
|||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros_like(boxes[:, 0]), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
try:
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros_like(boxes[:, 0]), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
except Exception:
|
||||
keep_by_nms = nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
from torchvision.ops.boxes import box_iou
|
||||
|
||||
|
||||
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||
order = torch.argsort(-scores).to(bboxes.device)
|
||||
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
|
||||
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
|
||||
for i in indices:
|
||||
if keep[i]:
|
||||
bbox = bboxes[order[i]]
|
||||
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
|
||||
overlapped = torch.nonzero(iou > iou_threshold)
|
||||
keep[overlapped + i + 1] = 0
|
||||
return order[keep]
|
||||
Loading…
Reference in New Issue