Fix fallback nms to output same value as original
parent
1ccb9e462c
commit
3ddf311c64
|
|
@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
|
||||||
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||||
order = torch.argsort(-scores).to(bboxes.device)
|
order = torch.argsort(-scores)
|
||||||
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
|
keep = []
|
||||||
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
|
|
||||||
for i in indices:
|
while order.numel() > 0:
|
||||||
if keep[i]:
|
i = order[0]
|
||||||
bbox = bboxes[order[i]]
|
keep.append(i.item())
|
||||||
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
|
|
||||||
overlapped = torch.nonzero(iou > iou_threshold)
|
if order.numel() == 1:
|
||||||
keep[overlapped + i + 1] = 0
|
break
|
||||||
return order[keep]
|
|
||||||
|
ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
|
||||||
|
mask = ious <= iou_threshold
|
||||||
|
order = order[1:][mask]
|
||||||
|
|
||||||
|
return torch.tensor(keep, device=bboxes.device)
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
|
||||||
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||||
order = torch.argsort(-scores).to(bboxes.device)
|
order = torch.argsort(-scores)
|
||||||
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
|
keep = []
|
||||||
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
|
|
||||||
for i in indices:
|
while order.numel() > 0:
|
||||||
if keep[i]:
|
i = order[0]
|
||||||
bbox = bboxes[order[i]]
|
keep.append(i.item())
|
||||||
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
|
|
||||||
overlapped = torch.nonzero(iou > iou_threshold)
|
if order.numel() == 1:
|
||||||
keep[overlapped + i + 1] = 0
|
break
|
||||||
return order[keep]
|
|
||||||
|
ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
|
||||||
|
mask = ious <= iou_threshold
|
||||||
|
order = order[1:][mask]
|
||||||
|
|
||||||
|
return torch.tensor(keep, device=bboxes.device)
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
|
||||||
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||||
order = torch.argsort(-scores).to(bboxes.device)
|
order = torch.argsort(-scores)
|
||||||
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
|
keep = []
|
||||||
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
|
|
||||||
for i in indices:
|
while order.numel() > 0:
|
||||||
if keep[i]:
|
i = order[0]
|
||||||
bbox = bboxes[order[i]]
|
keep.append(i.item())
|
||||||
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
|
|
||||||
overlapped = torch.nonzero(iou > iou_threshold)
|
if order.numel() == 1:
|
||||||
keep[overlapped + i + 1] = 0
|
break
|
||||||
return order[keep]
|
|
||||||
|
ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
|
||||||
|
mask = ious <= iou_threshold
|
||||||
|
order = order[1:][mask]
|
||||||
|
|
||||||
|
return torch.tensor(keep, device=bboxes.device)
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,18 @@ from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
|
||||||
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||||
order = torch.argsort(-scores).to(bboxes.device)
|
order = torch.argsort(-scores)
|
||||||
indices = torch.arange(bboxes.shape[0]).to(bboxes.device)
|
keep = []
|
||||||
keep = torch.ones_like(indices, dtype=torch.bool).to(bboxes.device)
|
|
||||||
for i in indices:
|
while order.numel() > 0:
|
||||||
if keep[i]:
|
i = order[0]
|
||||||
bbox = bboxes[order[i]]
|
keep.append(i.item())
|
||||||
iou = box_iou(bbox[None, ...], (bboxes[order[i + 1:]]) * keep[i + 1:][..., None])
|
|
||||||
overlapped = torch.nonzero(iou > iou_threshold)
|
if order.numel() == 1:
|
||||||
keep[overlapped + i + 1] = 0
|
break
|
||||||
return order[keep]
|
|
||||||
|
ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
|
||||||
|
mask = ious <= iou_threshold
|
||||||
|
order = order[1:][mask]
|
||||||
|
|
||||||
|
return torch.tensor(keep, device=bboxes.device)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue