fix bounding box draw bug

pull/26/head
Chengsong Zhang 2023-04-15 04:59:44 +08:00
parent 370663a13a
commit 520e7fa74f
2 changed files with 15 additions and 31 deletions

View File

@ -40,7 +40,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
image = copy.deepcopy(image_np) image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes): for idx, box in enumerate(boxes):
x, y, w, h = box x, y, w, h = box
cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness) cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index: if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx) text = str(idx)

View File

@ -146,22 +146,21 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
if dino_enabled and boxes_filt.shape[0] > 1: if dino_enabled and boxes_filt.shape[0] > 1:
print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.") print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.")
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
masks, _, _ = predictor.predict_torch( masks, _, _ = predictor.predict_torch(
point_coords=None, point_coords=None,
point_labels=None, point_labels=None,
boxes=boxes.to(device), boxes=transformed_boxes.to(device),
multimask_output=True, multimask_output=True,
) )
masks = masks.permute(1, 0, 2, 3).cpu().numpy() masks = masks.permute(1, 0, 2, 3).cpu().numpy()
boxes = boxes.cpu().numpy().astype(int)
else: else:
print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.") print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.")
point_coords = np.array(positive_points + negative_points) point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
box = boxes_filt[0].numpy() if boxes_filt is not None and boxes_filt.shape[0] > 0 else None box = copy.deepcopy(boxes_filt[0].numpy()) if boxes_filt is not None and boxes_filt.shape[0] > 0 else None
masks, _, _ = predictor.predict( masks, _, _ = predictor.predict(
point_coords=point_coords if len(point_coords) > 0 else None, point_coords=point_coords if len(point_coords) > 0 else None,
point_labels=point_labels if len(point_coords) > 0 else None, point_labels=point_labels if len(point_coords) > 0 else None,
@ -170,7 +169,6 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
) )
masks = masks[:, None, ...] masks = masks[:, None, ...]
boxes = predictor.transform.apply_boxes(box, image_np.shape[:2]).astype(int) if box is not None else None
if shared.cmd_opts.lowvram: if shared.cmd_opts.lowvram:
sam.to(cpu) sam.to(cpu)
@ -182,8 +180,9 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
masks_gallery = [] masks_gallery = []
matted_images = [] matted_images = []
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
for mask in masks: for mask in masks:
blended_image = show_masks(show_boxes(image_np, boxes), mask) blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
masks_gallery.append(Image.fromarray(np.any(mask, axis=0))) masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
mask_images.append(Image.fromarray(blended_image)) mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np) image_np_copy = copy.deepcopy(image_np)
@ -193,26 +192,11 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
return mask_images + masks_gallery + matted_images return mask_images + masks_gallery + matted_images
def dino_predict(sam_model_name, input_image, dino_model_name, text_prompt, box_threshold): def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
image_np = np.array(input_image) image_np = np.array(input_image)
image_np_rgb = image_np[..., :3] boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy()
boxes_filt = dino_predict_internal(
input_image, dino_model_name, text_prompt, box_threshold).numpy()
sam = init_sam_model(sam_model_name)
predictor = SamPredictor(sam)
predictor.set_image(image_np_rgb)
boxes_filt = predictor.transform.apply_boxes(
boxes_filt, image_np.shape[:2]).astype(int)
if shared.cmd_opts.lowvram:
sam.to(cpu)
gc.collect()
torch_gc()
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])] boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
return Image.fromarray(show_boxes(image_np, boxes_filt, show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice) return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
def dino_batch_process( def dino_batch_process(
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt, batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
@ -234,21 +218,21 @@ def dino_batch_process(
boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold) boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
predictor.set_image(image_np_rgb) predictor.set_image(image_np_rgb)
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
masks, _, _ = predictor.predict_torch( masks, _, _ = predictor.predict_torch(
point_coords=None, point_coords=None,
point_labels=None, point_labels=None,
boxes=boxes.to(device), boxes=transformed_boxes.to(device),
multimask_output=(dino_batch_output_per_image == 1), multimask_output=(dino_batch_output_per_image == 1),
) )
masks = masks.permute(1, 0, 2, 3).cpu().numpy() masks = masks.permute(1, 0, 2, 3).cpu().numpy()
boxes = boxes.cpu().numpy().astype(int) boxes_filt = boxes_filt.cpu().numpy().astype(int)
filename, ext = os.path.splitext(os.path.basename(input_image_file)) filename, ext = os.path.splitext(os.path.basename(input_image_file))
for idx, mask in enumerate(masks): for idx, mask in enumerate(masks):
blended_image = show_masks(show_boxes(image_np, boxes), mask) blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
_, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt) _, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt)
image_np_copy = copy.deepcopy(image_np) image_np_copy = copy.deepcopy(image_np)
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0]) image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
@ -318,7 +302,7 @@ class Script(scripts.Script):
dino_preview_boxes_button.click( dino_preview_boxes_button.click(
fn=dino_predict, fn=dino_predict,
_js="submit_dino", _js="submit_dino",
inputs=[sam_model_name, input_image, dino_model_name, text_prompt, box_threshold], inputs=[input_image, dino_model_name, text_prompt, box_threshold],
outputs=[dino_preview_boxes, dino_preview_boxes_selection] outputs=[dino_preview_boxes, dino_preview_boxes_selection]
) )