fix bounding box draw bug

pull/26/head
Chengsong Zhang 2023-04-15 04:59:44 +08:00
parent 370663a13a
commit 520e7fa74f
2 changed files with 15 additions and 31 deletions

View File

@ -40,7 +40,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness)
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)

View File

@ -146,22 +146,21 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
if dino_enabled and boxes_filt.shape[0] > 1:
print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.")
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=boxes.to(device),
boxes=transformed_boxes.to(device),
multimask_output=True,
)
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
boxes = boxes.cpu().numpy().astype(int)
else:
print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.")
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
box = boxes_filt[0].numpy() if boxes_filt is not None and boxes_filt.shape[0] > 0 else None
box = copy.deepcopy(boxes_filt[0].numpy()) if boxes_filt is not None and boxes_filt.shape[0] > 0 else None
masks, _, _ = predictor.predict(
point_coords=point_coords if len(point_coords) > 0 else None,
point_labels=point_labels if len(point_coords) > 0 else None,
@ -170,7 +169,6 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
)
masks = masks[:, None, ...]
boxes = predictor.transform.apply_boxes(box, image_np.shape[:2]).astype(int) if box is not None else None
if shared.cmd_opts.lowvram:
sam.to(cpu)
@ -181,9 +179,10 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
mask_images = []
masks_gallery = []
matted_images = []
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
for mask in masks:
blended_image = show_masks(show_boxes(image_np, boxes), mask)
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
@ -193,26 +192,11 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
return mask_images + masks_gallery + matted_images
def dino_predict(sam_model_name, input_image, dino_model_name, text_prompt, box_threshold):
def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
image_np = np.array(input_image)
image_np_rgb = image_np[..., :3]
boxes_filt = dino_predict_internal(
input_image, dino_model_name, text_prompt, box_threshold).numpy()
sam = init_sam_model(sam_model_name)
predictor = SamPredictor(sam)
predictor.set_image(image_np_rgb)
boxes_filt = predictor.transform.apply_boxes(
boxes_filt, image_np.shape[:2]).astype(int)
if shared.cmd_opts.lowvram:
sam.to(cpu)
gc.collect()
torch_gc()
boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy()
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
return Image.fromarray(show_boxes(image_np, boxes_filt, show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
def dino_batch_process(
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
@ -234,21 +218,21 @@ def dino_batch_process(
boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
predictor.set_image(image_np_rgb)
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=boxes.to(device),
boxes=transformed_boxes.to(device),
multimask_output=(dino_batch_output_per_image == 1),
)
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
boxes = boxes.cpu().numpy().astype(int)
boxes_filt = boxes_filt.cpu().numpy().astype(int)
filename, ext = os.path.splitext(os.path.basename(input_image_file))
for idx, mask in enumerate(masks):
blended_image = show_masks(show_boxes(image_np, boxes), mask)
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
_, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt)
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
@ -318,7 +302,7 @@ class Script(scripts.Script):
dino_preview_boxes_button.click(
fn=dino_predict,
_js="submit_dino",
inputs=[sam_model_name, input_image, dino_model_name, text_prompt, box_threshold],
inputs=[input_image, dino_model_name, text_prompt, box_threshold],
outputs=[dino_preview_boxes, dino_preview_boxes_selection]
)