21 lines
520 B
Python
21 lines
520 B
Python
import torch
|
|
from torchvision.ops.boxes import box_iou
|
|
|
|
|
|
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
|
order = torch.argsort(-scores)
|
|
keep = []
|
|
|
|
while order.numel() > 0:
|
|
i = order[0]
|
|
keep.append(i.item())
|
|
|
|
if order.numel() == 1:
|
|
break
|
|
|
|
ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
|
|
mask = ious <= iou_threshold
|
|
order = order[1:][mask]
|
|
|
|
return torch.tensor(keep, device=bboxes.device)
|